Add 3.3 RAG
This commit is contained in:
117
docs/chapter8/RAG/Embeddings.py
Normal file
117
docs/chapter8/RAG/Embeddings.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : Embeddings.py
|
||||
@Time : 2024/02/10 21:55:39
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Desc : None
|
||||
'''
|
||||
|
||||
import os
|
||||
from copy import copy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
os.environ['CURL_CA_BUNDLE'] = ''
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
_ = load_dotenv(find_dotenv())
|
||||
|
||||
|
||||
class BaseEmbeddings:
|
||||
"""
|
||||
Base class for embeddings
|
||||
"""
|
||||
def __init__(self, path: str, is_api: bool) -> None:
|
||||
self.path = path
|
||||
self.is_api = is_api
|
||||
|
||||
def get_embedding(self, text: str, model: str) -> List[float]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
||||
"""
|
||||
calculate cosine similarity between two vectors
|
||||
"""
|
||||
dot_product = np.dot(vector1, vector2)
|
||||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
|
||||
if not magnitude:
|
||||
return 0
|
||||
return dot_product / magnitude
|
||||
|
||||
|
||||
class OpenAIEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for OpenAI embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI()
|
||||
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
|
||||
if self.is_api:
|
||||
text = text.replace("\n", " ")
|
||||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
class JinaEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Jina embeddings
|
||||
"""
|
||||
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
|
||||
super().__init__(path, is_api)
|
||||
self._model = self.load_model()
|
||||
|
||||
def get_embedding(self, text: str) -> List[float]:
|
||||
return self._model.encode([text])[0].tolist()
|
||||
|
||||
def load_model(self):
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
|
||||
return model
|
||||
|
||||
class ZhipuEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Zhipu embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
||||
|
||||
def get_embedding(self, text: str) -> List[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model="embedding-2",
|
||||
input=text,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
class DashscopeEmbedding(BaseEmbeddings):
|
||||
"""
|
||||
class for Dashscope embeddings
|
||||
"""
|
||||
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
||||
super().__init__(path, is_api)
|
||||
if self.is_api:
|
||||
import dashscope
|
||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
self.client = dashscope.TextEmbedding
|
||||
|
||||
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
|
||||
response = self.client.call(
|
||||
model=model,
|
||||
input=text
|
||||
)
|
||||
return response.output['embeddings'][0]['embedding']
|
||||
Reference in New Issue
Block a user