From 3e820cc68ea08127c821f806eb822c96f4cc21b1 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 14:04:59 +0100 Subject: [PATCH 1/6] fixed default factory --- lightrag/lightrag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5bb05764..990c1bcf 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -184,7 +184,7 @@ class LightRAG: """Maximum number of concurrent embedding function calls.""" embedding_cache_config: dict[str, Any] = field( - default={ + default_factory= lambda: { "enabled": False, "similarity_threshold": 0.95, "use_llm_check": False, From 214e3e8ad5c4d479d73afc3aee72ecdbfd3b0bf3 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 14:12:19 +0100 Subject: [PATCH 2/6] fixed last update --- examples/test_faiss.py | 2 +- lightrag/__init__.py | 2 +- lightrag/kg/networkx_impl.py | 4 ++-- lightrag/lightrag.py | 5 ++--- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/test_faiss.py b/examples/test_faiss.py index ab0ef9f7..c3ac6f47 100644 --- a/examples/test_faiss.py +++ b/examples/test_faiss.py @@ -70,7 +70,7 @@ def main(): ), vector_storage="FaissVectorDBStorage", vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": 0.3 # Your desired threshold + "cosine_better_than_threshold": 0.2 # Your desired threshold }, ) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 025fb73b..99f4052f 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.7" +__version__ = "1.1.10" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 614715c4..853bd369 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -16,12 +16,12 @@ import pipmaster as pm if not pm.is_installed("networkx"): pm.install("networkx") + if not pm.is_installed("graspologic"): pm.install("graspologic") -from graspologic import embed import networkx as nx - +from graspologic import embed @final @dataclass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 990c1bcf..38a6e835 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -738,9 +738,8 @@ class LightRAG: if new_kg is None: logger.info("No new entities or relationships extracted.") else: - async with self._entity_lock: - logger.info("New entities or relationships extracted.") - self.chunk_entity_relation_graph = new_kg + logger.info("New entities or relationships extracted.") + self.chunk_entity_relation_graph = new_kg except Exception as e: logger.error("Failed to extract entities and relationships") From c4562f71b9dcac80fd95b5e5c32dae7d6fba3a67 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 14:17:26 +0100 Subject: [PATCH 3/6] cleanup extraction --- lightrag/kg/networkx_impl.py | 3 ++- lightrag/lightrag.py | 10 ++-------- lightrag/operate.py | 24 ++++++++++++------------ 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 853bd369..1874719f 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -16,13 +16,14 @@ import pipmaster as pm if not pm.is_installed("networkx"): pm.install("networkx") - + if not pm.is_installed("graspologic"): pm.install("graspologic") import networkx as nx from graspologic import embed + @final @dataclass class NetworkXStorage(BaseGraphStorage): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 38a6e835..71784a8b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -184,7 +184,7 @@ class LightRAG: """Maximum number of concurrent embedding function calls.""" embedding_cache_config: dict[str, Any] = field( - default_factory= lambda: { + default_factory=lambda: { "enabled": False, "similarity_threshold": 0.95, "use_llm_check": False, @@ -727,7 +727,7 @@ class LightRAG: async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: - new_kg = await extract_entities( + await extract_entities( chunk, knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, @@ -735,12 +735,6 @@ class LightRAG: llm_response_cache=self.llm_response_cache, global_config=asdict(self), ) - if new_kg is None: - logger.info("No new entities or relationships extracted.") - else: - logger.info("New entities or relationships extracted.") - self.chunk_entity_relation_graph = new_kg - except Exception as e: logger.error("Failed to extract entities and relationships") raise e diff --git a/lightrag/operate.py b/lightrag/operate.py index 27950b7d..a79192ac 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -329,7 +329,7 @@ async def extract_entities( relationships_vdb: BaseVectorStorage, global_config: dict[str, str], llm_response_cache: BaseKVStorage | None = None, -) -> BaseGraphStorage | None: +) -> None: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ @@ -522,16 +522,18 @@ async def extract_entities( ] ) - if not len(all_entities_data) and not len(all_relationships_data): - logger.warning( - "Didn't extract any entities and relationships, maybe your LLM is not working" - ) - return None + if not (all_entities_data or all_relationships_data): + logger.info("Didn't extract any entities and relationships.") + return - if not len(all_entities_data): - logger.warning("Didn't extract any entities") - if not len(all_relationships_data): - logger.warning("Didn't extract any relationships") + if not all_entities_data: + logger.info("Didn't extract any entities") + if not all_relationships_data: + logger.info("Didn't extract any relationships") + + logger.info( + f"New entities or relationships extracted, entities:{all_entities_data}, relationships:{all_relationships_data}" + ) if entity_vdb is not None: data_for_vdb = { @@ -560,8 +562,6 @@ async def extract_entities( } await relationships_vdb.upsert(data_for_vdb) - return knowledge_graph_inst - async def kg_query( query: str, From 439685e69c2a12931fd38ebfd31517ae0c0f5e13 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 14:29:36 +0100 Subject: [PATCH 4/6] Revert "removed get_knowledge_graph" --- lightrag/api/lightrag_server.py | 4 + lightrag/base.py | 7 ++ lightrag/kg/age_impl.py | 6 ++ lightrag/kg/gremlin_impl.py | 6 ++ lightrag/kg/mongo_impl.py | 174 ++++++++++++++++++++++++++++++++ lightrag/kg/neo4j_impl.py | 94 +++++++++++++++++ lightrag/kg/networkx_impl.py | 6 ++ lightrag/kg/oracle_impl.py | 6 ++ lightrag/kg/postgres_impl.py | 6 ++ lightrag/kg/tidb_impl.py | 7 ++ lightrag/lightrag.py | 8 ++ 11 files changed, 324 insertions(+) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 58931eec..96315b82 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1683,6 +1683,10 @@ def create_app(args): raise HTTPException(status_code=500, detail=str(e)) # query all graph + @app.get("/graphs") + async def get_knowledge_graph(label: str): + return await rag.get_knowledge_graph(nodel_label=label, max_depth=100) + # Add Ollama API routes ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") diff --git a/lightrag/base.py b/lightrag/base.py index 5f6f8850..af060435 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -13,6 +13,7 @@ from typing import ( ) import numpy as np from .utils import EmbeddingFunc +from .types import KnowledgeGraph load_dotenv() @@ -197,6 +198,12 @@ class BaseGraphStorage(StorageNameSpace, ABC): ) -> tuple[np.ndarray[Any, Any], list[str]]: """Get all labels in the graph.""" + @abstractmethod + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """Retrieve a subgraph of the knowledge graph starting from a given node.""" + class DocStatus(str, Enum): """Document processing status""" diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 583423bb..077c7321 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, NamedTuple, Optional, Union, final import numpy as np import pipmaster as pm +from lightrag.types import KnowledgeGraph from tenacity import ( retry, @@ -615,6 +616,11 @@ class AGEStorage(BaseGraphStorage): ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + async def index_done_callback(self) -> None: # AGES handles persistence automatically pass diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 45bc1fab..39077b5f 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -16,6 +16,7 @@ from tenacity import ( wait_exponential, ) +from lightrag.types import KnowledgeGraph from lightrag.utils import logger from ..base import BaseGraphStorage @@ -401,3 +402,8 @@ class GremlinStorage(BaseGraphStorage): self, algorithm: str ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index cfae4abd..07b48f8b 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -16,6 +16,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm if not pm.is_installed("pymongo"): @@ -598,6 +599,179 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # QUERY # ------------------------------------------------------------------------- + # + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """ + Get complete connected subgraph for specified node (including the starting node itself) + + Args: + node_label: Label of the nodes to start from + max_depth: Maximum depth of traversal (default: 5) + + Returns: + KnowledgeGraph object containing nodes and edges of the subgraph + """ + label = node_label + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + try: + if label == "*": + # Get all nodes and edges + async for node_doc in self.collection.find({}): + node_id = str(node_doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[node_doc.get("_id")], + properties={ + k: v + for k, v in node_doc.items() + if k not in ["_id", "edges"] + }, + ) + ) + seen_nodes.add(node_id) + + # Process edges + for edge in node_doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + else: + # Verify if starting node exists + start_nodes = self.collection.find({"_id": label}) + start_nodes_exist = await start_nodes.to_list(length=1) + if not start_nodes_exist: + logger.warning(f"Starting node with label {label} does not exist!") + return result + + # Use $graphLookup for traversal + pipeline = [ + { + "$match": {"_id": label} + }, # Start with nodes having the specified label + { + "$graphLookup": { + "from": self._collection_name, + "startWith": "$edges.target", + "connectFromField": "edges.target", + "connectToField": "_id", + "maxDepth": max_depth, + "depthField": "depth", + "as": "connected_nodes", + } + }, + ] + + async for doc in self.collection.aggregate(pipeline): + # Add the start node + node_id = str(doc["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[ + doc.get( + "_id", + ) + ], + properties={ + k: v + for k, v in doc.items() + if k + not in [ + "_id", + "edges", + "connected_nodes", + "depth", + ] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from start node + for edge in doc.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + # Add connected nodes and their edges + for connected in doc.get("connected_nodes", []): + node_id = str(connected["_id"]) + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=node_id, + labels=[connected.get("_id")], + properties={ + k: v + for k, v in connected.items() + if k not in ["_id", "edges", "depth"] + }, + ) + ) + seen_nodes.add(node_id) + + # Add edges from connected nodes + for edge in connected.get("edges", []): + edge_id = f"{node_id}-{edge['target']}" + if edge_id not in seen_edges: + result.edges.append( + KnowledgeGraphEdge( + id=edge_id, + type=edge.get("relation", ""), + source=node_id, + target=edge["target"], + properties={ + k: v + for k, v in edge.items() + if k not in ["target", "relation"] + }, + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except PyMongoError as e: + logger.error(f"MongoDB query failed: {str(e)}") + + return result async def index_done_callback(self) -> None: # Mongo handles persistence automatically diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 9754ffc5..de0273ad 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -17,6 +17,7 @@ from tenacity import ( from ..utils import logger from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm if not pm.is_installed("neo4j"): @@ -468,6 +469,99 @@ class Neo4JStorage(BaseGraphStorage): async def _node2vec_embed(self): print("Implemented but never called.") + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + """ + Get complete connected subgraph for specified node (including the starting node itself) + + Key fixes: + 1. Include the starting node itself + 2. Handle multi-label nodes + 3. Clarify relationship directions + 4. Add depth control + """ + label = node_label.strip('"') + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + + async with self._driver.session(database=self._DATABASE) as session: + try: + main_query = "" + if label == "*": + main_query = """ + MATCH (n) + WITH collect(DISTINCT n) AS nodes + MATCH ()-[r]-() + RETURN nodes, collect(DISTINCT r) AS relationships; + """ + else: + # Critical debug step: first verify if starting node exists + validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" + validate_result = await session.run(validate_query) + if not await validate_result.single(): + logger.warning(f"Starting node {label} does not exist!") + return result + + # Optimized query (including direction handling and self-loops) + main_query = f""" + MATCH (start:`{label}`) + WITH start + CALL apoc.path.subgraphAll(start, {{ + relationshipFilter: '>', + minLevel: 0, + maxLevel: {max_depth}, + bfs: true + }}) + YIELD nodes, relationships + RETURN nodes, relationships + """ + result_set = await session.run(main_query) + record = await result_set.single() + + if record: + # Handle nodes (compatible with multi-label cases) + for node in record["nodes"]: + # Use node ID + label combination as unique identifier + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=list(node.labels), + properties=dict(node), + ) + ) + seen_nodes.add(node_id) + + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + seen_edges.add(edge_id) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except neo4jExceptions.ClientError as e: + logger.error(f"APOC query failed: {str(e)}") + return await self._robust_fallback(label, max_depth) + + return result + async def _robust_fallback( self, label: str, max_depth: int ) -> Dict[str, List[Dict]]: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 1874719f..3e7a08fd 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -5,6 +5,7 @@ from typing import Any, final import numpy as np +from lightrag.types import KnowledgeGraph from lightrag.utils import ( logger, ) @@ -166,3 +167,8 @@ class NetworkXStorage(BaseGraphStorage): for source, target in edges: if self._graph.has_edge(source, target): self._graph.remove_edge(source, target) + + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 35983ad3..d65688da 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -8,6 +8,7 @@ from typing import Any, Union, final import numpy as np import configparser +from lightrag.types import KnowledgeGraph from ..base import ( BaseGraphStorage, @@ -669,6 +670,11 @@ class OracleGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index d7ace41a..a0e0f184 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -7,6 +7,7 @@ from typing import Any, Union, final import numpy as np import configparser +from lightrag.types import KnowledgeGraph import sys from tenacity import ( @@ -1084,6 +1085,11 @@ class PGGraphStorage(BaseGraphStorage): ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + async def drop(self) -> None: """Drop the storage""" drop_sql = SQL_TEMPLATES["drop_vdb_entity"] diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 2feb782a..7ba2cf66 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -5,6 +5,8 @@ from typing import Any, Union, final import numpy as np +from lightrag.types import KnowledgeGraph + from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..namespace import NameSpace, is_namespace @@ -558,6 +560,11 @@ class TiDBGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 5 + ) -> KnowledgeGraph: + raise NotImplementedError + N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 71784a8b..0ba34ef7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -47,6 +47,7 @@ from .utils import ( set_logger, encode_string_by_tiktoken, ) +from .types import KnowledgeGraph # TODO: TO REMOVE @Yannick config = configparser.ConfigParser() @@ -457,6 +458,13 @@ class LightRAG: self._storages_status = StoragesStatus.FINALIZED logger.debug("Finalized Storages") + async def get_knowledge_graph( + self, nodel_label: str, max_depth: int + ) -> KnowledgeGraph: + return await self.chunk_entity_relation_graph.get_knowledge_graph( + node_label=nodel_label, max_depth=max_depth + ) + def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name) From 3647bc9b11588b89647e0d32a12c500d45f16a23 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 14:32:24 +0100 Subject: [PATCH 5/6] updated version to 1.1.11 --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 99f4052f..2a78af9b 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.10" +__version__ = "1.1.11" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 678e0f9aea4dece46afab81ed4e8c2005d1de9f7 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 15:09:43 +0100 Subject: [PATCH 6/6] Revert "Cleanup of code" --- lightrag/api/lightrag_server.py | 5 ++++ lightrag/base.py | 4 +++ lightrag/kg/age_impl.py | 7 ++++++ lightrag/kg/gremlin_impl.py | 3 +++ lightrag/kg/mongo_impl.py | 18 +++++++++++++ lightrag/kg/neo4j_impl.py | 25 +++++++++++++++++++ lightrag/kg/networkx_impl.py | 3 +++ lightrag/kg/oracle_impl.py | 3 +++ lightrag/kg/postgres_impl.py | 11 +++++--- lightrag/kg/tidb_impl.py | 3 +++ lightrag/lightrag.py | 4 +++ .../lightrag_visualizer/graph_visualizer.py | 5 +--- 12 files changed, 84 insertions(+), 7 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 96315b82..0cf1d01e 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1682,6 +1682,11 @@ def create_app(args): trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) + # query all graph labels + @app.get("/graph/label/list") + async def get_graph_labels(): + return await rag.get_graph_labels() + # query all graph @app.get("/graphs") async def get_knowledge_graph(label: str): diff --git a/lightrag/base.py b/lightrag/base.py index af060435..5f6a1bf1 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -198,6 +198,10 @@ class BaseGraphStorage(StorageNameSpace, ABC): ) -> tuple[np.ndarray[Any, Any], list[str]]: """Get all labels in the graph.""" + @abstractmethod + async def get_all_labels(self) -> list[str]: + """Get a knowledge graph of a node.""" + @abstractmethod async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index 077c7321..97b3825d 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -60,6 +60,10 @@ class AGEQueryException(Exception): @final @dataclass class AGEStorage(BaseGraphStorage): + @staticmethod + def load_nx_graph(file_name): + print("no preloading of graph with AGE in production") + def __init__(self, namespace, global_config, embedding_func): super().__init__( namespace=namespace, @@ -616,6 +620,9 @@ class AGEStorage(BaseGraphStorage): ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index 39077b5f..3a26401d 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -403,6 +403,9 @@ class GremlinStorage(BaseGraphStorage): ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 07b48f8b..0048b384 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -601,6 +601,24 @@ class MongoGraphStorage(BaseGraphStorage): # ------------------------------------------------------------------------- # + async def get_all_labels(self) -> list[str]: + """ + Get all existing node _id in the database + Returns: + [id1, id2, ...] # Alphabetically sorted id list + """ + # Use MongoDB's distinct and aggregation to get all unique labels + pipeline = [ + {"$group": {"_id": "$_id"}}, # Group by _id + {"$sort": {"_id": 1}}, # Sort alphabetically + ] + + cursor = self.collection.aggregate(pipeline) + labels = [] + async for doc in cursor: + labels.append(doc["_id"]) + return labels + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index de0273ad..0ddc611d 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -628,6 +628,31 @@ class Neo4JStorage(BaseGraphStorage): await traverse(label, 0) return result + async def get_all_labels(self) -> list[str]: + """ + Get all existing node labels in the database + Returns: + ["Person", "Company", ...] # Alphabetically sorted label list + """ + async with self._driver.session(database=self._DATABASE) as session: + # Method 1: Direct metadata query (Available for Neo4j 4.3+) + # query = "CALL db.labels() YIELD label RETURN label" + + # Method 2: Query compatible with older versions + query = """ + MATCH (n) + WITH DISTINCT labels(n) AS node_labels + UNWIND node_labels AS label + RETURN DISTINCT label + ORDER BY label + """ + + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + return labels + async def delete_node(self, node_id: str) -> None: raise NotImplementedError diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 3e7a08fd..9850b8c4 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -168,6 +168,9 @@ class NetworkXStorage(BaseGraphStorage): if self._graph.has_edge(source, target): self._graph.remove_edge(source, target) + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index d65688da..af2ededb 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -670,6 +670,9 @@ class OracleGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index a0e0f184..cbbd98c7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -178,10 +178,12 @@ class PostgreSQLDB: asyncpg.exceptions.UniqueViolationError, asyncpg.exceptions.DuplicateTableError, ) as e: - if not upsert: - logger.error(f"PostgreSQL, upsert error: {e}") + if upsert: + print("Key value duplicate, but upsert succeeded.") + else: + logger.error(f"Upsert error: {e}") except Exception as e: - logger.error(f"PostgreSQL database, sql:{sql}, data:{data}, error:{e}") + logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") raise @@ -1085,6 +1087,9 @@ class PGGraphStorage(BaseGraphStorage): ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 7ba2cf66..4adb0141 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -560,6 +560,9 @@ class TiDBGraphStorage(BaseGraphStorage): async def delete_node(self, node_id: str) -> None: raise NotImplementedError + async def get_all_labels(self) -> list[str]: + raise NotImplementedError + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0ba34ef7..db61788a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -458,6 +458,10 @@ class LightRAG: self._storages_status = StoragesStatus.FINALIZED logger.debug("Finalized Storages") + async def get_graph_labels(self): + text = await self.chunk_entity_relation_graph.get_all_labels() + return text + async def get_knowledge_graph( self, nodel_label: str, max_depth: int ) -> KnowledgeGraph: diff --git a/lightrag/tools/lightrag_visualizer/graph_visualizer.py b/lightrag/tools/lightrag_visualizer/graph_visualizer.py index 9950041f..8a6f0976 100644 --- a/lightrag/tools/lightrag_visualizer/graph_visualizer.py +++ b/lightrag/tools/lightrag_visualizer/graph_visualizer.py @@ -1,6 +1,6 @@ from typing import Optional, Tuple, Dict, List import numpy as np - +import networkx as nx import pipmaster as pm # Added automatic libraries install using pipmaster @@ -12,10 +12,7 @@ if not pm.is_installed("pyglm"): pm.install("pyglm") if not pm.is_installed("python-louvain"): pm.install("python-louvain") -if not pm.is_installed("networkx"): - pm.install("networkx") -import networkx as nx import moderngl from imgui_bundle import imgui, immapp, hello_imgui import community