updated clean of what implemented on BaseVectorStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:24:42 +01:00
parent ef0e81315f
commit 3eba41aab6
9 changed files with 94 additions and 30 deletions

View File

@@ -1,6 +1,6 @@
import asyncio
from dataclasses import dataclass
from typing import Union
from typing import Any
import numpy as np
from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings
@@ -102,7 +102,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"ChromaDB initialization failed: {str(e)}")
raise
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data:
logger.warning("Empty data provided to vector DB")
return []
@@ -151,7 +151,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
try:
embedding = await self.embedding_func([query])
@@ -183,6 +183,15 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB query: {str(e)}")
raise
async def index_done_callback(self):
async def index_done_callback(self) -> None:
# ChromaDB handles persistence automatically
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError

View File

@@ -1,6 +1,7 @@
import os
import time
import asyncio
from typing import Any
import faiss
import json
import numpy as np
@@ -57,7 +58,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""
Insert or update vectors in the Faiss index.
@@ -147,7 +148,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k=5):
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""
Search by a textual query; returns top_k results with their metadata + similarity distance.
"""
@@ -210,7 +211,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
async def delete_entity(self, entity_name: str):
async def delete_entity(self, entity_name: str) -> None:
"""
Delete a single entity by computing its hashed ID
the same way your code does it with `compute_mdhash_id`.
@@ -234,7 +235,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self):
async def index_done_callback(self) -> None:
"""
Called after indexing is done (save Faiss index + metadata).
"""

View File

@@ -1,5 +1,6 @@
import asyncio
import os
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
@@ -71,7 +72,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
dimension=self.embedding_func.embedding_dim,
)
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
@@ -106,7 +107,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results
async def query(self, query, top_k=5):
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.namespace,
@@ -123,3 +124,14 @@ class MilvusVectorDBStorage(BaseVectorStorage):
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
for dp in results[0]
]
async def index_done_callback(self) -> None:
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError

View File

@@ -844,7 +844,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
except PyMongoError as _:
logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
if not data:
logger.warning("You are inserting an empty data set to vector DB")
@@ -887,7 +887,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data
async def query(self, query, top_k=5):
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search."""
# Generate the embedding
embedding = await self.embedding_func([query])
@@ -921,6 +921,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
for doc in results
]
async def index_done_callback(self) -> None:
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it."""

View File

@@ -50,6 +50,7 @@ Usage:
import asyncio
import os
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
@@ -95,7 +96,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.embedding_func.embedding_dim, storage_file=self._client_file_name
)
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
@@ -139,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(self, query: str, top_k=5):
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query])
embedding = embedding[0]
results = self._client.query(
@@ -176,7 +177,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str):
async def delete_entity(self, entity_name: str) -> None:
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
@@ -211,7 +212,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self):
async def index_done_callback(self) -> None:
# Protect file write operation
async with self._save_lock:
self._client.save()

View File

@@ -307,7 +307,7 @@ class OracleKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self):
async def index_done_callback(self) -> None:
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
@@ -330,16 +330,14 @@ class OracleVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据"""
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pass
async def index_done_callback(self):
pass
#################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
"""从向量数据库中查询数据"""
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
# 转换精度
@@ -359,6 +357,13 @@ class OracleVectorDBStorage(BaseVectorStorage):
# print("vector search result:",results)
return results
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass
class OracleGraphStorage(BaseGraphStorage):

View File

@@ -287,7 +287,7 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data)
async def index_done_callback(self):
async def index_done_callback(self) -> None:
if is_namespace(
self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
@@ -352,7 +352,7 @@ class PGVectorStorage(BaseVectorStorage):
}
return upsert_sql, data
async def upsert(self, data: Dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
@@ -398,12 +398,11 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data)
async def index_done_callback(self):
async def index_done_callback(self) -> None:
logger.info("vector data had been saved into postgresql db!")
#################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
"""从向量数据库中查询数据"""
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding))
@@ -417,6 +416,13 @@ class PGVectorStorage(BaseVectorStorage):
results = await self.db.query(sql, params=params, multirows=True)
return results
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass
class PGDocStatusStorage(DocStatusStorage):

View File

@@ -1,5 +1,6 @@
import asyncio
import os
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import numpy as np
@@ -85,7 +86,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
),
)
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
@@ -130,7 +131,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
)
return results
async def query(self, query, top_k=5):
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query])
results = self._client.search(
collection_name=self.namespace,
@@ -143,3 +144,14 @@ class QdrantVectorDBStorage(BaseVectorStorage):
logger.debug(f"query result: {results}")
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
async def index_done_callback(self) -> None:
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError

View File

@@ -227,7 +227,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def query(self, query: str, top_k: int) -> list[dict]:
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Search from tidb vector"""
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
@@ -249,7 +249,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
return results
###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict]):
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
# ignore, upsert in TiDBKVStorage already
if not len(data):
logger.warning("You insert an empty data to vector DB")
@@ -333,6 +333,14 @@ class TiDBVectorDBStorage(BaseVectorStorage):
return await self.db.query(SQL, params, multirows=True)
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use