#!/usr/bin/env python # -*- coding: utf-8 -*- ''' @File : VectorBase.py @Time : 2024/02/12 10:11:13 @Author : 不要葱姜蒜 @Version : 1.0 @Desc : None ''' import os from typing import Dict, List, Optional, Tuple, Union import json from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding import numpy as np from tqdm import tqdm class VectorStore: def __init__(self, document: List[str] = ['']) -> None: self.document = document def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]: self.vectors = [] for doc in tqdm(self.document, desc="Calculating embeddings"): self.vectors.append(EmbeddingModel.get_embedding(doc)) return self.vectors def persist(self, path: str = 'storage'): if not os.path.exists(path): os.makedirs(path) with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f: json.dump(self.document, f, ensure_ascii=False) if self.vectors: with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f: json.dump(self.vectors, f) def load_vector(self, path: str = 'storage'): with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f: self.vectors = json.load(f) with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f: self.document = json.load(f) def get_similarity(self, vector1: List[float], vector2: List[float]) -> float: return BaseEmbeddings.cosine_similarity(vector1, vector2) def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]: query_vector = EmbeddingModel.get_embedding(query) result = np.array([self.get_similarity(query_vector, vector) for vector in self.vectors]) return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()