updated clean of what implemented on BaseVectorStorage

This commit is contained in:
Yannick Stephan
2025-02-16 13:24:42 +01:00
parent ef0e81315f
commit 3eba41aab6
9 changed files with 94 additions and 30 deletions

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Any
import numpy as np import numpy as np
from chromadb import HttpClient, PersistentClient from chromadb import HttpClient, PersistentClient
from chromadb.config import Settings from chromadb.config import Settings
@@ -102,7 +102,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"ChromaDB initialization failed: {str(e)}") logger.error(f"ChromaDB initialization failed: {str(e)}")
raise raise
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data: if not data:
logger.warning("Empty data provided to vector DB") logger.warning("Empty data provided to vector DB")
return [] return []
@@ -151,7 +151,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB upsert: {str(e)}") logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise raise
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
try: try:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
@@ -183,6 +183,15 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB query: {str(e)}") logger.error(f"Error during ChromaDB query: {str(e)}")
raise raise
async def index_done_callback(self):
async def index_done_callback(self) -> None:
# ChromaDB handles persistence automatically # ChromaDB handles persistence automatically
pass pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError

View File

@@ -1,6 +1,7 @@
import os import os
import time import time
import asyncio import asyncio
from typing import Any
import faiss import faiss
import json import json
import numpy as np import numpy as np
@@ -57,7 +58,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Attempt to load an existing index + metadata from disk # Attempt to load an existing index + metadata from disk
self._load_faiss_index() self._load_faiss_index()
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
Insert or update vectors in the Faiss index. Insert or update vectors in the Faiss index.
@@ -147,7 +148,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
""" """
Search by a textual query; returns top_k results with their metadata + similarity distance. Search by a textual query; returns top_k results with their metadata + similarity distance.
""" """
@@ -210,7 +211,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
) )
async def delete_entity(self, entity_name: str): async def delete_entity(self, entity_name: str) -> None:
""" """
Delete a single entity by computing its hashed ID Delete a single entity by computing its hashed ID
the same way your code does it with `compute_mdhash_id`. the same way your code does it with `compute_mdhash_id`.
@@ -234,7 +235,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._remove_faiss_ids(relations) self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}") logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self): async def index_done_callback(self) -> None:
""" """
Called after indexing is done (save Faiss index + metadata). Called after indexing is done (save Faiss index + metadata).
""" """

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@@ -71,7 +72,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
dimension=self.embedding_func.embedding_dim, dimension=self.embedding_func.embedding_dim,
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -106,7 +107,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
results = self._client.upsert(collection_name=self.namespace, data=list_data) results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results return results
async def query(self, query, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.namespace,
@@ -123,3 +124,14 @@ class MilvusVectorDBStorage(BaseVectorStorage):
{**dp["entity"], "id": dp["id"], "distance": dp["distance"]} {**dp["entity"], "id": dp["id"], "distance": dp["distance"]}
for dp in results[0] for dp in results[0]
] ]
async def index_done_callback(self) -> None:
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError

View File

@@ -844,7 +844,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
except PyMongoError as _: except PyMongoError as _:
logger.debug("vector index already exist") logger.debug("vector index already exist")
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"Inserting {len(data)} vectors to {self.namespace}") logger.debug(f"Inserting {len(data)} vectors to {self.namespace}")
if not data: if not data:
logger.warning("You are inserting an empty data set to vector DB") logger.warning("You are inserting an empty data set to vector DB")
@@ -887,7 +887,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
return list_data return list_data
async def query(self, query, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Queries the vector database using Atlas Vector Search.""" """Queries the vector database using Atlas Vector Search."""
# Generate the embedding # Generate the embedding
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
@@ -921,6 +921,16 @@ class MongoVectorDBStorage(BaseVectorStorage):
for doc in results for doc in results
] ]
async def index_done_callback(self) -> None:
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it.""" """Check if the collection exists. if not, create it."""

View File

@@ -50,6 +50,7 @@ Usage:
import asyncio import asyncio
import os import os
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@@ -95,7 +96,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.embedding_func.embedding_dim, storage_file=self._client_file_name self.embedding_func.embedding_dim, storage_file=self._client_file_name
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -139,7 +140,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
) )
async def query(self, query: str, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
results = self._client.query( results = self._client.query(
@@ -176,7 +177,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}") logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str): async def delete_entity(self, entity_name: str) -> None:
try: try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug( logger.debug(
@@ -211,7 +212,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self): async def index_done_callback(self) -> None:
# Protect file write operation # Protect file write operation
async with self._save_lock: async with self._save_lock:
self._client.save() self._client.save()

View File

@@ -307,7 +307,7 @@ class OracleKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
async def index_done_callback(self): async def index_done_callback(self) -> None:
if is_namespace( if is_namespace(
self.namespace, self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
@@ -330,16 +330,14 @@ class OracleVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
"""向向量数据库中插入数据"""
pass pass
async def index_done_callback(self): async def index_done_callback(self):
pass pass
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""从向量数据库中查询数据"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
# 转换精度 # 转换精度
@@ -359,6 +357,13 @@ class OracleVectorDBStorage(BaseVectorStorage):
# print("vector search result:",results) # print("vector search result:",results)
return results return results
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):

View File

@@ -287,7 +287,7 @@ class PGKVStorage(BaseKVStorage):
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
async def index_done_callback(self): async def index_done_callback(self) -> None:
if is_namespace( if is_namespace(
self.namespace, self.namespace,
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
@@ -352,7 +352,7 @@ class PGVectorStorage(BaseVectorStorage):
} }
return upsert_sql, data return upsert_sql, data
async def upsert(self, data: Dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} vectors to {self.namespace}") logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -398,12 +398,11 @@ class PGVectorStorage(BaseVectorStorage):
await self.db.execute(upsert_sql, data) await self.db.execute(upsert_sql, data)
async def index_done_callback(self): async def index_done_callback(self) -> None:
logger.info("vector data had been saved into postgresql db!") logger.info("vector data had been saved into postgresql db!")
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""从向量数据库中查询数据"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding)) embedding_string = ",".join(map(str, embedding))
@@ -417,6 +416,13 @@ class PGVectorStorage(BaseVectorStorage):
results = await self.db.query(sql, params=params, multirows=True) results = await self.db.query(sql, params=params, multirows=True)
return results return results
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@@ -85,7 +86,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
), ),
) )
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
return [] return []
@@ -130,7 +131,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
) )
return results return results
async def query(self, query, top_k=5): async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
results = self._client.search( results = self._client.search(
collection_name=self.namespace, collection_name=self.namespace,
@@ -143,3 +144,14 @@ class QdrantVectorDBStorage(BaseVectorStorage):
logger.debug(f"query result: {results}") logger.debug(f"query result: {results}")
return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results]
async def index_done_callback(self) -> None:
pass
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError

View File

@@ -227,7 +227,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def query(self, query: str, top_k: int) -> list[dict]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Search from tidb vector""" """Search from tidb vector"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
embedding = embeddings[0] embedding = embeddings[0]
@@ -249,7 +249,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
return results return results
###### INSERT entities And relationships ###### ###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
# ignore, upsert in TiDBKVStorage already # ignore, upsert in TiDBKVStorage already
if not len(data): if not len(data):
logger.warning("You insert an empty data to vector DB") logger.warning("You insert an empty data to vector DB")
@@ -333,6 +333,14 @@ class TiDBVectorDBStorage(BaseVectorStorage):
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name"""
raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata"""
raise NotImplementedError
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use # db instance must be injected before use