修改项目结构+7,4 一部分

This commit is contained in:
KMnO4-zx
2025-04-21 22:15:49 +08:00
parent ca959a0cb8
commit 81bc97f434
17 changed files with 46 additions and 7 deletions

View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : VectorBase.py
@Time : 2024/02/12 10:11:13
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : None
'''
import os
from typing import Dict, List, Optional, Tuple, Union
import json
from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding
import numpy as np
from tqdm import tqdm
class VectorStore:
def __init__(self, document: List[str] = ['']) -> None:
self.document = document
def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:
self.vectors = []
for doc in tqdm(self.document, desc="Calculating embeddings"):
self.vectors.append(EmbeddingModel.get_embedding(doc))
return self.vectors
def persist(self, path: str = 'storage'):
if not os.path.exists(path):
os.makedirs(path)
with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f:
json.dump(self.document, f, ensure_ascii=False)
if self.vectors:
with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f:
json.dump(self.vectors, f)
def load_vector(self, path: str = 'storage'):
with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f:
self.vectors = json.load(f)
with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f:
self.document = json.load(f)
def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:
return BaseEmbeddings.cosine_similarity(vector1, vector2)
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
query_vector = EmbeddingModel.get_embedding(query)
result = np.array([self.get_similarity(query_vector, vector)
for vector in self.vectors])
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()