inference running locally. use neo4j next
This commit is contained in:
@@ -6,8 +6,6 @@ from typing import Any, Union, cast
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from nano_vectordb import NanoVectorDB
|
||||
from kg.neo4j import GraphStorage
|
||||
|
||||
|
||||
from .utils import load_json, logger, write_json
|
||||
from .base import (
|
||||
@@ -99,66 +97,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
d["__vector__"] = embeddings[i]
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class PineConeVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
|
||||
def __post_init__(self):
|
||||
self._client_file_name = os.path.join(
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
import os
|
||||
from pinecone import Pinecone
|
||||
|
||||
pc = Pinecone() #api_key=os.environ.get('PINECONE_API_KEY'))
|
||||
# From here on, everything is identical to the REST-based SDK.
|
||||
self._client = pc.Index(host=self._client_pinecone_host)#'my-index-8833ca1.svc.us-east1-gcp.pinecone.io')
|
||||
|
||||
self.cosine_better_than_threshold = self.global_config.get(
|
||||
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
list_data = [
|
||||
{
|
||||
"__id__": k,
|
||||
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
|
||||
}
|
||||
for k, v in data.items()
|
||||
]
|
||||
contents = [v["content"] for v in data.values()]
|
||||
batches = [
|
||||
contents[i : i + self._max_batch_size]
|
||||
for i in range(0, len(contents), self._max_batch_size)
|
||||
]
|
||||
embeddings_list = await asyncio.gather(
|
||||
*[self.embedding_func(batch) for batch in batches]
|
||||
)
|
||||
embeddings = np.concatenate(embeddings_list)
|
||||
for i, d in enumerate(list_data):
|
||||
d["__vector__"] = embeddings[i]
|
||||
# self._client.upsert(vectors=[]) pinecone
|
||||
results = self._client.upsert(datas=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
# self._client.query(vector=[...], top_key=10) pinecone
|
||||
results = self._client.query(
|
||||
vector=embedding,
|
||||
query=embedding,
|
||||
top_k=top_k,
|
||||
better_than_threshold=self.cosine_better_than_threshold, ???
|
||||
better_than_threshold=self.cosine_better_than_threshold,
|
||||
)
|
||||
results = [
|
||||
{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
|
||||
@@ -166,8 +112,7 @@ class PineConeVectorDBStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("self._client.save()")
|
||||
# self._client.save()
|
||||
self._client.save()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -298,5 +243,3 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
|
||||
return embeddings, nodes_ids
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user