feat(RAG): 更新RAG模块代码和文档

refactor: 简化Embeddings和LLM类实现,移除不必要依赖
docs: 更新文档内容,添加硅基流动API使用说明
chore: 更新requirements.txt依赖版本
This commit is contained in:
KMnO4-zx
2025-06-20 22:53:23 +08:00
parent 0eea57b11f
commit fe07d0ede1
8 changed files with 233 additions and 218 deletions

View File

@@ -0,0 +1,4 @@
# 此处默认使用国内可访问的轨迹流动平台 https://cloud.siliconflow.cn/
OPENAI_API_KEY='your api key'
OPENAI_BASE_URL='https://api.siliconflow.cn/v1'

View File

@@ -1,10 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
''' '''
@File : Embeddings.py @File : Embedding.py
@Time : 2024/02/10 21:55:39 @Time : 2025/06/20 13:50:47
@Author : 不要葱姜蒜 @Author : 不要葱姜蒜
@Version : 1.0 @Version : 1.1
@Desc : None @Desc : None
''' '''
@@ -12,6 +12,7 @@ import os
from copy import copy from copy import copy
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) _ = load_dotenv(find_dotenv())
@@ -22,21 +23,59 @@ class BaseEmbeddings:
Base class for embeddings Base class for embeddings
""" """
def __init__(self, path: str, is_api: bool) -> None: def __init__(self, path: str, is_api: bool) -> None:
"""
初始化嵌入基类
Args:
path (str): 模型或数据的路径
is_api (bool): 是否使用API方式。True表示使用在线API服务False表示使用本地模型
"""
self.path = path self.path = path
self.is_api = is_api self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]: def get_embedding(self, text: str, model: str) -> List[float]:
"""
获取文本的嵌入向量表示
Args:
text (str): 输入文本
model (str): 使用的模型名称
Returns:
List[float]: 文本的嵌入向量
Raises:
NotImplementedError: 该方法需要在子类中实现
"""
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float: def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
""" """
calculate cosine similarity between two vectors 计算两个向量之间的余弦相似度
Args:
vector1 (List[float]): 第一个向量
vector2 (List[float]): 第二个向量
Returns:
float: 两个向量的余弦相似度,范围在[-1,1]之间
""" """
dot_product = np.dot(vector1, vector2) # 将输入列表转换为numpy数组并指定数据类型为float32
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) v1 = np.array(vector1, dtype=np.float32)
if not magnitude: v2 = np.array(vector2, dtype=np.float32)
return 0
# 检查向量中是否包含无穷大或NaN值
if not np.all(np.isfinite(v1)) or not np.all(np.isfinite(v2)):
return 0.0
# 计算向量的点积
dot_product = np.dot(v1, v2)
# 计算向量的范数(长度)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2)
# 计算分母(两个向量范数的乘积)
magnitude = norm_v1 * norm_v2
# 处理分母为0的特殊情况
if magnitude == 0:
return 0.0
# 返回余弦相似度
return dot_product / magnitude return dot_product / magnitude
@@ -47,70 +86,18 @@ class OpenAIEmbedding(BaseEmbeddings):
def __init__(self, path: str = '', is_api: bool = True) -> None: def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api) super().__init__(path, is_api)
if self.is_api: if self.is_api:
from openai import OpenAI
self.client = OpenAI() self.client = OpenAI()
# 从环境变量中获取 硅基流动 密钥
self.client.api_key = os.getenv("OPENAI_API_KEY") self.client.api_key = os.getenv("OPENAI_API_KEY")
# 从环境变量中获取 硅基流动 的基础URL
self.client.base_url = os.getenv("OPENAI_BASE_URL") self.client.base_url = os.getenv("OPENAI_BASE_URL")
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]: def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
"""
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
"""
if self.is_api: if self.is_api:
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding return self.client.embeddings.create(input=[text], model=model).data[0].embedding
else: else:
raise NotImplementedError raise NotImplementedError
class JinaEmbedding(BaseEmbeddings):
"""
class for Jina embeddings
"""
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
super().__init__(path, is_api)
self._model = self.load_model()
def get_embedding(self, text: str) -> List[float]:
return self._model.encode([text])[0].tolist()
def load_model(self):
import torch
from transformers import AutoModel
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
return model
class ZhipuEmbedding(BaseEmbeddings):
"""
class for Zhipu embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
def get_embedding(self, text: str) -> List[float]:
response = self.client.embeddings.create(
model="embedding-2",
input=text,
)
return response.data[0].embedding
class DashscopeEmbedding(BaseEmbeddings):
"""
class for Dashscope embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
import dashscope
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
self.client = dashscope.TextEmbedding
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
response = self.client.call(
model=model,
input=text
)
return response.output['embeddings'][0]['embedding']

View File

@@ -2,37 +2,33 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
''' '''
@File : LLM.py @File : LLM.py
@Time : 2024/02/12 13:50:47 @Time : 2025/06/20 13:50:47
@Author : 不要葱姜蒜 @Author : 不要葱姜蒜
@Version : 1.0 @Version : 1.1
@Desc : None @Desc : None
''' '''
import os import os
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from openai import OpenAI
PROMPT_TEMPLATE = dict( from dotenv import load_dotenv, find_dotenv
RAG_PROMPT_TEMPLATE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。 _ = load_dotenv(find_dotenv())
问题: {question}
可参考的上下文: RAG_PROMPT_TEMPLATE="""
··· 使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
{context} 问题: {question}
··· 可参考的上下文:
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。 ···
有用的回答:""", {context}
InternLM_PROMPT_TEMPLATE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。 ···
问题: {question} 如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
可参考的上下文: 有用的回答:
··· """
{context}
···
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
有用的回答:"""
)
class BaseModel: class BaseModel:
def __init__(self, path: str = '') -> None: def __init__(self, model) -> None:
self.path = path self.model = model
def chat(self, prompt: str, history: List[dict], content: str) -> str: def chat(self, prompt: str, history: List[dict], content: str) -> str:
pass pass
@@ -41,73 +37,18 @@ class BaseModel:
pass pass
class OpenAIChat(BaseModel): class OpenAIChat(BaseModel):
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None: def __init__(self, model: str = "Qwen/Qwen2.5-32B-Instruct") -> None:
super().__init__(path)
self.model = model self.model = model
def chat(self, prompt: str, history: List[dict], content: str) -> str: def chat(self, prompt: str, history: List[dict], content: str) -> str:
from openai import OpenAI
client = OpenAI() client = OpenAI()
client.api_key = os.getenv("OPENAI_API_KEY") client.api_key = os.getenv("OPENAI_API_KEY")
client.base_url = os.getenv("OPENAI_BASE_URL") client.base_url = os.getenv("OPENAI_BASE_URL")
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)}) history.append({'role': 'user', 'content': RAG_PROMPT_TEMPLATE.format(question=prompt, context=content)})
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
messages=history, messages=history,
max_tokens=150, max_tokens=2048,
temperature=0.1 temperature=0.1
) )
return response.choices[0].message.content return response.choices[0].message.content
class InternLMChat(BaseModel):
def __init__(self, path: str = '') -> None:
super().__init__(path)
self.load_model()
def chat(self, prompt: str, history: List = [], content: str='') -> str:
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPLATE'].format(question=prompt, context=content)
response, history = self.model.chat(self.tokenizer, prompt, history)
return response
def load_model(self):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
class DashscopeChat(BaseModel):
def __init__(self, path: str = '', model: str = "qwen-turbo") -> None:
super().__init__(path)
self.model = model
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
import dashscope
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
response = dashscope.Generation.call(
model=self.model,
messages=history,
result_format='message',
max_tokens=150,
temperature=0.1
)
return response.output.choices[0].message.content
class ZhipuChat(BaseModel):
def __init__(self, path: str = '', model: str = "glm-4") -> None:
super().__init__(path)
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
self.model = model
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
response = self.client.chat.completions.create(
model=self.model,
messages=history,
max_tokens=150,
temperature=0.1
)
return response.choices[0].message

View File

@@ -2,16 +2,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
''' '''
@File : VectorBase.py @File : VectorBase.py
@Time : 2024/02/12 10:11:13 @Time : 2025/06/20 10:11:13
@Author : 不要葱姜蒜 @Author : 不要葱姜蒜
@Version : 1.0 @Version : 1.1
@Desc : None @Desc : None
''' '''
import os import os
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import json import json
from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding from Embeddings import BaseEmbeddings, OpenAIEmbedding
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm

19
docs/chapter7/RAG/demo.py Normal file
View File

@@ -0,0 +1,19 @@
from VectorBase import VectorStore
from utils import ReadFiles
from LLM import OpenAIChat
from Embeddings import OpenAIEmbedding
# 没有保存数据库
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # 获得data目录下的所有文件内容并分割
vector = VectorStore(docs)
embedding = OpenAIEmbedding() # 创建EmbeddingModel
vector.get_vector(EmbeddingModel=embedding)
vector.persist(path='storage') # 将向量和文档内容保存到storage目录下下次再用就可以直接加载本地的数据库
# vector.load_vector('./storage') # 加载本地的数据库
question = 'RAG的原理是什么'
content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
chat = OpenAIChat(model='Qwen/Qwen2.5-32B-Instruct')
print(chat.chat(question, [], content))

View File

@@ -1,14 +1,28 @@
openai annotated-types==0.7.0
zhipuai anyio==4.9.0
numpy beautifulsoup4==4.13.4
python-dotenv bs4==0.0.2
torch certifi==2025.6.15
torchvision charset-normalizer==3.4.2
torchaudio distro==1.9.0
transformers h11==0.16.0
tqdm httpcore==1.0.9
PyPDF2 httpx==0.28.1
markdown idna==3.10
html2text jiter==0.10.0
tiktoken markdown==3.8.2
beautifulsoup4 numpy==2.3.0
openai==1.88.0
pydantic==2.11.7
pydantic-core==2.33.2
pypdf2==3.0.1
python-dotenv==1.1.0
regex==2024.11.6
requests==2.32.4
sniffio==1.3.1
soupsieve==2.7
tiktoken==0.9.0
tqdm==4.67.1
typing-extensions==4.14.0
typing-inspection==0.4.1
urllib3==2.5.0

View File

@@ -2,9 +2,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
''' '''
@File : utils.py @File : utils.py
@Time : 2024/02/11 09:52:26 @Time : 2025/06/20 13:50:47
@Author : 不要葱姜蒜 @Author : 不要葱姜蒜
@Version : 1.0 @Version : 1.1
@Desc : None @Desc : None
''' '''
@@ -13,7 +13,6 @@ from typing import Dict, List, Optional, Tuple, Union
import PyPDF2 import PyPDF2
import markdown import markdown
import html2text
import json import json
from tqdm import tqdm from tqdm import tqdm
import tiktoken import tiktoken

View File

@@ -146,21 +146,59 @@ class BaseEmbeddings:
Base class for embeddings Base class for embeddings
""" """
def __init__(self, path: str, is_api: bool) -> None: def __init__(self, path: str, is_api: bool) -> None:
"""
初始化嵌入基类
Args:
path (str): 模型或数据的路径
is_api (bool): 是否使用API方式。True表示使用在线API服务False表示使用本地模型
"""
self.path = path self.path = path
self.is_api = is_api self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]: def get_embedding(self, text: str, model: str) -> List[float]:
"""
获取文本的嵌入向量表示
Args:
text (str): 输入文本
model (str): 使用的模型名称
Returns:
List[float]: 文本的嵌入向量
Raises:
NotImplementedError: 该方法需要在子类中实现
"""
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float: def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
""" """
calculate cosine similarity between two vectors 计算两个向量之间的余弦相似度
Args:
vector1 (List[float]): 第一个向量
vector2 (List[float]): 第二个向量
Returns:
float: 两个向量的余弦相似度,范围在[-1,1]之间
""" """
dot_product = np.dot(vector1, vector2) # 将输入列表转换为numpy数组并指定数据类型为float32
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) v1 = np.array(vector1, dtype=np.float32)
if not magnitude: v2 = np.array(vector2, dtype=np.float32)
return 0
# 检查向量中是否包含无穷大或NaN值
if not np.all(np.isfinite(v1)) or not np.all(np.isfinite(v2)):
return 0.0
# 计算向量的点积
dot_product = np.dot(v1, v2)
# 计算向量的范数(长度)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2)
# 计算分母(两个向量范数的乘积)
magnitude = norm_v1 * norm_v2
# 处理分母为0的特殊情况
if magnitude == 0:
return 0.0
# 返回余弦相似度
return dot_product / magnitude return dot_product / magnitude
``` ```
@@ -176,12 +214,16 @@ class OpenAIEmbedding(BaseEmbeddings):
def __init__(self, path: str = '', is_api: bool = True) -> None: def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api) super().__init__(path, is_api)
if self.is_api: if self.is_api:
from openai import OpenAI
self.client = OpenAI() self.client = OpenAI()
# 从环境变量中获取 硅基流动 密钥
self.client.api_key = os.getenv("OPENAI_API_KEY") self.client.api_key = os.getenv("OPENAI_API_KEY")
# 从环境变量中获取 硅基流动 的基础URL
self.client.base_url = os.getenv("OPENAI_BASE_URL") self.client.base_url = os.getenv("OPENAI_BASE_URL")
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]: def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
"""
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
"""
if self.is_api: if self.is_api:
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding return self.client.embeddings.create(input=[text], model=model).data[0].embedding
@@ -189,6 +231,9 @@ class OpenAIEmbedding(BaseEmbeddings):
raise NotImplementedError raise NotImplementedError
``` ```
> 此处我们默认使用国内用户可访问的硅基流动大模型API服务平台。
> 硅基流动https://cloud.siliconflow.cn/
#### Step 3: 文档加载和切分 #### Step 3: 文档加载和切分
接下来我们来实现一个文档加载和切分的类,这个类主要用于加载文档并将其切分成文档片段。 接下来我们来实现一个文档加载和切分的类,这个类主要用于加载文档并将其切分成文档片段。
@@ -251,7 +296,7 @@ def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150
- `get_vector`:获取文档的向量表示。 - `get_vector`:获取文档的向量表示。
- `query`:根据问题检索相关文档片段。 - `query`:根据问题检索相关文档片段。
完整代码可以在 ***[RAG/VectorBase.py](RAG/VectorBase.py)*** 文件中找到。 完整代码可以在 ***[/VectorBase.py](./RAG/VectorBase.py)*** 文件中找到。
```python ```python
class VectorStore: class VectorStore:
@@ -302,41 +347,43 @@ class BaseModel:
pass pass
``` ```
`BaseModel` 包含两个方法:`chat``load_model`。对于本地化运行的开源模型需要实现`load_model`而API模型则不需要。 `BaseModel` 包含两个方法:`chat``load_model`。对于本地化运行的开源模型需要实现`load_model`而API模型则不需要。在此处我们还是使用国内用户可访问的硅基流动大模型API服务平台使用API服务的好处就是用户不需要本地的计算资源可以大大降低学习者的学习门槛。
下面以 ***[InternLM2-chat-7B](https://huggingface.co/internlm/internlm2-chat-7b)*** 模型为例:
```python ```python
class InternLMChat(BaseModel): from openai import OpenAI
def __init__(self, path: str = '') -> None:
super().__init__(path)
self.load_model()
def chat(self, prompt: str, history: List = [], content: str='') -> str: class OpenAIChat(BaseModel):
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPLATE'].format(question=prompt, context=content) def __init__(self, model: str = "Qwen/Qwen2.5-32B-Instruct") -> None:
response, history = self.model.chat(self.tokenizer, prompt, history) self.model = model
return response
def chat(self, prompt: str, history: List[dict], content: str) -> str:
client = OpenAI()
client.api_key = os.getenv("OPENAI_API_KEY")
client.base_url = os.getenv("OPENAI_BASE_URL")
history.append({'role': 'user', 'content': RAG_PROMPT_TEMPLATE.format(question=prompt, context=content)})
response = client.chat.completions.create(
model=self.model,
messages=history,
max_tokens=2048,
temperature=0.1
)
return response.choices[0].message.content
def load_model(self):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
``` ```
可以用一个字典来保存所有的prompt方便维护 设计一个专用于RAG的大模型提示词如下
```python ```python
PROMPT_TEMPLATE = dict( RAG_PROMPT_TEMPLATE="""
InternLM_PROMPT_TEMPLATE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。 使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
问题: {question} 问题: {question}
可参考的上下文: 可参考的上下文:
··· ···
{context} {context}
··· ···
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。 如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
有用的回答:""" 有用的回答:
) """
``` ```
这样我们就可以利用InternLM2模型来做RAG啦 这样我们就可以利用InternLM2模型来做RAG啦
@@ -346,47 +393,51 @@ PROMPT_TEMPLATE = dict(
接下来我们来看看Tiny-RAG的Demo吧 接下来我们来看看Tiny-RAG的Demo吧
```python ```python
from RAG.VectorBase import VectorStore from VectorBase import VectorStore
from RAG.utils import ReadFiles from utils import ReadFiles
from RAG.LLM import OpenAIChat, InternLMChat from LLM import OpenAIChat
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding from Embeddings import OpenAIEmbedding
# 没有保存数据库 # 没有保存数据库
docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # 获data目录下的所有文件内容并分割 docs = ReadFiles('./data').get_content(max_token_len=600, cover_content=150) # 获data目录下的所有文件内容并分割
vector = VectorStore(docs) vector = VectorStore(docs)
embedding = ZhipuEmbedding() # 创建EmbeddingModel embedding = OpenAIEmbedding() # 创建EmbeddingModel
vector.get_vector(EmbeddingModel=embedding) vector.get_vector(EmbeddingModel=embedding)
vector.persist(path='storage') # 将向量和文档内容保存到storage目录下次再用可以直接加载本地数据库 vector.persist(path='storage') # 将向量和文档内容保存到storage目录,下次再用可以直接加载本地数据库
question = 'git的原理是什么' # vector.load_vector('./storage') # 加载本地的数据库
content = vector.query(question, model='zhipu', k=1)[0] question = 'RAG的原理是什么'
chat = InternLMChat(path='model_path')
content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
chat = OpenAIChat(model='Qwen/Qwen2.5-32B-Instruct')
print(chat.chat(question, [], content)) print(chat.chat(question, [], content))
``` ```
也可以从本地加载已处理好的数据库: 也可以从本地加载已处理好的数据库:
```python ```python
from RAG.VectorBase import VectorStore from VectorBase import VectorStore
from RAG.utils import ReadFiles from utils import ReadFiles
from RAG.LLM import OpenAIChat, InternLMChat from LLM import OpenAIChat
from RAG.Embeddings import JinaEmbedding, ZhipuEmbedding from Embeddings import OpenAIEmbedding
# 保存数据库之后 # 保存数据库之后
vector = VectorStore() vector = VectorStore()
vector.load_vector('./storage') # 加载本地数据库 vector.load_vector('./storage') # 加载本地数据库
question = 'git的原理是什么?' question = 'RAG的原理是什么?'
embedding = ZhipuEmbedding() # 创建EmbeddingModel embedding = ZhipuEmbedding() # 创建EmbeddingModel
content = vector.query(question, EmbeddingModel=embedding, k=1)[0] content = vector.query(question, EmbeddingModel=embedding, k=1)[0]
chat = InternLMChat(path='model_path') chat = OpenAIChat(model='Qwen/Qwen2.5-32B-Instruct')
print(chat.chat(question, [], content)) print(chat.chat(question, [], content))
``` ```
> 7.2 章节的所有代码均可在 [Happy-LLM Chapter7 RAG](https://github.com/datawhalechina/happy-llm/tree/main/docs/chapter7/RAG) 中找到。
## 7.3 Agent ## 7.3 Agent
### 7.3.1 什么是 LLM Agent ### 7.3.1 什么是 LLM Agent