修改项目结构+7,4 一部分

This commit is contained in:
KMnO4-zx
2025-04-21 22:15:49 +08:00
parent ca959a0cb8
commit 81bc97f434
17 changed files with 46 additions and 7 deletions

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : Embeddings.py
@Time : 2024/02/10 21:55:39
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : None
'''
import os
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
os.environ['CURL_CA_BUNDLE'] = ''
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
class BaseEmbeddings:
"""
Base class for embeddings
"""
def __init__(self, path: str, is_api: bool) -> None:
self.path = path
self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError
@classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
"""
calculate cosine similarity between two vectors
"""
dot_product = np.dot(vector1, vector2)
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
if not magnitude:
return 0
return dot_product / magnitude
class OpenAIEmbedding(BaseEmbeddings):
"""
class for OpenAI embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from openai import OpenAI
self.client = OpenAI()
self.client.api_key = os.getenv("OPENAI_API_KEY")
self.client.base_url = os.getenv("OPENAI_BASE_URL")
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
if self.is_api:
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
else:
raise NotImplementedError
class JinaEmbedding(BaseEmbeddings):
"""
class for Jina embeddings
"""
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
super().__init__(path, is_api)
self._model = self.load_model()
def get_embedding(self, text: str) -> List[float]:
return self._model.encode([text])[0].tolist()
def load_model(self):
import torch
from transformers import AutoModel
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
return model
class ZhipuEmbedding(BaseEmbeddings):
"""
class for Zhipu embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
def get_embedding(self, text: str) -> List[float]:
response = self.client.embeddings.create(
model="embedding-2",
input=text,
)
return response.data[0].embedding
class DashscopeEmbedding(BaseEmbeddings):
"""
class for Dashscope embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
import dashscope
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
self.client = dashscope.TextEmbedding
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
response = self.client.call(
model=model,
input=text
)
return response.output['embeddings'][0]['embedding']

113
docs/chapter7/RAG/LLM.py Normal file
View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : LLM.py
@Time : 2024/02/12 13:50:47
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : None
'''
import os
from typing import Dict, List, Optional, Tuple, Union
PROMPT_TEMPLATE = dict(
RAG_PROMPT_TEMPALTE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
问题: {question}
可参考的上下文:
···
{context}
···
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
有用的回答:""",
InternLM_PROMPT_TEMPALTE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
问题: {question}
可参考的上下文:
···
{context}
···
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
有用的回答:"""
)
class BaseModel:
def __init__(self, path: str = '') -> None:
self.path = path
def chat(self, prompt: str, history: List[dict], content: str) -> str:
pass
def load_model(self):
pass
class OpenAIChat(BaseModel):
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None:
super().__init__(path)
self.model = model
def chat(self, prompt: str, history: List[dict], content: str) -> str:
from openai import OpenAI
client = OpenAI()
client.api_key = os.getenv("OPENAI_API_KEY")
client.base_url = os.getenv("OPENAI_BASE_URL")
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
response = client.chat.completions.create(
model=self.model,
messages=history,
max_tokens=150,
temperature=0.1
)
return response.choices[0].message.content
class InternLMChat(BaseModel):
def __init__(self, path: str = '') -> None:
super().__init__(path)
self.load_model()
def chat(self, prompt: str, history: List = [], content: str='') -> str:
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)
response, history = self.model.chat(self.tokenizer, prompt, history)
return response
def load_model(self):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
class DashscopeChat(BaseModel):
def __init__(self, path: str = '', model: str = "qwen-turbo") -> None:
super().__init__(path)
self.model = model
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
import dashscope
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
response = dashscope.Generation.call(
model=self.model,
messages=history,
result_format='message',
max_tokens=150,
temperature=0.1
)
return response.output.choices[0].message.content
class ZhipuChat(BaseModel):
def __init__(self, path: str = '', model: str = "glm-4") -> None:
super().__init__(path)
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
self.model = model
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
response = self.client.chat.completions.create(
model=self.model,
messages=history,
max_tokens=150,
temperature=0.1
)
return response.choices[0].message

View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : VectorBase.py
@Time : 2024/02/12 10:11:13
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : None
'''
import os
from typing import Dict, List, Optional, Tuple, Union
import json
from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding
import numpy as np
from tqdm import tqdm
class VectorStore:
def __init__(self, document: List[str] = ['']) -> None:
self.document = document
def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:
self.vectors = []
for doc in tqdm(self.document, desc="Calculating embeddings"):
self.vectors.append(EmbeddingModel.get_embedding(doc))
return self.vectors
def persist(self, path: str = 'storage'):
if not os.path.exists(path):
os.makedirs(path)
with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f:
json.dump(self.document, f, ensure_ascii=False)
if self.vectors:
with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f:
json.dump(self.vectors, f)
def load_vector(self, path: str = 'storage'):
with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f:
self.vectors = json.load(f)
with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f:
self.document = json.load(f)
def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:
return BaseEmbeddings.cosine_similarity(vector1, vector2)
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
query_vector = EmbeddingModel.get_embedding(query)
result = np.array([self.get_similarity(query_vector, vector)
for vector in self.vectors])
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()

160
docs/chapter7/RAG/utils.py Normal file
View File

@@ -0,0 +1,160 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : utils.py
@Time : 2024/02/11 09:52:26
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : None
'''
import os
from typing import Dict, List, Optional, Tuple, Union
import PyPDF2
import markdown
import html2text
import json
from tqdm import tqdm
import tiktoken
from bs4 import BeautifulSoup
import re
enc = tiktoken.get_encoding("cl100k_base")
class ReadFiles:
"""
class to read files
"""
def __init__(self, path: str) -> None:
self._path = path
self.file_list = self.get_files()
def get_files(self):
# argsdir_path目标文件夹路径
file_list = []
for filepath, dirnames, filenames in os.walk(self._path):
# os.walk 函数将递归遍历指定文件夹
for filename in filenames:
# 通过后缀名判断文件类型是否满足要求
if filename.endswith(".md"):
# 如果满足要求,将其绝对路径加入到结果列表
file_list.append(os.path.join(filepath, filename))
elif filename.endswith(".txt"):
file_list.append(os.path.join(filepath, filename))
elif filename.endswith(".pdf"):
file_list.append(os.path.join(filepath, filename))
return file_list
def get_content(self, max_token_len: int = 600, cover_content: int = 150):
docs = []
# 读取文件内容
for file in self.file_list:
content = self.read_file_content(file)
chunk_content = self.get_chunk(
content, max_token_len=max_token_len, cover_content=cover_content)
docs.extend(chunk_content)
return docs
@classmethod
def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150):
chunk_text = []
curr_len = 0
curr_chunk = ''
token_len = max_token_len - cover_content
lines = text.splitlines() # 假设以换行符分割文本为行
for line in lines:
line = line.replace(' ', '')
line_len = len(enc.encode(line))
if line_len > max_token_len:
# 如果单行长度就超过限制,则将其分割成多个块
num_chunks = (line_len + token_len - 1) // token_len
for i in range(num_chunks):
start = i * token_len
end = start + token_len
# 避免跨单词分割
while not line[start:end].rstrip().isspace():
start += 1
end += 1
if start >= line_len:
break
curr_chunk = curr_chunk[-cover_content:] + line[start:end]
chunk_text.append(curr_chunk)
# 处理最后一个块
start = (num_chunks - 1) * token_len
curr_chunk = curr_chunk[-cover_content:] + line[start:end]
chunk_text.append(curr_chunk)
if curr_len + line_len <= token_len:
curr_chunk += line
curr_chunk += '\n'
curr_len += line_len
curr_len += 1
else:
chunk_text.append(curr_chunk)
curr_chunk = curr_chunk[-cover_content:]+line
curr_len = line_len + cover_content
if curr_chunk:
chunk_text.append(curr_chunk)
return chunk_text
@classmethod
def read_file_content(cls, file_path: str):
# 根据文件扩展名选择读取方法
if file_path.endswith('.pdf'):
return cls.read_pdf(file_path)
elif file_path.endswith('.md'):
return cls.read_markdown(file_path)
elif file_path.endswith('.txt'):
return cls.read_text(file_path)
else:
raise ValueError("Unsupported file type")
@classmethod
def read_pdf(cls, file_path: str):
# 读取PDF文件
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
text = ""
for page_num in range(len(reader.pages)):
text += reader.pages[page_num].extract_text()
return text
@classmethod
def read_markdown(cls, file_path: str):
# 读取Markdown文件
with open(file_path, 'r', encoding='utf-8') as file:
md_text = file.read()
html_text = markdown.markdown(md_text)
# 使用BeautifulSoup从HTML中提取纯文本
soup = BeautifulSoup(html_text, 'html.parser')
plain_text = soup.get_text()
# 使用正则表达式移除网址链接
text = re.sub(r'http\S+', '', plain_text)
return text
@classmethod
def read_text(cls, file_path: str):
# 读取文本文件
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
class Documents:
"""
获取已分好类的json格式文档
"""
def __init__(self, path: str = '') -> None:
self.path = path
def get_content(self):
with open(self.path, mode='r', encoding='utf-8') as f:
content = json.load(f)
return content