From 931c31fa8c2d893572e3c787c1f3fbc9f683eef2 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Sun, 16 Feb 2025 13:55:30 +0100 Subject: [PATCH] cleaned code --- lightrag/base.py | 6 ++++-- lightrag/kg/age_impl.py | 18 +++++++++++------- lightrag/kg/chroma_impl.py | 3 +-- lightrag/kg/gremlin_impl.py | 16 +++++++++------- lightrag/kg/json_kv_impl.py | 2 +- lightrag/kg/milvus_impl.py | 4 ++-- lightrag/kg/mongo_impl.py | 17 +++++++++++------ lightrag/kg/neo4j_impl.py | 9 +++++---- lightrag/kg/networkx_impl.py | 12 ++++++++---- lightrag/kg/oracle_impl.py | 15 +++++++++++---- lightrag/kg/postgres_impl.py | 17 +++++++++++------ lightrag/kg/qdrant_impl.py | 4 ++-- lightrag/kg/redis_impl.py | 8 ++++---- lightrag/kg/tidb_impl.py | 21 +++++++++++++-------- 14 files changed, 93 insertions(+), 59 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 3cc7646d..8e3a7ecf 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -92,6 +92,7 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Query the vector storage and retrieve top_k results.""" raise NotImplementedError @@ -165,7 +166,6 @@ class BaseGraphStorage(StorageNameSpace): """Get all edges connected to a node.""" raise NotImplementedError - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Upsert a node into the graph.""" raise NotImplementedError @@ -194,7 +194,9 @@ class BaseGraphStorage(StorageNameSpace): """Get a knowledge graph of a node.""" raise NotImplementedError - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: """Retrieve a subgraph of the knowledge graph starting from a given node.""" raise NotImplementedError diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index a64e4785..37ab57d7 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -614,12 +614,16 @@ class AGEStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError - - async def get_all_labels(self) -> list[str]: + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: - raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index e32346f9..7e325abd 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -183,7 +183,6 @@ class ChromaVectorDBStorage(BaseVectorStorage): logger.error(f"Error during ChromaDB query: {str(e)}") raise - async def index_done_callback(self) -> None: # ChromaDB handles persistence automatically pass @@ -194,4 +193,4 @@ class ChromaVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 77c627b6..48bf77c8 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -389,14 +389,16 @@ class GremlinStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def embed_nodes( self, algorithm: str - ) -> tuple[np.ndarray[Any, Any], list[str]]: + ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError - - async def get_all_labels(self) -> list[str]: + + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: - raise NotImplementedError \ No newline at end of file diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 5683801f..7d51ae93 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -39,7 +39,7 @@ class JsonKVStorage(BaseKVStorage): ] async def filter_keys(self, keys: set[str]) -> set[str]: - return set(data) - set(self._data.keys()) + return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index d67f03b1..703229c8 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -127,11 +127,11 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: pass - + async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index ce15fe29..463e24d2 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -68,9 +68,9 @@ class MongoKVStorage(BaseKVStorage): return await cursor.to_list() async def filter_keys(self, keys: set[str]) -> set[str]: - cursor = self._data.find({"_id": {"$in": list(data)}}, {"_id": 1}) + cursor = self._data.find({"_id": {"$in": list(keys)}}, {"_id": 1}) existing_ids = {str(x["_id"]) async for x in cursor} - return data - existing_ids + return keys - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -109,7 +109,7 @@ class MongoKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: pass - + async def drop(self) -> None: """Drop the collection""" await self._data.drop() @@ -570,7 +570,9 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: """ Placeholder for demonstration, raises NotImplementedError. """ @@ -600,7 +602,9 @@ class MongoGraphStorage(BaseGraphStorage): labels.append(doc["_id"]) return labels - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -918,7 +922,7 @@ class MongoVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: pass - + async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError @@ -927,6 +931,7 @@ class MongoVectorDBStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): """Check if the collection exists. if not, create it.""" client = MongoClient(uri) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f27a9645..d8e8faa8 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -254,7 +254,6 @@ class Neo4JStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - try: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') @@ -436,7 +435,9 @@ class Neo4JStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) @@ -620,8 +621,8 @@ class Neo4JStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 254bb0ed..109c5827 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -186,7 +186,9 @@ class NetworkXStorage(BaseGraphStorage): else: logger.warning(f"Node {node_id} not found in the graph for deletion.") - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -225,6 +227,8 @@ class NetworkXStorage(BaseGraphStorage): async def get_all_labels(self) -> list[str]: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: - raise NotImplementedError \ No newline at end of file + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 360a4847..74268a67 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -318,6 +318,7 @@ class OracleKVStorage(BaseKVStorage): async def drop(self) -> None: raise NotImplementedError + @dataclass class OracleVectorDBStorage(BaseVectorStorage): # db instance must be injected before use @@ -368,6 +369,7 @@ class OracleVectorDBStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + @dataclass class OracleGraphStorage(BaseGraphStorage): # db instance must be injected before use @@ -452,7 +454,9 @@ class OracleGraphStorage(BaseGraphStorage): await self.db.execute(merge_sql, data) # self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -593,13 +597,16 @@ class OracleGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def get_all_labels(self) -> list[str]: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 47336190..77a42ad1 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -299,6 +299,7 @@ class PGKVStorage(BaseKVStorage): async def drop(self) -> None: raise NotImplementedError + @dataclass class PGVectorStorage(BaseVectorStorage): # db instance must be injected before use @@ -428,6 +429,7 @@ class PGVectorStorage(BaseVectorStorage): """Delete relations for a given entity by scanning metadata""" raise NotImplementedError + @dataclass class PGDocStatusStorage(DocStatusStorage): # db instance must be injected before use @@ -1040,18 +1042,21 @@ class PGGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def embed_nodes( self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: - raise NotImplementedError - - async def get_all_labels(self) -> list[str]: raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: + + async def get_all_labels(self) -> list[str]: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + + NAMESPACE_TABLE_MAP = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 18a50082..eb9582e6 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -147,11 +147,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: pass - + async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index f735c72a..71e39c5c 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -41,12 +41,12 @@ class RedisKVStorage(BaseKVStorage): async def filter_keys(self, keys: set[str]) -> set[str]: pipe = self._redis.pipeline() - for key in data: + for key in keys: pipe.exists(f"{self.namespace}:{key}") results = await pipe.execute() - existing_ids = {data[i] for i, exists in enumerate(results) if exists} - return set(data) - existing_ids + existing_ids = {keys[i] for i, exists in enumerate(results) if exists} + return set(keys) - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: pipe = self._redis.pipeline() @@ -63,4 +63,4 @@ class RedisKVStorage(BaseKVStorage): await self._redis.delete(*keys) async def index_done_callback(self) -> None: - pass \ No newline at end of file + pass diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 44c0d9e7..27850d81 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -127,7 +127,6 @@ class TiDBKVStorage(BaseKVStorage): return await self.db.query(SQL, multirows=True) async def filter_keys(self, keys: set[str]) -> set[str]: - """过滤掉重复内容""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), id_field=namespace_to_id(self.namespace), @@ -211,6 +210,7 @@ class TiDBKVStorage(BaseKVStorage): async def drop(self) -> None: raise NotImplementedError + @dataclass class TiDBVectorDBStorage(BaseVectorStorage): # db instance must be injected before use @@ -335,7 +335,6 @@ class TiDBVectorDBStorage(BaseVectorStorage): params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) - async def delete_entity(self, entity_name: str) -> None: """Delete a single entity by its name""" raise NotImplementedError @@ -343,7 +342,8 @@ class TiDBVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: """Delete relations for a given entity by scanning metadata""" raise NotImplementedError - + + @dataclass class TiDBGraphStorage(BaseGraphStorage): # db instance must be injected before use @@ -420,7 +420,9 @@ class TiDBGraphStorage(BaseGraphStorage): } await self.db.execute(merge_sql, data) - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: if algorithm not in self._node_embed_algorithms: raise ValueError(f"Node embedding algorithm {algorithm} not supported") return await self._node_embed_algorithms[algorithm]() @@ -481,13 +483,16 @@ class TiDBGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError - + async def get_all_labels(self) -> list[str]: - raise NotImplementedError - - async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",