feat(RAG): 更新RAG模块代码和文档

refactor: 简化Embeddings和LLM类实现,移除不必要依赖
docs: 更新文档内容,添加硅基流动API使用说明
chore: 更新requirements.txt依赖版本
This commit is contained in:
KMnO4-zx
2025-06-20 22:53:23 +08:00
parent 0eea57b11f
commit fe07d0ede1
8 changed files with 233 additions and 218 deletions

View File

@@ -1,10 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : Embeddings.py
@Time : 2024/02/10 21:55:39
@File : Embedding.py
@Time : 2025/06/20 13:50:47
@Author : 不要葱姜蒜
@Version : 1.0
@Version : 1.1
@Desc : None
'''
@@ -12,6 +12,7 @@ import os
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
@@ -22,21 +23,59 @@ class BaseEmbeddings:
Base class for embeddings
"""
def __init__(self, path: str, is_api: bool) -> None:
"""
初始化嵌入基类
Args:
path (str): 模型或数据的路径
is_api (bool): 是否使用API方式。True表示使用在线API服务False表示使用本地模型
"""
self.path = path
self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]:
"""
获取文本的嵌入向量表示
Args:
text (str): 输入文本
model (str): 使用的模型名称
Returns:
List[float]: 文本的嵌入向量
Raises:
NotImplementedError: 该方法需要在子类中实现
"""
raise NotImplementedError
@classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
"""
calculate cosine similarity between two vectors
计算两个向量之间的余弦相似度
Args:
vector1 (List[float]): 第一个向量
vector2 (List[float]): 第二个向量
Returns:
float: 两个向量的余弦相似度,范围在[-1,1]之间
"""
dot_product = np.dot(vector1, vector2)
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
if not magnitude:
return 0
# 将输入列表转换为numpy数组并指定数据类型为float32
v1 = np.array(vector1, dtype=np.float32)
v2 = np.array(vector2, dtype=np.float32)
# 检查向量中是否包含无穷大或NaN值
if not np.all(np.isfinite(v1)) or not np.all(np.isfinite(v2)):
return 0.0
# 计算向量的点积
dot_product = np.dot(v1, v2)
# 计算向量的范数(长度)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2)
# 计算分母(两个向量范数的乘积)
magnitude = norm_v1 * norm_v2
# 处理分母为0的特殊情况
if magnitude == 0:
return 0.0
# 返回余弦相似度
return dot_product / magnitude
@@ -47,70 +86,18 @@ class OpenAIEmbedding(BaseEmbeddings):
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from openai import OpenAI
self.client = OpenAI()
# 从环境变量中获取 硅基流动 密钥
self.client.api_key = os.getenv("OPENAI_API_KEY")
# 从环境变量中获取 硅基流动 的基础URL
self.client.base_url = os.getenv("OPENAI_BASE_URL")
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
"""
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
"""
if self.is_api:
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
else:
raise NotImplementedError
class JinaEmbedding(BaseEmbeddings):
"""
class for Jina embeddings
"""
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
super().__init__(path, is_api)
self._model = self.load_model()
def get_embedding(self, text: str) -> List[float]:
return self._model.encode([text])[0].tolist()
def load_model(self):
import torch
from transformers import AutoModel
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
return model
class ZhipuEmbedding(BaseEmbeddings):
"""
class for Zhipu embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
def get_embedding(self, text: str) -> List[float]:
response = self.client.embeddings.create(
model="embedding-2",
input=text,
)
return response.data[0].embedding
class DashscopeEmbedding(BaseEmbeddings):
"""
class for Dashscope embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
import dashscope
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
self.client = dashscope.TextEmbedding
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
response = self.client.call(
model=model,
input=text
)
return response.output['embeddings'][0]['embedding']