feat(RAG): 更新RAG模块代码和文档
refactor: 简化Embeddings和LLM类实现,移除不必要依赖 docs: 更新文档内容,添加硅基流动API使用说明 chore: 更新requirements.txt依赖版本
This commit is contained in:
4
docs/chapter7/RAG/.env_example
Normal file
4
docs/chapter7/RAG/.env_example
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# 此处默认使用国内可访问的轨迹流动平台 https://cloud.siliconflow.cn/
|
||||||
|
|
||||||
|
OPENAI_API_KEY='your api key'
|
||||||
|
OPENAI_BASE_URL='https://api.siliconflow.cn/v1'
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
'''
|
'''
|
||||||
@File : Embeddings.py
|
@File : Embedding.py
|
||||||
@Time : 2024/02/10 21:55:39
|
@Time : 2025/06/20 13:50:47
|
||||||
@Author : 不要葱姜蒜
|
@Author : 不要葱姜蒜
|
||||||
@Version : 1.0
|
@Version : 1.1
|
||||||
@Desc : None
|
@Desc : None
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@@ -12,6 +12,7 @@ import os
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
from dotenv import load_dotenv, find_dotenv
|
from dotenv import load_dotenv, find_dotenv
|
||||||
_ = load_dotenv(find_dotenv())
|
_ = load_dotenv(find_dotenv())
|
||||||
@@ -22,21 +23,59 @@ class BaseEmbeddings:
|
|||||||
Base class for embeddings
|
Base class for embeddings
|
||||||
"""
|
"""
|
||||||
def __init__(self, path: str, is_api: bool) -> None:
|
def __init__(self, path: str, is_api: bool) -> None:
|
||||||
|
"""
|
||||||
|
初始化嵌入基类
|
||||||
|
Args:
|
||||||
|
path (str): 模型或数据的路径
|
||||||
|
is_api (bool): 是否使用API方式。True表示使用在线API服务,False表示使用本地模型
|
||||||
|
"""
|
||||||
self.path = path
|
self.path = path
|
||||||
self.is_api = is_api
|
self.is_api = is_api
|
||||||
|
|
||||||
def get_embedding(self, text: str, model: str) -> List[float]:
|
def get_embedding(self, text: str, model: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
获取文本的嵌入向量表示
|
||||||
|
Args:
|
||||||
|
text (str): 输入文本
|
||||||
|
model (str): 使用的模型名称
|
||||||
|
Returns:
|
||||||
|
List[float]: 文本的嵌入向量
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: 该方法需要在子类中实现
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
||||||
"""
|
"""
|
||||||
calculate cosine similarity between two vectors
|
计算两个向量之间的余弦相似度
|
||||||
|
Args:
|
||||||
|
vector1 (List[float]): 第一个向量
|
||||||
|
vector2 (List[float]): 第二个向量
|
||||||
|
Returns:
|
||||||
|
float: 两个向量的余弦相似度,范围在[-1,1]之间
|
||||||
"""
|
"""
|
||||||
dot_product = np.dot(vector1, vector2)
|
# 将输入列表转换为numpy数组,并指定数据类型为float32
|
||||||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
|
v1 = np.array(vector1, dtype=np.float32)
|
||||||
if not magnitude:
|
v2 = np.array(vector2, dtype=np.float32)
|
||||||
return 0
|
|
||||||
|
# 检查向量中是否包含无穷大或NaN值
|
||||||
|
if not np.all(np.isfinite(v1)) or not np.all(np.isfinite(v2)):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 计算向量的点积
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
# 计算向量的范数(长度)
|
||||||
|
norm_v1 = np.linalg.norm(v1)
|
||||||
|
norm_v2 = np.linalg.norm(v2)
|
||||||
|
|
||||||
|
# 计算分母(两个向量范数的乘积)
|
||||||
|
magnitude = norm_v1 * norm_v2
|
||||||
|
# 处理分母为0的特殊情况
|
||||||
|
if magnitude == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 返回余弦相似度
|
||||||
return dot_product / magnitude
|
return dot_product / magnitude
|
||||||
|
|
||||||
|
|
||||||
@@ -47,70 +86,18 @@ class OpenAIEmbedding(BaseEmbeddings):
|
|||||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||||
super().__init__(path, is_api)
|
super().__init__(path, is_api)
|
||||||
if self.is_api:
|
if self.is_api:
|
||||||
from openai import OpenAI
|
|
||||||
self.client = OpenAI()
|
self.client = OpenAI()
|
||||||
|
# 从环境变量中获取 硅基流动 密钥
|
||||||
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
# 从环境变量中获取 硅基流动 的基础URL
|
||||||
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||||
|
|
||||||
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
|
def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
|
||||||
|
"""
|
||||||
|
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
|
||||||
|
"""
|
||||||
if self.is_api:
|
if self.is_api:
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
class JinaEmbedding(BaseEmbeddings):
|
|
||||||
"""
|
|
||||||
class for Jina embeddings
|
|
||||||
"""
|
|
||||||
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
|
|
||||||
super().__init__(path, is_api)
|
|
||||||
self._model = self.load_model()
|
|
||||||
|
|
||||||
def get_embedding(self, text: str) -> List[float]:
|
|
||||||
return self._model.encode([text])[0].tolist()
|
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModel
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
class ZhipuEmbedding(BaseEmbeddings):
|
|
||||||
"""
|
|
||||||
class for Zhipu embeddings
|
|
||||||
"""
|
|
||||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
|
||||||
super().__init__(path, is_api)
|
|
||||||
if self.is_api:
|
|
||||||
from zhipuai import ZhipuAI
|
|
||||||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
|
||||||
|
|
||||||
def get_embedding(self, text: str) -> List[float]:
|
|
||||||
response = self.client.embeddings.create(
|
|
||||||
model="embedding-2",
|
|
||||||
input=text,
|
|
||||||
)
|
|
||||||
return response.data[0].embedding
|
|
||||||
|
|
||||||
class DashscopeEmbedding(BaseEmbeddings):
|
|
||||||
"""
|
|
||||||
class for Dashscope embeddings
|
|
||||||
"""
|
|
||||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
|
||||||
super().__init__(path, is_api)
|
|
||||||
if self.is_api:
|
|
||||||
import dashscope
|
|
||||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
|
||||||
self.client = dashscope.TextEmbedding
|
|
||||||
|
|
||||||
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
|
|
||||||
response = self.client.call(
|
|
||||||
model=model,
|
|
||||||
input=text
|
|
||||||
)
|
|
||||||
return response.output['embeddings'][0]['embedding']
|
|
||||||
@@ -2,37 +2,33 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
'''
|
'''
|
||||||
@File : LLM.py
|
@File : LLM.py
|
||||||
@Time : 2024/02/12 13:50:47
|
@Time : 2025/06/20 13:50:47
|
||||||
@Author : 不要葱姜蒜
|
@Author : 不要葱姜蒜
|
||||||
@Version : 1.0
|
@Version : 1.1
|
||||||
@Desc : None
|
@Desc : None
|
||||||
'''
|
'''
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
PROMPT_TEMPLATE = dict(
|
from dotenv import load_dotenv, find_dotenv
|
||||||
RAG_PROMPT_TEMPLATE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
_ = load_dotenv(find_dotenv())
|
||||||
问题: {question}
|
|
||||||
可参考的上下文:
|
RAG_PROMPT_TEMPLATE="""
|
||||||
···
|
使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||||
{context}
|
问题: {question}
|
||||||
···
|
可参考的上下文:
|
||||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
···
|
||||||
有用的回答:""",
|
{context}
|
||||||
InternLM_PROMPT_TEMPLATE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
···
|
||||||
问题: {question}
|
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
||||||
可参考的上下文:
|
有用的回答:
|
||||||
···
|
"""
|
||||||
{context}
|
|
||||||
···
|
|
||||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
|
||||||
有用的回答:"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel:
|
class BaseModel:
|
||||||
def __init__(self, path: str = '') -> None:
|
def __init__(self, model) -> None:
|
||||||
self.path = path
|
self.model = model
|
||||||
|
|
||||||
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||||||
pass
|
pass
|
||||||
@@ -41,73 +37,18 @@ class BaseModel:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class OpenAIChat(BaseModel):
|
class OpenAIChat(BaseModel):
|
||||||
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None:
|
def __init__(self, model: str = "Qwen/Qwen2.5-32B-Instruct") -> None:
|
||||||
super().__init__(path)
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||||||
from openai import OpenAI
|
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
client.api_key = os.getenv("OPENAI_API_KEY")
|
client.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
client.base_url = os.getenv("OPENAI_BASE_URL")
|
client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
history.append({'role': 'user', 'content': RAG_PROMPT_TEMPLATE.format(question=prompt, context=content)})
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=history,
|
messages=history,
|
||||||
max_tokens=150,
|
max_tokens=2048,
|
||||||
temperature=0.1
|
temperature=0.1
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
class InternLMChat(BaseModel):
|
|
||||||
def __init__(self, path: str = '') -> None:
|
|
||||||
super().__init__(path)
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
def chat(self, prompt: str, history: List = [], content: str='') -> str:
|
|
||||||
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPLATE'].format(question=prompt, context=content)
|
|
||||||
response, history = self.model.chat(self.tokenizer, prompt, history)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
|
|
||||||
|
|
||||||
class DashscopeChat(BaseModel):
|
|
||||||
def __init__(self, path: str = '', model: str = "qwen-turbo") -> None:
|
|
||||||
super().__init__(path)
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
|
||||||
import dashscope
|
|
||||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
|
||||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
|
||||||
response = dashscope.Generation.call(
|
|
||||||
model=self.model,
|
|
||||||
messages=history,
|
|
||||||
result_format='message',
|
|
||||||
max_tokens=150,
|
|
||||||
temperature=0.1
|
|
||||||
)
|
|
||||||
return response.output.choices[0].message.content
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuChat(BaseModel):
|
|
||||||
def __init__(self, path: str = '', model: str = "glm-4") -> None:
|
|
||||||
super().__init__(path)
|
|
||||||
from zhipuai import ZhipuAI
|
|
||||||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
|
||||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
|
||||||
response = self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=history,
|
|
||||||
max_tokens=150,
|
|
||||||
temperature=0.1
|
|
||||||
)
|
|
||||||
return response.choices[0].message
|
|
||||||
@@ -2,16 +2,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
'''
|
'''
|
||||||
@File : VectorBase.py
|
@File : VectorBase.py
|
||||||
@Time : 2024/02/12 10:11:13
|
@Time : 2025/06/20 10:11:13
|
||||||
@Author : 不要葱姜蒜
|
@Author : 不要葱姜蒜
|
||||||
@Version : 1.0
|
@Version : 1.1
|
||||||
@Desc : None
|
@Desc : None
|
||||||
'''
|
'''
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
import json
|
import json
|
||||||
from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding
|
from Embeddings import BaseEmbeddings, OpenAIEmbedding
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|||||||
19
docs/chapter7/RAG/demo.py
Normal file
19
docs/chapter7/RAG/demo.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from VectorBase import VectorStore
|
||||||
|
from utils import ReadFiles
|
||||||
|
from LLM import OpenAIChat
|
||||||
|
from Embeddings import OpenAIEmbedding
|
||||||
|
|
||||||
|
# 没有保存数据库
|
||||||
|
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # 获得data目录下的所有文件内容并分割
|
||||||
|
vector = VectorStore(docs)
|
||||||
|
embedding = OpenAIEmbedding() # 创建EmbeddingModel
|
||||||
|
vector.get_vector(EmbeddingModel=embedding)
|
||||||
|
vector.persist(path='storage') # 将向量和文档内容保存到storage目录下,下次再用就可以直接加载本地的数据库
|
||||||
|
|
||||||
|
# vector.load_vector('./storage') # 加载本地的数据库
|
||||||
|
|
||||||
|
question = 'RAG的原理是什么?'
|
||||||
|
|
||||||
|
content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
|
||||||
|
chat = OpenAIChat(model='Qwen/Qwen2.5-32B-Instruct')
|
||||||
|
print(chat.chat(question, [], content))
|
||||||
@@ -1,14 +1,28 @@
|
|||||||
openai
|
annotated-types==0.7.0
|
||||||
zhipuai
|
anyio==4.9.0
|
||||||
numpy
|
beautifulsoup4==4.13.4
|
||||||
python-dotenv
|
bs4==0.0.2
|
||||||
torch
|
certifi==2025.6.15
|
||||||
torchvision
|
charset-normalizer==3.4.2
|
||||||
torchaudio
|
distro==1.9.0
|
||||||
transformers
|
h11==0.16.0
|
||||||
tqdm
|
httpcore==1.0.9
|
||||||
PyPDF2
|
httpx==0.28.1
|
||||||
markdown
|
idna==3.10
|
||||||
html2text
|
jiter==0.10.0
|
||||||
tiktoken
|
markdown==3.8.2
|
||||||
beautifulsoup4
|
numpy==2.3.0
|
||||||
|
openai==1.88.0
|
||||||
|
pydantic==2.11.7
|
||||||
|
pydantic-core==2.33.2
|
||||||
|
pypdf2==3.0.1
|
||||||
|
python-dotenv==1.1.0
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.4
|
||||||
|
sniffio==1.3.1
|
||||||
|
soupsieve==2.7
|
||||||
|
tiktoken==0.9.0
|
||||||
|
tqdm==4.67.1
|
||||||
|
typing-extensions==4.14.0
|
||||||
|
typing-inspection==0.4.1
|
||||||
|
urllib3==2.5.0
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
'''
|
'''
|
||||||
@File : utils.py
|
@File : utils.py
|
||||||
@Time : 2024/02/11 09:52:26
|
@Time : 2025/06/20 13:50:47
|
||||||
@Author : 不要葱姜蒜
|
@Author : 不要葱姜蒜
|
||||||
@Version : 1.0
|
@Version : 1.1
|
||||||
@Desc : None
|
@Desc : None
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@@ -13,7 +13,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import markdown
|
import markdown
|
||||||
import html2text
|
|
||||||
import json
|
import json
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|||||||
@@ -146,21 +146,59 @@ class BaseEmbeddings:
|
|||||||
Base class for embeddings
|
Base class for embeddings
|
||||||
"""
|
"""
|
||||||
def __init__(self, path: str, is_api: bool) -> None:
|
def __init__(self, path: str, is_api: bool) -> None:
|
||||||
|
"""
|
||||||
|
初始化嵌入基类
|
||||||
|
Args:
|
||||||
|
path (str): 模型或数据的路径
|
||||||
|
is_api (bool): 是否使用API方式。True表示使用在线API服务,False表示使用本地模型
|
||||||
|
"""
|
||||||
self.path = path
|
self.path = path
|
||||||
self.is_api = is_api
|
self.is_api = is_api
|
||||||
|
|
||||||
def get_embedding(self, text: str, model: str) -> List[float]:
|
def get_embedding(self, text: str, model: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
获取文本的嵌入向量表示
|
||||||
|
Args:
|
||||||
|
text (str): 输入文本
|
||||||
|
model (str): 使用的模型名称
|
||||||
|
Returns:
|
||||||
|
List[float]: 文本的嵌入向量
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: 该方法需要在子类中实现
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
||||||
"""
|
"""
|
||||||
calculate cosine similarity between two vectors
|
计算两个向量之间的余弦相似度
|
||||||
|
Args:
|
||||||
|
vector1 (List[float]): 第一个向量
|
||||||
|
vector2 (List[float]): 第二个向量
|
||||||
|
Returns:
|
||||||
|
float: 两个向量的余弦相似度,范围在[-1,1]之间
|
||||||
"""
|
"""
|
||||||
dot_product = np.dot(vector1, vector2)
|
# 将输入列表转换为numpy数组,并指定数据类型为float32
|
||||||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
|
v1 = np.array(vector1, dtype=np.float32)
|
||||||
if not magnitude:
|
v2 = np.array(vector2, dtype=np.float32)
|
||||||
return 0
|
|
||||||
|
# 检查向量中是否包含无穷大或NaN值
|
||||||
|
if not np.all(np.isfinite(v1)) or not np.all(np.isfinite(v2)):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 计算向量的点积
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
# 计算向量的范数(长度)
|
||||||
|
norm_v1 = np.linalg.norm(v1)
|
||||||
|
norm_v2 = np.linalg.norm(v2)
|
||||||
|
|
||||||
|
# 计算分母(两个向量范数的乘积)
|
||||||
|
magnitude = norm_v1 * norm_v2
|
||||||
|
# 处理分母为0的特殊情况
|
||||||
|
if magnitude == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 返回余弦相似度
|
||||||
return dot_product / magnitude
|
return dot_product / magnitude
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -176,12 +214,16 @@ class OpenAIEmbedding(BaseEmbeddings):
|
|||||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||||
super().__init__(path, is_api)
|
super().__init__(path, is_api)
|
||||||
if self.is_api:
|
if self.is_api:
|
||||||
from openai import OpenAI
|
|
||||||
self.client = OpenAI()
|
self.client = OpenAI()
|
||||||
|
# 从环境变量中获取 硅基流动 密钥
|
||||||
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
# 从环境变量中获取 硅基流动 的基础URL
|
||||||
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||||
|
|
||||||
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
|
def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
|
||||||
|
"""
|
||||||
|
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
|
||||||
|
"""
|
||||||
if self.is_api:
|
if self.is_api:
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||||
@@ -189,6 +231,9 @@ class OpenAIEmbedding(BaseEmbeddings):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> 注:此处我们默认使用国内用户可访问的硅基流动大模型API服务平台。
|
||||||
|
> 硅基流动:https://cloud.siliconflow.cn/
|
||||||
|
|
||||||
#### Step 3: 文档加载和切分
|
#### Step 3: 文档加载和切分
|
||||||
|
|
||||||
接下来我们来实现一个文档加载和切分的类,这个类主要用于加载文档并将其切分成文档片段。
|
接下来我们来实现一个文档加载和切分的类,这个类主要用于加载文档并将其切分成文档片段。
|
||||||
@@ -251,7 +296,7 @@ def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150
|
|||||||
- `get_vector`:获取文档的向量表示。
|
- `get_vector`:获取文档的向量表示。
|
||||||
- `query`:根据问题检索相关文档片段。
|
- `query`:根据问题检索相关文档片段。
|
||||||
|
|
||||||
完整代码可以在 ***[RAG/VectorBase.py](RAG/VectorBase.py)*** 文件中找到。
|
完整代码可以在 ***[/VectorBase.py](./RAG/VectorBase.py)*** 文件中找到。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class VectorStore:
|
class VectorStore:
|
||||||
@@ -302,41 +347,43 @@ class BaseModel:
|
|||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
`BaseModel` 包含两个方法:`chat`和`load_model`。对于本地化运行的开源模型需要实现`load_model`,而API模型则不需要。
|
`BaseModel` 包含两个方法:`chat`和`load_model`。对于本地化运行的开源模型需要实现`load_model`,而API模型则不需要。在此处我们还是使用国内用户可访问的硅基流动大模型API服务平台,使用API服务的好处就是用户不需要本地的计算资源,可以大大降低学习者的学习门槛。
|
||||||
|
|
||||||
下面以 ***[InternLM2-chat-7B](https://huggingface.co/internlm/internlm2-chat-7b)*** 模型为例:
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class InternLMChat(BaseModel):
|
from openai import OpenAI
|
||||||
def __init__(self, path: str = '') -> None:
|
|
||||||
super().__init__(path)
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
def chat(self, prompt: str, history: List = [], content: str='') -> str:
|
class OpenAIChat(BaseModel):
|
||||||
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPLATE'].format(question=prompt, context=content)
|
def __init__(self, model: str = "Qwen/Qwen2.5-32B-Instruct") -> None:
|
||||||
response, history = self.model.chat(self.tokenizer, prompt, history)
|
self.model = model
|
||||||
return response
|
|
||||||
|
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||||||
|
client = OpenAI()
|
||||||
|
client.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||||
|
history.append({'role': 'user', 'content': RAG_PROMPT_TEMPLATE.format(question=prompt, context=content)})
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=history,
|
||||||
|
max_tokens=2048,
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
可以用一个字典来保存所有的prompt,方便维护:
|
设计一个专用于RAG的大模型提示词,如下:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
PROMPT_TEMPLATE = dict(
|
RAG_PROMPT_TEMPLATE="""
|
||||||
InternLM_PROMPT_TEMPLATE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||||
问题: {question}
|
问题: {question}
|
||||||
可参考的上下文:
|
可参考的上下文:
|
||||||
···
|
···
|
||||||
{context}
|
{context}
|
||||||
···
|
···
|
||||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
||||||
有用的回答:"""
|
有用的回答:
|
||||||
)
|
"""
|
||||||
```
|
```
|
||||||
|
|
||||||
这样我们就可以利用InternLM2模型来做RAG啦!
|
这样我们就可以利用InternLM2模型来做RAG啦!
|
||||||
@@ -346,47 +393,51 @@ PROMPT_TEMPLATE = dict(
|
|||||||
接下来,我们来看看Tiny-RAG的Demo吧!
|
接下来,我们来看看Tiny-RAG的Demo吧!
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from RAG.VectorBase import VectorStore
|
from VectorBase import VectorStore
|
||||||
from RAG.utils import ReadFiles
|
from utils import ReadFiles
|
||||||
from RAG.LLM import OpenAIChat, InternLMChat
|
from LLM import OpenAIChat
|
||||||
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding
|
from Embeddings import OpenAIEmbedding
|
||||||
|
|
||||||
# 没有保存数据库
|
# 没有保存数据库
|
||||||
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # 获取data目录下的所有文件内容并分割
|
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # 获得data目录下的所有文件内容并分割
|
||||||
vector = VectorStore(docs)
|
vector = VectorStore(docs)
|
||||||
embedding = ZhipuEmbedding() # 创建EmbeddingModel
|
embedding = OpenAIEmbedding() # 创建EmbeddingModel
|
||||||
vector.get_vector(EmbeddingModel=embedding)
|
vector.get_vector(EmbeddingModel=embedding)
|
||||||
vector.persist(path='storage') # 将向量和文档内容保存到storage目录,下次再用可以直接加载本地数据库
|
vector.persist(path='storage') # 将向量和文档内容保存到storage目录下,下次再用就可以直接加载本地的数据库
|
||||||
|
|
||||||
question = 'git的原理是什么?'
|
# vector.load_vector('./storage') # 加载本地的数据库
|
||||||
|
|
||||||
content = vector.query(question, model='zhipu', k=1)[0]
|
question = 'RAG的原理是什么?'
|
||||||
chat = InternLMChat(path='model_path')
|
|
||||||
|
content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
|
||||||
|
chat = OpenAIChat(model='Qwen/Qwen2.5-32B-Instruct')
|
||||||
print(chat.chat(question, [], content))
|
print(chat.chat(question, [], content))
|
||||||
```
|
```
|
||||||
|
|
||||||
也可以从本地加载已处理好的数据库:
|
也可以从本地加载已处理好的数据库:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from RAG.VectorBase import VectorStore
|
from VectorBase import VectorStore
|
||||||
from RAG.utils import ReadFiles
|
from utils import ReadFiles
|
||||||
from RAG.LLM import OpenAIChat, InternLMChat
|
from LLM import OpenAIChat
|
||||||
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding
|
from Embeddings import OpenAIEmbedding
|
||||||
|
|
||||||
# 保存数据库之后
|
# 保存数据库之后
|
||||||
vector = VectorStore()
|
vector = VectorStore()
|
||||||
|
|
||||||
vector.load_vector('./storage') # 加载本地数据库
|
vector.load_vector('./storage') # 加载本地的数据库
|
||||||
|
|
||||||
question = 'git的原理是什么?'
|
question = 'RAG的原理是什么?'
|
||||||
|
|
||||||
embedding = ZhipuEmbedding() # 创建EmbeddingModel
|
embedding = ZhipuEmbedding() # 创建EmbeddingModel
|
||||||
|
|
||||||
content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
|
content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
|
||||||
chat = InternLMChat(path='model_path')
|
chat = OpenAIChat(model='Qwen/Qwen2.5-32B-Instruct')
|
||||||
print(chat.chat(question, [], content))
|
print(chat.chat(question, [], content))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> 注:7.2 章节的所有代码均可在 [Happy-LLM Chapter7 RAG](https://github.com/datawhalechina/happy-llm/tree/main/docs/chapter7/RAG) 中找到。
|
||||||
|
|
||||||
## 7.3 Agent
|
## 7.3 Agent
|
||||||
|
|
||||||
### 7.3.1 什么是 LLM Agent?
|
### 7.3.1 什么是 LLM Agent?
|
||||||
|
|||||||
Reference in New Issue
Block a user