Add 3.3 RAG

This commit is contained in:
KMnO4-zx
2024-10-24 15:05:13 +08:00
parent d8cf0c0031
commit 90122a2ba9
7 changed files with 715 additions and 306 deletions

View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : VectorBase.py
@Time : 2024/02/12 10:11:13
@Author : 不要葱姜蒜
@Version : 1.0
@Desc : None
'''
import os
from typing import Dict, List, Optional, Tuple, Union
import json
from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding
import numpy as np
from tqdm import tqdm
class VectorStore:
def __init__(self, document: List[str] = ['']) -> None:
self.document = document
def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:
self.vectors = []
for doc in tqdm(self.document, desc="Calculating embeddings"):
self.vectors.append(EmbeddingModel.get_embedding(doc))
return self.vectors
def persist(self, path: str = 'storage'):
if not os.path.exists(path):
os.makedirs(path)
with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f:
json.dump(self.document, f, ensure_ascii=False)
if self.vectors:
with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f:
json.dump(self.vectors, f)
def load_vector(self, path: str = 'storage'):
with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f:
self.vectors = json.load(f)
with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f:
self.document = json.load(f)
def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:
return BaseEmbeddings.cosine_similarity(vector1, vector2)
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
query_vector = EmbeddingModel.get_embedding(query)
result = np.array([self.get_similarity(query_vector, vector)
for vector in self.vectors])
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()