updated clean of what implemented on BaseVectorStorage
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
from chromadb import HttpClient, PersistentClient
|
||||
from chromadb.config import Settings
|
||||
@@ -102,7 +102,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"ChromaDB initialization failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if not data:
|
||||
logger.warning("Empty data provided to vector DB")
|
||||
return []
|
||||
@@ -151,7 +151,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during ChromaDB upsert: {str(e)}")
|
||||
raise
|
||||
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
try:
|
||||
embedding = await self.embedding_func([query])
|
||||
|
||||
@@ -183,6 +183,15 @@ class ChromaVectorDBStorage(BaseVectorStorage):
|
||||
logger.error(f"Error during ChromaDB query: {str(e)}")
|
||||
raise
|
||||
|
||||
async def index_done_callback(self):
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# ChromaDB handles persistence automatically
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Any
|
||||
import faiss
|
||||
import json
|
||||
import numpy as np
|
||||
@@ -57,7 +58,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
# Attempt to load an existing index + metadata from disk
|
||||
self._load_faiss_index()
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Insert or update vectors in the Faiss index.
|
||||
|
||||
@@ -147,7 +148,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
||||
return [m["__id__"] for m in list_data]
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
||||
"""
|
||||
@@ -210,7 +211,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
|
||||
)
|
||||
|
||||
async def delete_entity(self, entity_name: str):
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""
|
||||
Delete a single entity by computing its hashed ID
|
||||
the same way your code does it with `compute_mdhash_id`.
|
||||
@@ -234,7 +235,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._remove_faiss_ids(relations)
|
||||
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
"""
|
||||
Called after indexing is done (save Faiss index + metadata).
|
||||
"""
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@@ -71,7 +72,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -106,7 +107,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
results = self._client.upsert(collection_name=self.namespace, data=list_data)
|
||||
return results
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
@@ -123,3 +124,14 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
|
||||
for dp in results[0]
|
||||
]
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
@@ -844,7 +844,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
except PyMongoError as _:
|
||||
logger.debug("vector index already exist")
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not data:
|
||||
logger.warning("You are inserting an empty data set to vector DB")
|
||||
@@ -887,7 +887,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
return list_data
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Queries the vector database using Atlas Vector Search."""
|
||||
# Generate the embedding
|
||||
embedding = await self.embedding_func([query])
|
||||
@@ -921,6 +921,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
for doc in results
|
||||
]
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||
"""Check if the collection exists. if not, create it."""
|
||||
|
@@ -50,6 +50,7 @@ Usage:
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@@ -95,7 +96,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -139,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
results = self._client.query(
|
||||
@@ -176,7 +177,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
|
||||
|
||||
async def delete_entity(self, entity_name: str):
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
try:
|
||||
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
||||
logger.debug(
|
||||
@@ -211,7 +212,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relations for {entity_name}: {e}")
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
# Protect file write operation
|
||||
async with self._save_lock:
|
||||
self._client.save()
|
||||
|
@@ -307,7 +307,7 @@ class OracleKVStorage(BaseKVStorage):
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
@@ -330,16 +330,14 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
"""向向量数据库中插入数据"""
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
pass
|
||||
|
||||
async def index_done_callback(self):
|
||||
pass
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
"""从向量数据库中查询数据"""
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
# 转换精度
|
||||
@@ -359,6 +357,13 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# print("vector search result:",results)
|
||||
return results
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
@dataclass
|
||||
class OracleGraphStorage(BaseGraphStorage):
|
||||
|
@@ -287,7 +287,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
|
||||
await self.db.execute(upsert_sql, _data)
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
if is_namespace(
|
||||
self.namespace,
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
@@ -352,7 +352,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
}
|
||||
return upsert_sql, data
|
||||
|
||||
async def upsert(self, data: Dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -398,12 +398,11 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
async def index_done_callback(self):
|
||||
async def index_done_callback(self) -> None:
|
||||
logger.info("vector data had been saved into postgresql db!")
|
||||
|
||||
#################### query method ###############
|
||||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
|
||||
"""从向量数据库中查询数据"""
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
@@ -417,6 +416,13 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
results = await self.db.query(sql, params=params, multirows=True)
|
||||
return results
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
@dataclass
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@@ -85,7 +86,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
),
|
||||
)
|
||||
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
return []
|
||||
@@ -130,7 +131,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
return results
|
||||
|
||||
async def query(self, query, top_k=5):
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
embedding = await self.embedding_func([query])
|
||||
results = self._client.search(
|
||||
collection_name=self.namespace,
|
||||
@@ -143,3 +144,14 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug(f"query result: {results}")
|
||||
|
||||
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
pass
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
@@ -227,7 +227,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
async def query(self, query: str, top_k: int) -> list[dict]:
|
||||
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
|
||||
"""Search from tidb vector"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
@@ -249,7 +249,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
return results
|
||||
|
||||
###### INSERT entities And relationships ######
|
||||
async def upsert(self, data: dict[str, dict]):
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
# ignore, upsert in TiDBKVStorage already
|
||||
if not len(data):
|
||||
logger.warning("You insert an empty data to vector DB")
|
||||
@@ -333,6 +333,14 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
|
||||
async def delete_entity(self, entity_name: str) -> None:
|
||||
"""Delete a single entity by its name"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||
"""Delete relations for a given entity by scanning metadata"""
|
||||
raise NotImplementedError
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
# db instance must be injected before use
|
||||
|
Reference in New Issue
Block a user