- 添加search_wikipedia和get_current_temperature工具函数 - 实现基于Streamlit的web交互界面 - 更新requirements.txt添加相关依赖 - 修复PROMPT_TEMPLATE变量名拼写错误 - 移除不再使用的工具函数 - 添加web界面截图到文档
113 lines
4.3 KiB
Python
113 lines
4.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
'''
|
|
@File : LLM.py
|
|
@Time : 2024/02/12 13:50:47
|
|
@Author : 不要葱姜蒜
|
|
@Version : 1.0
|
|
@Desc : None
|
|
'''
|
|
import os
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
PROMPT_TEMPLATE = dict(
|
|
RAG_PROMPT_TEMPLATE="""使用以上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
|
问题: {question}
|
|
可参考的上下文:
|
|
···
|
|
{context}
|
|
···
|
|
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
|
有用的回答:""",
|
|
InternLM_PROMPT_TEMPLATE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
|
|
问题: {question}
|
|
可参考的上下文:
|
|
···
|
|
{context}
|
|
···
|
|
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
|
|
有用的回答:"""
|
|
)
|
|
|
|
|
|
class BaseModel:
|
|
def __init__(self, path: str = '') -> None:
|
|
self.path = path
|
|
|
|
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
|
pass
|
|
|
|
def load_model(self):
|
|
pass
|
|
|
|
class OpenAIChat(BaseModel):
|
|
def __init__(self, path: str = '', model: str = "gpt-3.5-turbo-1106") -> None:
|
|
super().__init__(path)
|
|
self.model = model
|
|
|
|
def chat(self, prompt: str, history: List[dict], content: str) -> str:
|
|
from openai import OpenAI
|
|
client = OpenAI()
|
|
client.api_key = os.getenv("OPENAI_API_KEY")
|
|
client.base_url = os.getenv("OPENAI_BASE_URL")
|
|
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
|
response = client.chat.completions.create(
|
|
model=self.model,
|
|
messages=history,
|
|
max_tokens=150,
|
|
temperature=0.1
|
|
)
|
|
return response.choices[0].message.content
|
|
|
|
class InternLMChat(BaseModel):
|
|
def __init__(self, path: str = '') -> None:
|
|
super().__init__(path)
|
|
self.load_model()
|
|
|
|
def chat(self, prompt: str, history: List = [], content: str='') -> str:
|
|
prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPLATE'].format(question=prompt, context=content)
|
|
response, history = self.model.chat(self.tokenizer, prompt, history)
|
|
return response
|
|
|
|
|
|
def load_model(self):
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
|
|
self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
|
|
|
|
class DashscopeChat(BaseModel):
|
|
def __init__(self, path: str = '', model: str = "qwen-turbo") -> None:
|
|
super().__init__(path)
|
|
self.model = model
|
|
|
|
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
|
import dashscope
|
|
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
|
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
|
response = dashscope.Generation.call(
|
|
model=self.model,
|
|
messages=history,
|
|
result_format='message',
|
|
max_tokens=150,
|
|
temperature=0.1
|
|
)
|
|
return response.output.choices[0].message.content
|
|
|
|
|
|
class ZhipuChat(BaseModel):
|
|
def __init__(self, path: str = '', model: str = "glm-4") -> None:
|
|
super().__init__(path)
|
|
from zhipuai import ZhipuAI
|
|
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
|
self.model = model
|
|
|
|
def chat(self, prompt: str, history: List[Dict], content: str) -> str:
|
|
history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=history,
|
|
max_tokens=150,
|
|
temperature=0.1
|
|
)
|
|
return response.choices[0].message |