feat(RAG): 更新RAG模块代码和文档
refactor: 简化Embeddings和LLM类实现,移除不必要依赖 docs: 更新文档内容,添加硅基流动API使用说明 chore: 更新requirements.txt依赖版本
This commit is contained in:
@@ -2,37 +2,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : LLM.py
|
||||
@Time : 2024/02/12 13:50:47
|
||||
@Time : 2025/06/20 13:50:47
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Version : 1.1
|
||||
@Desc : None
|
||||
'''
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from openai import OpenAI
|
||||
|
||||
PROMPT_TEMPLATE = dict(
|
||||
RAG_PROMPT_TEMPLATE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||
问题: {question}
|
||||
可参考的上下文:
|
||||
···
|
||||
{context}
|
||||
···
|
||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
||||
有用的回答:""",
|
||||
InternLM_PROMPT_TEMPLATE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||
问题: {question}
|
||||
可参考的上下文:
|
||||
···
|
||||
{context}
|
||||
···
|
||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
||||
有用的回答:"""
|
||||
)
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
_ = load_dotenv(find_dotenv())
|
||||
|
||||
RAG_PROMPT_TEMPLATE="""
|
||||
使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
||||
问题: {question}
|
||||
可参考的上下文:
|
||||
···
|
||||
{context}
|
||||
···
|
||||
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
||||
有用的回答:
|
||||
"""
|
||||
|
||||
|
||||
class BaseModel:
|
||||
def __init__(self, path: str = '') -> None:
|
||||
self.path = path
|
||||
def __init__(self, model) -> None:
|
||||
self.model = model
|
||||
|
||||
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||||
pass
|
||||
@@ -41,73 +37,18 @@ class BaseModel:
|
||||
pass
|
||||
|
||||
class OpenAIChat(BaseModel):
|
||||
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None:
|
||||
super().__init__(path)
|
||||
def __init__(self, model: str = "Qwen/Qwen2.5-32B-Instruct") -> None:
|
||||
self.model = model
|
||||
|
||||
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
||||
from openai import OpenAI
|
||||
client = OpenAI()
|
||||
client.api_key = os.getenv("OPENAI_API_KEY")
|
||||
client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
||||
history.append({'role': 'user', 'content': RAG_PROMPT_TEMPLATE.format(question=prompt, context=content)})
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=history,
|
||||
max_tokens=150,
|
||||
max_tokens=2048,
|
||||
temperature=0.1
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
class InternLMChat(BaseModel):
|
||||
def __init__(self, path: str = '') -> None:
|
||||
super().__init__(path)
|
||||
self.load_model()
|
||||
|
||||
def chat(self, prompt: str, history: List = [], content: str='') -> str:
|
||||
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPLATE'].format(question=prompt, context=content)
|
||||
response, history = self.model.chat(self.tokenizer, prompt, history)
|
||||
return response
|
||||
|
||||
|
||||
def load_model(self):
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
|
||||
|
||||
class DashscopeChat(BaseModel):
|
||||
def __init__(self, path: str = '', model: str = "qwen-turbo") -> None:
|
||||
super().__init__(path)
|
||||
self.model = model
|
||||
|
||||
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
||||
import dashscope
|
||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model,
|
||||
messages=history,
|
||||
result_format='message',
|
||||
max_tokens=150,
|
||||
temperature=0.1
|
||||
)
|
||||
return response.output.choices[0].message.content
|
||||
|
||||
|
||||
class ZhipuChat(BaseModel):
|
||||
def __init__(self, path: str = '', model: str = "glm-4") -> None:
|
||||
super().__init__(path)
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
||||
self.model = model
|
||||
|
||||
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
||||
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=history,
|
||||
max_tokens=150,
|
||||
temperature=0.1
|
||||
)
|
||||
return response.choices[0].message
|
||||
Reference in New Issue
Block a user