修改项目结构+7,4 一部分
This commit is contained in:
117
docs/chapter7/RAG/Embeddings.py
Normal file
117
docs/chapter7/RAG/Embeddings.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : Embeddings.py
|
||||
@Time : 2024/02/10 21:55:39
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Desc : None
|
||||
'''
|
||||
|
||||
import os
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
os.environ['CURL_CA_BUNDLE'] = ''
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
_ = load_dotenv(find_dotenv())
|
||||
|
||||
|
||||
class BaseEmbeddings:
|
||||
"""
|
||||
Base class for embeddings
|
||||
"""
|
||||
def __init__(self, path: str, is_api: bool) -> None:
|
||||
self.path = path
|
||||
self.is_api = is_api
|
||||
|
||||
def get_embedding(self, text: str, model: str) -> List[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
||||
"""
|
||||
calculate cosine similarity between two vectors
|
||||
"""
|
||||
dot_product = np.dot(vector1, vector2)
|
||||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
|
||||
if not magnitude:
|
||||
return 0
|
||||
return dot_product / magnitude
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for OpenAI embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI()
|
||||
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
|
||||
if self.is_api:
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class JinaEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Jina embeddings
|
||||
"""
|
||||
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
|
||||
super().__init__(path, is_api)
|
||||
self._model = self.load_model()
|
||||
|
||||
def get_embedding(self, text: str) -> List[float]:
|
||||
return self._model.encode([text])[0].tolist()
|
||||
|
||||
def load_model(self):
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
|
||||
return model
|
||||
|
||||
class ZhipuEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Zhipu embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
||||
|
||||
def get_embedding(self, text: str) -> List[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model="embedding-2",
|
||||
input=text,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
class DashscopeEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Dashscope embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
import dashscope
|
||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
self.client = dashscope.TextEmbedding
|
||||
|
||||
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
|
||||
response = self.client.call(
|
||||
model=model,
|
||||
input=text
|
||||
)
|
||||
return response.output['embeddings'][0]['embedding']
|
||||
113
docs/chapter7/RAG/LLM.py
Normal file
113
docs/chapter7/RAG/LLM.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : LLM.py
|
||||
@Time : 2024/02/12 13:50:47
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Desc : None
|
||||
'''
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
PROMPT_TEMPLATE = dict(
|
||||
RAG_PROMPT_TEMPALTE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||
问题: {question}
|
||||
可参考的上下文:
|
||||
···
|
||||
{context}
|
||||
···
|
||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
||||
有用的回答:""",
|
||||
InternLM_PROMPT_TEMPALTE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||
问题: {question}
|
||||
可参考的上下文:
|
||||
···
|
||||
{context}
|
||||
···
|
||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
||||
有用的回答:"""
|
||||
)
|
||||
|
||||
|
||||
class BaseModel:
|
||||
def __init__(self, path: str = '') -> None:
|
||||
self.path = path
|
||||
|
||||
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||||
pass
|
||||
|
||||
def load_model(self):
|
||||
pass
|
||||
|
||||
class OpenAIChat(BaseModel):
|
||||
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None:
|
||||
super().__init__(path)
|
||||
self.model = model
|
||||
|
||||
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||||
from openai import OpenAI
|
||||
client = OpenAI()
|
||||
client.api_key = os.getenv("OPENAI_API_KEY")
|
||||
client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=history,
|
||||
max_tokens=150,
|
||||
temperature=0.1
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
class InternLMChat(BaseModel):
|
||||
def __init__(self, path: str = '') -> None:
|
||||
super().__init__(path)
|
||||
self.load_model()
|
||||
|
||||
def chat(self, prompt: str, history: List = [], content: str='') -> str:
|
||||
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
|
||||
response, history = self.model.chat(self.tokenizer, prompt, history)
|
||||
return response
|
||||
|
||||
|
||||
def load_model(self):
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
|
||||
|
||||
class DashscopeChat(BaseModel):
|
||||
def __init__(self, path: str = '', model: str = "qwen-turbo") -> None:
|
||||
super().__init__(path)
|
||||
self.model = model
|
||||
|
||||
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
||||
import dashscope
|
||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model,
|
||||
messages=history,
|
||||
result_format='message',
|
||||
max_tokens=150,
|
||||
temperature=0.1
|
||||
)
|
||||
return response.output.choices[0].message.content
|
||||
|
||||
|
||||
class ZhipuChat(BaseModel):
|
||||
def __init__(self, path: str = '', model: str = "glm-4") -> None:
|
||||
super().__init__(path)
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
||||
self.model = model
|
||||
|
||||
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=history,
|
||||
max_tokens=150,
|
||||
temperature=0.1
|
||||
)
|
||||
return response.choices[0].message
|
||||
52
docs/chapter7/RAG/VectorBase.py
Normal file
52
docs/chapter7/RAG/VectorBase.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : VectorBase.py
|
||||
@Time : 2024/02/12 10:11:13
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Desc : None
|
||||
'''
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import json
|
||||
from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self, document: List[str] = ['']) -> None:
|
||||
self.document = document
|
||||
|
||||
def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:
|
||||
|
||||
self.vectors = []
|
||||
for doc in tqdm(self.document, desc="Calculating embeddings"):
|
||||
self.vectors.append(EmbeddingModel.get_embedding(doc))
|
||||
return self.vectors
|
||||
|
||||
def persist(self, path: str = 'storage'):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(self.document, f, ensure_ascii=False)
|
||||
if self.vectors:
|
||||
with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(self.vectors, f)
|
||||
|
||||
def load_vector(self, path: str = 'storage'):
|
||||
with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f:
|
||||
self.vectors = json.load(f)
|
||||
with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f:
|
||||
self.document = json.load(f)
|
||||
|
||||
def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:
|
||||
return BaseEmbeddings.cosine_similarity(vector1, vector2)
|
||||
|
||||
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
|
||||
query_vector = EmbeddingModel.get_embedding(query)
|
||||
result = np.array([self.get_similarity(query_vector, vector)
|
||||
for vector in self.vectors])
|
||||
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()
|
||||
160
docs/chapter7/RAG/utils.py
Normal file
160
docs/chapter7/RAG/utils.py
Normal file
@@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : utils.py
|
||||
@Time : 2024/02/11 09:52:26
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Desc : None
|
||||
'''
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import PyPDF2
|
||||
import markdown
|
||||
import html2text
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import tiktoken
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
class ReadFiles:
|
||||
"""
|
||||
class to read files
|
||||
"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self._path = path
|
||||
self.file_list = self.get_files()
|
||||
|
||||
def get_files(self):
|
||||
# args:dir_path,目标文件夹路径
|
||||
file_list = []
|
||||
for filepath, dirnames, filenames in os.walk(self._path):
|
||||
# os.walk 函数将递归遍历指定文件夹
|
||||
for filename in filenames:
|
||||
# 通过后缀名判断文件类型是否满足要求
|
||||
if filename.endswith(".md"):
|
||||
# 如果满足要求,将其绝对路径加入到结果列表
|
||||
file_list.append(os.path.join(filepath, filename))
|
||||
elif filename.endswith(".txt"):
|
||||
file_list.append(os.path.join(filepath, filename))
|
||||
elif filename.endswith(".pdf"):
|
||||
file_list.append(os.path.join(filepath, filename))
|
||||
return file_list
|
||||
|
||||
def get_content(self, max_token_len: int = 600, cover_content: int = 150):
|
||||
docs = []
|
||||
# 读取文件内容
|
||||
for file in self.file_list:
|
||||
content = self.read_file_content(file)
|
||||
chunk_content = self.get_chunk(
|
||||
content, max_token_len=max_token_len, cover_content=cover_content)
|
||||
docs.extend(chunk_content)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):
|
||||
chunk_text = []
|
||||
|
||||
curr_len = 0
|
||||
curr_chunk = ''
|
||||
|
||||
token_len = max_token_len - cover_content
|
||||
lines = text.splitlines() # 假设以换行符分割文本为行
|
||||
|
||||
for line in lines:
|
||||
line = line.replace(' ', '')
|
||||
line_len = len(enc.encode(line))
|
||||
if line_len > max_token_len:
|
||||
# 如果单行长度就超过限制,则将其分割成多个块
|
||||
num_chunks = (line_len + token_len - 1) // token_len
|
||||
for i in range(num_chunks):
|
||||
start = i * token_len
|
||||
end = start + token_len
|
||||
# 避免跨单词分割
|
||||
while not line[start:end].rstrip().isspace():
|
||||
start += 1
|
||||
end += 1
|
||||
if start >= line_len:
|
||||
break
|
||||
curr_chunk = curr_chunk[-cover_content:] + line[start:end]
|
||||
chunk_text.append(curr_chunk)
|
||||
# 处理最后一个块
|
||||
start = (num_chunks - 1) * token_len
|
||||
curr_chunk = curr_chunk[-cover_content:] + line[start:end]
|
||||
chunk_text.append(curr_chunk)
|
||||
|
||||
if curr_len + line_len <= token_len:
|
||||
curr_chunk += line
|
||||
curr_chunk += '\n'
|
||||
curr_len += line_len
|
||||
curr_len += 1
|
||||
else:
|
||||
chunk_text.append(curr_chunk)
|
||||
curr_chunk = curr_chunk[-cover_content:]+line
|
||||
curr_len = line_len + cover_content
|
||||
|
||||
if curr_chunk:
|
||||
chunk_text.append(curr_chunk)
|
||||
|
||||
return chunk_text
|
||||
|
||||
@classmethod
|
||||
def read_file_content(cls, file_path: str):
|
||||
# 根据文件扩展名选择读取方法
|
||||
if file_path.endswith('.pdf'):
|
||||
return cls.read_pdf(file_path)
|
||||
elif file_path.endswith('.md'):
|
||||
return cls.read_markdown(file_path)
|
||||
elif file_path.endswith('.txt'):
|
||||
return cls.read_text(file_path)
|
||||
else:
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
@classmethod
|
||||
def read_pdf(cls, file_path: str):
|
||||
# 读取PDF文件
|
||||
with open(file_path, 'rb') as file:
|
||||
reader = PyPDF2.PdfReader(file)
|
||||
text = ""
|
||||
for page_num in range(len(reader.pages)):
|
||||
text += reader.pages[page_num].extract_text()
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def read_markdown(cls, file_path: str):
|
||||
# 读取Markdown文件
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
md_text = file.read()
|
||||
html_text = markdown.markdown(md_text)
|
||||
# 使用BeautifulSoup从HTML中提取纯文本
|
||||
soup = BeautifulSoup(html_text, 'html.parser')
|
||||
plain_text = soup.get_text()
|
||||
# 使用正则表达式移除网址链接
|
||||
text = re.sub(r'http\S+', '', plain_text)
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def read_text(cls, file_path: str):
|
||||
# 读取文本文件
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
class Documents:
|
||||
"""
|
||||
获取已分好类的json格式文档
|
||||
"""
|
||||
def __init__(self, path: str = '') -> None:
|
||||
self.path = path
|
||||
|
||||
def get_content(self):
|
||||
with open(self.path, mode='r', encoding='utf-8') as f:
|
||||
content = json.load(f)
|
||||
return content
|
||||
Reference in New Issue
Block a user