- 添加search_wikipedia和get_current_temperature工具函数 - 实现基于Streamlit的web交互界面 - 更新requirements.txt添加相关依赖 - 修复PROMPT_TEMPLATE变量名拼写错误 - 移除不再使用的工具函数 - 添加web界面截图到文档
116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
'''
|
|
@File : Embeddings.py
|
|
@Time : 2024/02/10 21:55:39
|
|
@Author : 不要葱姜蒜
|
|
@Version : 1.0
|
|
@Desc : None
|
|
'''
|
|
|
|
import os
|
|
from copy import copy
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
import numpy as np
|
|
|
|
from dotenv import load_dotenv, find_dotenv
|
|
_ = load_dotenv(find_dotenv())
|
|
|
|
|
|
class BaseEmbeddings:
|
|
"""
|
|
Base class for embeddings
|
|
"""
|
|
def __init__(self, path: str, is_api: bool) -> None:
|
|
self.path = path
|
|
self.is_api = is_api
|
|
|
|
def get_embedding(self, text: str, model: str) -> List[float]:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
|
"""
|
|
calculate cosine similarity between two vectors
|
|
"""
|
|
dot_product = np.dot(vector1, vector2)
|
|
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
|
|
if not magnitude:
|
|
return 0
|
|
return dot_product / magnitude
|
|
|
|
|
|
class OpenAIEmbedding(BaseEmbeddings):
|
|
"""
|
|
class for OpenAI embeddings
|
|
"""
|
|
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
|
super().__init__(path, is_api)
|
|
if self.is_api:
|
|
from openai import OpenAI
|
|
self.client = OpenAI()
|
|
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
|
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
|
|
|
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
|
|
if self.is_api:
|
|
text = text.replace("\n", " ")
|
|
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
class JinaEmbedding(BaseEmbeddings):
|
|
"""
|
|
class for Jina embeddings
|
|
"""
|
|
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
|
|
super().__init__(path, is_api)
|
|
self._model = self.load_model()
|
|
|
|
def get_embedding(self, text: str) -> List[float]:
|
|
return self._model.encode([text])[0].tolist()
|
|
|
|
def load_model(self):
|
|
import torch
|
|
from transformers import AutoModel
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
else:
|
|
device = torch.device("cpu")
|
|
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
|
|
return model
|
|
|
|
class ZhipuEmbedding(BaseEmbeddings):
|
|
"""
|
|
class for Zhipu embeddings
|
|
"""
|
|
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
|
super().__init__(path, is_api)
|
|
if self.is_api:
|
|
from zhipuai import ZhipuAI
|
|
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
|
|
|
def get_embedding(self, text: str) -> List[float]:
|
|
response = self.client.embeddings.create(
|
|
model="embedding-2",
|
|
input=text,
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
class DashscopeEmbedding(BaseEmbeddings):
|
|
"""
|
|
class for Dashscope embeddings
|
|
"""
|
|
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
|
super().__init__(path, is_api)
|
|
if self.is_api:
|
|
import dashscope
|
|
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
|
self.client = dashscope.TextEmbedding
|
|
|
|
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
|
|
response = self.client.call(
|
|
model=model,
|
|
input=text
|
|
)
|
|
return response.output['embeddings'][0]['embedding'] |