修改项目结构+7,4 一部分
This commit is contained in:
52
docs/chapter7/RAG/VectorBase.py
Normal file
52
docs/chapter7/RAG/VectorBase.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
@File : VectorBase.py
|
||||
@Time : 2024/02/12 10:11:13
|
||||
@Author : 不要葱姜蒜
|
||||
@Version : 1.0
|
||||
@Desc : None
|
||||
'''
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import json
|
||||
from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self, document: List[str] = ['']) -> None:
|
||||
self.document = document
|
||||
|
||||
def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:
|
||||
|
||||
self.vectors = []
|
||||
for doc in tqdm(self.document, desc="Calculating embeddings"):
|
||||
self.vectors.append(EmbeddingModel.get_embedding(doc))
|
||||
return self.vectors
|
||||
|
||||
def persist(self, path: str = 'storage'):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(self.document, f, ensure_ascii=False)
|
||||
if self.vectors:
|
||||
with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(self.vectors, f)
|
||||
|
||||
def load_vector(self, path: str = 'storage'):
|
||||
with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f:
|
||||
self.vectors = json.load(f)
|
||||
with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f:
|
||||
self.document = json.load(f)
|
||||
|
||||
def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:
|
||||
return BaseEmbeddings.cosine_similarity(vector1, vector2)
|
||||
|
||||
def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:
|
||||
query_vector = EmbeddingModel.get_embedding(query)
|
||||
result = np.array([self.get_similarity(query_vector, vector)
|
||||
for vector in self.vectors])
|
||||
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist()
|
||||
Reference in New Issue
Block a user