Files
happy-llm/docs/chapter7/RAG/Embeddings.py
KMnO4-zx fe07d0ede1 feat(RAG): 更新RAG模块代码和文档
refactor: 简化Embeddings和LLM类实现,移除不必要依赖
docs: 更新文档内容,添加硅基流动API使用说明
chore: 更新requirements.txt依赖版本
2025-06-20 22:53:23 +08:00

104 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : Embedding.py
@Time : 2025/06/20 13:50:47
@Author : 不要葱姜蒜
@Version : 1.1
@Desc : None
'''
import os
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
class BaseEmbeddings:
"""
Base class for embeddings
"""
def __init__(self, path: str, is_api: bool) -> None:
"""
初始化嵌入基类
Args:
path (str): 模型或数据的路径
is_api (bool): 是否使用API方式。True表示使用在线API服务False表示使用本地模型
"""
self.path = path
self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]:
"""
获取文本的嵌入向量表示
Args:
text (str): 输入文本
model (str): 使用的模型名称
Returns:
List[float]: 文本的嵌入向量
Raises:
NotImplementedError: 该方法需要在子类中实现
"""
raise NotImplementedError
@classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
"""
计算两个向量之间的余弦相似度
Args:
vector1 (List[float]): 第一个向量
vector2 (List[float]): 第二个向量
Returns:
float: 两个向量的余弦相似度,范围在[-1,1]之间
"""
# 将输入列表转换为numpy数组并指定数据类型为float32
v1 = np.array(vector1, dtype=np.float32)
v2 = np.array(vector2, dtype=np.float32)
# 检查向量中是否包含无穷大或NaN值
if not np.all(np.isfinite(v1)) or not np.all(np.isfinite(v2)):
return 0.0
# 计算向量的点积
dot_product = np.dot(v1, v2)
# 计算向量的范数(长度)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2)
# 计算分母(两个向量范数的乘积)
magnitude = norm_v1 * norm_v2
# 处理分母为0的特殊情况
if magnitude == 0:
return 0.0
# 返回余弦相似度
return dot_product / magnitude
class OpenAIEmbedding(BaseEmbeddings):
"""
class for OpenAI embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
self.client = OpenAI()
# 从环境变量中获取 硅基流动 密钥
self.client.api_key = os.getenv("OPENAI_API_KEY")
# 从环境变量中获取 硅基流动 的基础URL
self.client.base_url = os.getenv("OPENAI_BASE_URL")
def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
"""
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
"""
if self.is_api:
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
else:
raise NotImplementedError