feat(RAG): 更新RAG模块代码和文档
refactor: 简化Embeddings和LLM类实现,移除不必要依赖 docs: 更新文档内容,添加硅基流动API使用说明 chore: 更新requirements.txt依赖版本
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : Embeddings.py
|
||||
@Time : 2024/02/10 21:55:39
|
||||
@File : Embedding.py
|
||||
@Time : 2025/06/20 13:50:47
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Version : 1.1
|
||||
@Desc : None
|
||||
'''
|
||||
|
||||
@@ -12,6 +12,7 @@ import os
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
_ = load_dotenv(find_dotenv())
|
||||
@@ -22,21 +23,59 @@ class BaseEmbeddings:
|
||||
Base class for embeddings
|
||||
"""
|
||||
def __init__(self, path: str, is_api: bool) -> None:
|
||||
"""
|
||||
初始化嵌入基类
|
||||
Args:
|
||||
path (str): 模型或数据的路径
|
||||
is_api (bool): 是否使用API方式。True表示使用在线API服务,False表示使用本地模型
|
||||
"""
|
||||
self.path = path
|
||||
self.is_api = is_api
|
||||
|
||||
def get_embedding(self, text: str, model: str) -> List[float]:
|
||||
"""
|
||||
获取文本的嵌入向量表示
|
||||
Args:
|
||||
text (str): 输入文本
|
||||
model (str): 使用的模型名称
|
||||
Returns:
|
||||
List[float]: 文本的嵌入向量
|
||||
Raises:
|
||||
NotImplementedError: 该方法需要在子类中实现
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
||||
"""
|
||||
calculate cosine similarity between two vectors
|
||||
计算两个向量之间的余弦相似度
|
||||
Args:
|
||||
vector1 (List[float]): 第一个向量
|
||||
vector2 (List[float]): 第二个向量
|
||||
Returns:
|
||||
float: 两个向量的余弦相似度,范围在[-1,1]之间
|
||||
"""
|
||||
dot_product = np.dot(vector1, vector2)
|
||||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
|
||||
if not magnitude:
|
||||
return 0
|
||||
# 将输入列表转换为numpy数组,并指定数据类型为float32
|
||||
v1 = np.array(vector1, dtype=np.float32)
|
||||
v2 = np.array(vector2, dtype=np.float32)
|
||||
|
||||
# 检查向量中是否包含无穷大或NaN值
|
||||
if not np.all(np.isfinite(v1)) or not np.all(np.isfinite(v2)):
|
||||
return 0.0
|
||||
|
||||
# 计算向量的点积
|
||||
dot_product = np.dot(v1, v2)
|
||||
# 计算向量的范数(长度)
|
||||
norm_v1 = np.linalg.norm(v1)
|
||||
norm_v2 = np.linalg.norm(v2)
|
||||
|
||||
# 计算分母(两个向量范数的乘积)
|
||||
magnitude = norm_v1 * norm_v2
|
||||
# 处理分母为0的特殊情况
|
||||
if magnitude == 0:
|
||||
return 0.0
|
||||
|
||||
# 返回余弦相似度
|
||||
return dot_product / magnitude
|
||||
|
||||
|
||||
@@ -47,70 +86,18 @@ class OpenAIEmbedding(BaseEmbeddings):
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI()
|
||||
# 从环境变量中获取 硅基流动 密钥
|
||||
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
||||
# 从环境变量中获取 硅基流动 的基础URL
|
||||
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
|
||||
def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
|
||||
"""
|
||||
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
|
||||
"""
|
||||
if self.is_api:
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class JinaEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Jina embeddings
|
||||
"""
|
||||
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
|
||||
super().__init__(path, is_api)
|
||||
self._model = self.load_model()
|
||||
|
||||
def get_embedding(self, text: str) -> List[float]:
|
||||
return self._model.encode([text])[0].tolist()
|
||||
|
||||
def load_model(self):
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
|
||||
return model
|
||||
|
||||
class ZhipuEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Zhipu embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
||||
|
||||
def get_embedding(self, text: str) -> List[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model="embedding-2",
|
||||
input=text,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
class DashscopeEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Dashscope embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
import dashscope
|
||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
self.client = dashscope.TextEmbedding
|
||||
|
||||
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
|
||||
response = self.client.call(
|
||||
model=model,
|
||||
input=text
|
||||
)
|
||||
return response.output['embeddings'][0]['embedding']
|
||||
Reference in New Issue
Block a user