Files
happy-llm/docs/chapter8/RAG/Embeddings.py
2024-10-24 15:05:13 +08:00

117 lines
3.7 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : Embeddings.py
@Time : 2024/02/10 21:55:39
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : None
'''
import os
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
os.environ['CURL_CA_BUNDLE'] = ''
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
class BaseEmbeddings:
"""
Base class for embeddings
"""
def __init__(self, path: str, is_api: bool) -> None:
self.path = path
self.is_api = is_api
def get_embedding(self, text: str, model: str) -> List[float]:
raise NotImplementedError
@classmethod
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
"""
calculate cosine similarity between two vectors
"""
dot_product = np.dot(vector1, vector2)
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
if not magnitude:
return 0
return dot_product / magnitude
class OpenAIEmbedding(BaseEmbeddings):
"""
class for OpenAI embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from openai import OpenAI
self.client = OpenAI()
self.client.api_key = os.getenv("OPENAI_API_KEY")
self.client.base_url = os.getenv("OPENAI_BASE_URL")
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
if self.is_api:
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
else:
raise NotImplementedError
class JinaEmbedding(BaseEmbeddings):
"""
class for Jina embeddings
"""
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
super().__init__(path, is_api)
self._model = self.load_model()
def get_embedding(self, text: str) -> List[float]:
return self._model.encode([text])[0].tolist()
def load_model(self):
import torch
from transformers import AutoModel
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
return model
class ZhipuEmbedding(BaseEmbeddings):
"""
class for Zhipu embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
from zhipuai import ZhipuAI
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
def get_embedding(self, text: str) -> List[float]:
response = self.client.embeddings.create(
model="embedding-2",
input=text,
)
return response.data[0].embedding
class DashscopeEmbedding(BaseEmbeddings):
"""
class for Dashscope embeddings
"""
def __init__(self, path: str = '', is_api: bool = True) -> None:
super().__init__(path, is_api)
if self.is_api:
import dashscope
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
self.client = dashscope.TextEmbedding
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
response = self.client.call(
model=model,
input=text
)
return response.output['embeddings'][0]['embedding']