cleaned code

This commit is contained in:
Yannick Stephan
2025-02-16 13:55:30 +01:00
parent 882190a515
commit 931c31fa8c
14 changed files with 93 additions and 59 deletions

View File

@@ -92,6 +92,7 @@ class StorageNameSpace:
class BaseVectorStorage(StorageNameSpace): class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
meta_fields: set[str] = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Query the vector storage and retrieve top_k results.""" """Query the vector storage and retrieve top_k results."""
raise NotImplementedError raise NotImplementedError
@@ -165,7 +166,6 @@ class BaseGraphStorage(StorageNameSpace):
"""Get all edges connected to a node.""" """Get all edges connected to a node."""
raise NotImplementedError raise NotImplementedError
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Upsert a node into the graph.""" """Upsert a node into the graph."""
raise NotImplementedError raise NotImplementedError
@@ -194,7 +194,9 @@ class BaseGraphStorage(StorageNameSpace):
"""Get a knowledge graph of a node.""" """Get a knowledge graph of a node."""
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""Retrieve a subgraph of the knowledge graph starting from a given node.""" """Retrieve a subgraph of the knowledge graph starting from a given node."""
raise NotImplementedError raise NotImplementedError

View File

@@ -614,12 +614,16 @@ class AGEStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -183,7 +183,6 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB query: {str(e)}") logger.error(f"Error during ChromaDB query: {str(e)}")
raise raise
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# ChromaDB handles persistence automatically # ChromaDB handles persistence automatically
pass pass
@@ -194,4 +193,4 @@ class ChromaVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError

View File

@@ -389,14 +389,16 @@ class GremlinStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -39,7 +39,7 @@ class JsonKVStorage(BaseKVStorage):
] ]
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
return set(data) - set(self._data.keys()) return set(keys) - set(self._data.keys())
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
left_data = {k: v for k, v in data.items() if k not in self._data} left_data = {k: v for k, v in data.items() if k not in self._data}

View File

@@ -127,11 +127,11 @@ class MilvusVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
pass pass
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name""" """Delete a single entity by its name"""
raise NotImplementedError raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError

View File

@@ -68,9 +68,9 @@ class MongoKVStorage(BaseKVStorage):
return await cursor.to_list() return await cursor.to_list()
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1})
existing_ids = {str(x["_id"]) async for x in cursor} existing_ids = {str(x["_id"]) async for x in cursor}
return data - existing_ids return keys - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@@ -109,7 +109,7 @@ class MongoKVStorage(BaseKVStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
pass pass
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the collection""" """Drop the collection"""
await self._data.drop() await self._data.drop()
@@ -570,7 +570,9 @@ class MongoGraphStorage(BaseGraphStorage):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# #
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
""" """
Placeholder for demonstration, raises NotImplementedError. Placeholder for demonstration, raises NotImplementedError.
""" """
@@ -600,7 +602,9 @@ class MongoGraphStorage(BaseGraphStorage):
labels.append(doc["_id"]) labels.append(doc["_id"])
return labels return labels
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
@@ -918,7 +922,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
pass pass
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name""" """Delete a single entity by its name"""
raise NotImplementedError raise NotImplementedError
@@ -927,6 +931,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it.""" """Check if the collection exists. if not, create it."""
client = MongoClient(uri) client = MongoClient(uri)

View File

@@ -254,7 +254,6 @@ class Neo4JStorage(BaseGraphStorage):
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
try: try:
entity_name_label_source = source_node_id.strip('"') entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"') entity_name_label_target = target_node_id.strip('"')
@@ -436,7 +435,9 @@ class Neo4JStorage(BaseGraphStorage):
async def _node2vec_embed(self): async def _node2vec_embed(self):
print("Implemented but never called.") print("Implemented but never called.")
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
""" """
Get complete connected subgraph for specified node (including the starting node itself) Get complete connected subgraph for specified node (including the starting node itself)
@@ -620,8 +621,8 @@ class Neo4JStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError raise NotImplementedError

View File

@@ -186,7 +186,9 @@ class NetworkXStorage(BaseGraphStorage):
else: else:
logger.warning(f"Node {node_id} not found in the graph for deletion.") logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -225,6 +227,8 @@ class NetworkXStorage(BaseGraphStorage):
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: async def get_knowledge_graph(
raise NotImplementedError self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError

View File

@@ -318,6 +318,7 @@ class OracleKVStorage(BaseKVStorage):
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -368,6 +369,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -452,7 +454,9 @@ class OracleGraphStorage(BaseGraphStorage):
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
# self._graph.add_edge(source_node_id, target_node_id, **edge_data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -593,13 +597,16 @@ class OracleGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",

View File

@@ -299,6 +299,7 @@ class PGKVStorage(BaseKVStorage):
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -428,6 +429,7 @@ class PGVectorStorage(BaseVectorStorage):
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -1040,18 +1042,21 @@ class PGGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]: ) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def get_all_labels(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: async def get_all_labels(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
NAMESPACE_TABLE_MAP = { NAMESPACE_TABLE_MAP = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",

View File

@@ -147,11 +147,11 @@ class QdrantVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
pass pass
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name""" """Delete a single entity by its name"""
raise NotImplementedError raise NotImplementedError
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError

View File

@@ -41,12 +41,12 @@ class RedisKVStorage(BaseKVStorage):
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for key in data: for key in keys:
pipe.exists(f"{self.namespace}:{key}") pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute() results = await pipe.execute()
existing_ids = {data[i] for i, exists in enumerate(results) if exists} existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(data) - existing_ids return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
@@ -63,4 +63,4 @@ class RedisKVStorage(BaseKVStorage):
await self._redis.delete(*keys) await self._redis.delete(*keys)
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
pass pass

View File

@@ -127,7 +127,6 @@ class TiDBKVStorage(BaseKVStorage):
return await self.db.query(SQL, multirows=True) return await self.db.query(SQL, multirows=True)
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""过滤掉重复内容"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace), table_name=namespace_to_table_name(self.namespace),
id_field=namespace_to_id(self.namespace), id_field=namespace_to_id(self.namespace),
@@ -211,6 +210,7 @@ class TiDBKVStorage(BaseKVStorage):
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -335,7 +335,6 @@ class TiDBVectorDBStorage(BaseVectorStorage):
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
"""Delete a single entity by its name""" """Delete a single entity by its name"""
raise NotImplementedError raise NotImplementedError
@@ -343,7 +342,8 @@ class TiDBVectorDBStorage(BaseVectorStorage):
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete relations for a given entity by scanning metadata""" """Delete relations for a given entity by scanning metadata"""
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use # db instance must be injected before use
@@ -420,7 +420,9 @@ class TiDBGraphStorage(BaseGraphStorage):
} }
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms: if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported") raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]() return await self._node_embed_algorithms[algorithm]()
@@ -481,13 +483,16 @@ class TiDBGraphStorage(BaseGraphStorage):
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
raise NotImplementedError
N_T = { N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",