From 1ca6837219ea38c512a3cf13504c930c3cddf162 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 12:52:25 +0800 Subject: [PATCH 01/22] Add max nodes limit for graph retrieval of networkX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Set MAX_GRAPH_NODES env var (default 1000) • Change edge type to "RELATED" --- .env.example | 1 + lightrag/api/routers/graph_routes.py | 19 +++++++++++++++++-- lightrag/kg/networkx_impl.py | 14 ++++++++------ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/.env.example b/.env.example index de9b6452..70cb575c 100644 --- a/.env.example +++ b/.env.example @@ -3,6 +3,7 @@ # PORT=9621 # WORKERS=1 # NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances +# MAX_GRAPH_NODES=1000 # Max nodes return from grap retrieval # CORS_ORIGINS=http://localhost:3000,http://localhost:8080 ### Optional SSL Configuration diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index 95a72758..aa1803c2 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -16,12 +16,27 @@ def create_graph_routes(rag, api_key: Optional[str] = None): @router.get("/graph/label/list", dependencies=[Depends(optional_api_key)]) async def get_graph_labels(): - """Get all graph labels""" + """ + Get all graph labels + + Returns: + List[str]: List of graph labels + """ return await rag.get_graph_labels() @router.get("/graphs", dependencies=[Depends(optional_api_key)]) async def get_knowledge_graph(label: str, max_depth: int = 3): - """Get knowledge graph for a specific label""" + """ + Get knowledge graph for a specific label. + Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) + + Args: + label (str): Label to get knowledge graph for + max_depth (int, optional): Maximum depth of graph. Defaults to 3. + + Returns: + Dict[str, List[str]]: Knowledge graph for label + """ return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth) return router diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index f11e9c0e..b1cc45fe 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -24,6 +24,8 @@ from .shared_storage import ( is_multiprocess, ) +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + @final @dataclass @@ -234,6 +236,7 @@ class NetworkXStorage(BaseGraphStorage): ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) + Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Args: node_label: Label of the starting node @@ -269,18 +272,17 @@ class NetworkXStorage(BaseGraphStorage): subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) # Check if number of nodes exceeds max_graph_nodes - max_graph_nodes = 500 - if len(subgraph.nodes()) > max_graph_nodes: + if len(subgraph.nodes()) > MAX_GRAPH_NODES: origin_nodes = len(subgraph.nodes()) node_degrees = dict(subgraph.degree()) top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ - :max_graph_nodes + :MAX_GRAPH_NODES ] top_node_ids = [node[0] for node in top_nodes] - # Create new subgraph with only top nodes + # Create new subgraph and keep nodes only with most degree subgraph = subgraph.subgraph(top_node_ids) logger.info( - f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" + f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})" ) # Add nodes to result @@ -320,7 +322,7 @@ class NetworkXStorage(BaseGraphStorage): result.edges.append( KnowledgeGraphEdge( id=edge_id, - type="DIRECTED", + type="RELATED", source=str(source), target=str(target), properties=edge_data, From 0f1eb42c8dd7e2440f6c4f1c18afbfc37ad2b9c0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 15:39:14 +0800 Subject: [PATCH 02/22] Add node limit and prioritization for knowledge graph retrieval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add MAX_GRAPH_NODES limit from env var • Prioritize nodes by label match & connection --- lightrag/kg/neo4j_impl.py | 56 ++++++++++++++++++++++++++++-------- lightrag/kg/networkx_impl.py | 47 +++++++++++++++++++++++++++--- 2 files changed, 87 insertions(+), 16 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f5c2237a..2fb2c494 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -23,7 +23,7 @@ import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") -from neo4j import ( +from neo4j import ( # type: ignore AsyncGraphDatabase, exceptions as neo4jExceptions, AsyncDriver, @@ -34,6 +34,9 @@ from neo4j import ( config = configparser.ConfigParser() config.read("config.ini", "utf-8") +# 从环境变量获取最大图节点数,默认为1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + @final @dataclass @@ -471,12 +474,17 @@ class Neo4JStorage(BaseGraphStorage): ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes - Key fixes: - 1. Include the starting node itself - 2. Handle multi-label nodes - 3. Clarify relationship directions - 4. Add depth control + Args: + node_label (str): Label of the starting node + max_depth (int, optional): Maximum depth of the graph. Defaults to 5. + Returns: + KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') result = KnowledgeGraph() @@ -485,14 +493,22 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session(database=self._DATABASE) as session: try: - main_query = "" if label == "*": main_query = """ MATCH (n) - WITH collect(DISTINCT n) AS nodes - MATCH ()-[r]-() - RETURN nodes, collect(DISTINCT r) AS relationships; + OPTIONAL MATCH (n)-[r]-() + WITH n, count(r) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect(n) AS nodes + MATCH (a)-[r]->(b) + WHERE a IN nodes AND b IN nodes + RETURN nodes, collect(DISTINCT r) AS relationships """ + result_set = await session.run( + main_query, {"max_nodes": MAX_GRAPH_NODES} + ) + else: # Critical debug step: first verify if starting node exists validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" @@ -512,9 +528,25 @@ class Neo4JStorage(BaseGraphStorage): bfs: true }}) YIELD nodes, relationships - RETURN nodes, relationships + WITH start, nodes, relationships + UNWIND nodes AS node + OPTIONAL MATCH (node)-[r]-() + WITH node, count(r) AS degree, start, nodes, relationships, + CASE + WHEN id(node) = id(start) THEN 2 + WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 + ELSE 0 + END AS priority + ORDER BY priority DESC, degree DESC + LIMIT $max_nodes + WITH collect(node) AS filtered_nodes, nodes, relationships + RETURN filtered_nodes AS nodes, + [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships """ - result_set = await session.run(main_query) + result_set = await session.run( + main_query, {"max_nodes": MAX_GRAPH_NODES} + ) + record = await result_set.single() if record: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index b1cc45fe..462fb832 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -236,7 +236,11 @@ class NetworkXStorage(BaseGraphStorage): ) -> KnowledgeGraph: """ Get complete connected subgraph for specified node (including the starting node itself) - Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes Args: node_label: Label of the starting node @@ -268,14 +272,49 @@ class NetworkXStorage(BaseGraphStorage): logger.warning(f"No nodes found with label {node_label}") return result - # Get subgraph using ego_graph - subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) + # Get subgraph using ego_graph from all matching nodes + combined_subgraph = nx.Graph() + for start_node in nodes_to_explore: + node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) + combined_subgraph = nx.compose(combined_subgraph, node_subgraph) + subgraph = combined_subgraph # Check if number of nodes exceeds max_graph_nodes if len(subgraph.nodes()) > MAX_GRAPH_NODES: origin_nodes = len(subgraph.nodes()) + + # 获取节点度数 node_degrees = dict(subgraph.degree()) - top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ + + # 标记起点节点和直接连接的节点 + start_nodes = set() + direct_connected_nodes = set() + + if node_label != "*" and nodes_to_explore: + # 所有在 nodes_to_explore 中的节点都是起点节点 + start_nodes = set(nodes_to_explore) + + # 获取与所有起点直接连接的节点 + for start_node in start_nodes: + direct_connected_nodes.update(subgraph.neighbors(start_node)) + + # 从直接连接节点中移除起点节点(避免重复) + direct_connected_nodes -= start_nodes + + # 按优先级和度数排序 + def priority_key(node_item): + node, degree = node_item + # 优先级排序:起点(2) > 直接连接(1) > 其他节点(0) + if node in start_nodes: + priority = 2 + elif node in direct_connected_nodes: + priority = 1 + else: + priority = 0 + return (priority, degree) # 先按优先级,再按度数 + + # 排序并选择前MAX_GRAPH_NODES个节点 + top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[ :MAX_GRAPH_NODES ] top_node_ids = [node[0] for node in top_nodes] From 68bf02abb6224ba1212af0ada9da54c23b0b3185 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 16:20:37 +0800 Subject: [PATCH 03/22] refactor: improve graph querying with label substring matching and security fixes --- lightrag/kg/neo4j_impl.py | 26 ++++++++++++++++---------- lightrag/kg/networkx_impl.py | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 2fb2c494..8052b1f7 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -34,7 +34,7 @@ from neo4j import ( # type: ignore config = configparser.ConfigParser() config.read("config.ini", "utf-8") -# 从环境变量获取最大图节点数,默认为1000 +# Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) @@ -473,20 +473,22 @@ class Neo4JStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence + 1. Label matching nodes take precedence (nodes containing the specified label string) 2. Followed by nodes directly connected to the matching nodes 3. Finally, the degree of the nodes Args: - node_label (str): Label of the starting node + node_label (str): String to match in node labels (will match any node containing this string in its label) max_depth (int, optional): Maximum depth of the graph. Defaults to 5. Returns: KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') + # Escape single quotes to prevent injection attacks + escaped_label = label.replace("'", "\\'") result = KnowledgeGraph() seen_nodes = set() seen_edges = set() @@ -510,16 +512,20 @@ class Neo4JStorage(BaseGraphStorage): ) else: - # Critical debug step: first verify if starting node exists - validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1" + validate_query = f""" + MATCH (n) + WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}') + RETURN n LIMIT 1 + """ validate_result = await session.run(validate_query) if not await validate_result.single(): - logger.warning(f"Starting node {label} does not exist!") + logger.warning(f"No nodes containing '{label}' in their labels found!") return result - # Optimized query (including direction handling and self-loops) + # Main query uses partial matching main_query = f""" - MATCH (start:`{label}`) + MATCH (start) + WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') WITH start CALL apoc.path.subgraphAll(start, {{ relationshipFilter: '>', @@ -598,7 +604,7 @@ class Neo4JStorage(BaseGraphStorage): result = {"nodes": [], "edges": []} visited_nodes = set() visited_edges = set() - + async def traverse(current_label: str, current_depth: int): if current_depth > max_depth: return diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 462fb832..92d36fa6 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -235,7 +235,7 @@ class NetworkXStorage(BaseGraphStorage): self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ - Get complete connected subgraph for specified node (including the starting node itself) + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: 1. Label matching nodes take precedence From 465737efed6e0d4b81854d87a142762a9d631b98 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 17:32:25 +0800 Subject: [PATCH 04/22] Fix linting --- lightrag/api/routers/graph_routes.py | 7 ++++++- lightrag/kg/neo4j_impl.py | 8 +++++--- lightrag/kg/networkx_impl.py | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index aa1803c2..e6f894a2 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -27,7 +27,12 @@ def create_graph_routes(rag, api_key: Optional[str] = None): @router.get("/graphs", dependencies=[Depends(optional_api_key)]) async def get_knowledge_graph(label: str, max_depth: int = 3): """ - Get knowledge graph for a specific label. + Retrieve a connected subgraph of nodes where the label includes the specified label. + Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Label matching nodes take precedence + 2. Followed by nodes directly connected to the matching nodes + 3. Finally, the degree of the nodes Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Args: diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 8052b1f7..dccee330 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -475,7 +475,7 @@ class Neo4JStorage(BaseGraphStorage): """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - When reducing the number of nodes, the prioritization criteria are as follows: + When reducing the number of nodes, the prioritization criteria are as follows: 1. Label matching nodes take precedence (nodes containing the specified label string) 2. Followed by nodes directly connected to the matching nodes 3. Finally, the degree of the nodes @@ -519,7 +519,9 @@ class Neo4JStorage(BaseGraphStorage): """ validate_result = await session.run(validate_query) if not await validate_result.single(): - logger.warning(f"No nodes containing '{label}' in their labels found!") + logger.warning( + f"No nodes containing '{label}' in their labels found!" + ) return result # Main query uses partial matching @@ -604,7 +606,7 @@ class Neo4JStorage(BaseGraphStorage): result = {"nodes": [], "edges": []} visited_nodes = set() visited_edges = set() - + async def traverse(current_label: str, current_depth: int): if current_depth > max_depth: return diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 92d36fa6..9601a35e 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -237,7 +237,7 @@ class NetworkXStorage(BaseGraphStorage): """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). - When reducing the number of nodes, the prioritization criteria are as follows: + When reducing the number of nodes, the prioritization criteria are as follows: 1. Label matching nodes take precedence 2. Followed by nodes directly connected to the matching nodes 3. Finally, the degree of the nodes From aa5888042e02625568785cafde8996f0b3d16831 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 2 Mar 2025 23:57:57 +0800 Subject: [PATCH 05/22] Improved file handling and validation for document processing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Enhanced UTF-8 validation for text files • Added content validation checks • Better handling of binary data • Added logging for ignored document IDs • Improved document ID filtering --- lightrag/api/routers/document_routes.py | 28 +++++++++++++++++++++---- lightrag/lightrag.py | 18 +++++++++++++++- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ab5aff96..f7f87c2b 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -215,7 +215,27 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: | ".scss" | ".less" ): - content = file.decode("utf-8") + try: + # Try to decode as UTF-8 + content = file.decode("utf-8") + + # Validate content + if not content or len(content.strip()) == 0: + logger.error(f"Empty content in file: {file_path.name}") + return False + + # Check if content looks like binary data string representation + if content.startswith("b'") or content.startswith('b"'): + logger.error( + f"File {file_path.name} appears to contain binary data representation instead of text" + ) + return False + + except UnicodeDecodeError: + logger.error( + f"File {file_path.name} is not valid UTF-8 encoded text. Please convert it to UTF-8 before processing." + ) + return False case ".pdf": if not pm.is_installed("pypdf2"): pm.install("pypdf2") @@ -229,7 +249,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: case ".docx": if not pm.is_installed("docx"): pm.install("docx") - from docx import Document + from docx import Document # type: ignore from io import BytesIO docx_file = BytesIO(file) @@ -238,7 +258,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: case ".pptx": if not pm.is_installed("pptx"): pm.install("pptx") - from pptx import Presentation + from pptx import Presentation # type: ignore from io import BytesIO pptx_file = BytesIO(file) @@ -250,7 +270,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: case ".xlsx": if not pm.is_installed("openpyxl"): pm.install("openpyxl") - from openpyxl import load_workbook + from openpyxl import load_workbook # type: ignore from io import BytesIO xlsx_file = BytesIO(file) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8d9c1678..daf5c059 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -670,8 +670,24 @@ class LightRAG: all_new_doc_ids = set(new_docs.keys()) # Exclude IDs of documents that are already in progress unique_new_doc_ids = await self.doc_status.filter_keys(all_new_doc_ids) + + # Log ignored document IDs + ignored_ids = [ + doc_id for doc_id in unique_new_doc_ids if doc_id not in new_docs + ] + if ignored_ids: + logger.warning( + f"Ignoring {len(ignored_ids)} document IDs not found in new_docs" + ) + for doc_id in ignored_ids: + logger.warning(f"Ignored document ID: {doc_id}") + # Filter new_docs to only include documents with unique IDs - new_docs = {doc_id: new_docs[doc_id] for doc_id in unique_new_doc_ids} + new_docs = { + doc_id: new_docs[doc_id] + for doc_id in unique_new_doc_ids + if doc_id in new_docs + } if not new_docs: logger.info("No new unique documents were found.") From 11fdb60fe5ef30ec6cb447f6762f8cad1ff67b0b Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 3 Mar 2025 01:30:41 +0800 Subject: [PATCH 06/22] Remove Chinese comments --- lightrag/kg/networkx_impl.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 9601a35e..563fc554 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -283,37 +283,32 @@ class NetworkXStorage(BaseGraphStorage): if len(subgraph.nodes()) > MAX_GRAPH_NODES: origin_nodes = len(subgraph.nodes()) - # 获取节点度数 node_degrees = dict(subgraph.degree()) - # 标记起点节点和直接连接的节点 start_nodes = set() direct_connected_nodes = set() if node_label != "*" and nodes_to_explore: - # 所有在 nodes_to_explore 中的节点都是起点节点 start_nodes = set(nodes_to_explore) - - # 获取与所有起点直接连接的节点 + # Get nodes directly connected to all start nodes for start_node in start_nodes: direct_connected_nodes.update(subgraph.neighbors(start_node)) - # 从直接连接节点中移除起点节点(避免重复) + # Remove start nodes from directly connected nodes (avoid duplicates) direct_connected_nodes -= start_nodes - # 按优先级和度数排序 def priority_key(node_item): node, degree = node_item - # 优先级排序:起点(2) > 直接连接(1) > 其他节点(0) + # Priority order: start(2) > directly connected(1) > other nodes(0) if node in start_nodes: priority = 2 elif node in direct_connected_nodes: priority = 1 else: priority = 0 - return (priority, degree) # 先按优先级,再按度数 + return (priority, degree) - # 排序并选择前MAX_GRAPH_NODES个节点 + # Sort by priority and degree and select top MAX_GRAPH_NODES nodes top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[ :MAX_GRAPH_NODES ] From b07181ca39ca034d165f19ddfbae09a45ae738a3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 3 Mar 2025 01:59:20 +0800 Subject: [PATCH 07/22] Remove duplicated run_with_gunicorn.py from project root --- run_with_gunicorn.py | 203 ------------------------------------------- 1 file changed, 203 deletions(-) delete mode 100755 run_with_gunicorn.py diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py deleted file mode 100755 index 2e4e3cf7..00000000 --- a/run_with_gunicorn.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python -""" -Start LightRAG server with Gunicorn -""" - -import os -import sys -import signal -import pipmaster as pm -from lightrag.api.utils_api import parse_args, display_splash_screen -from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data - - -def check_and_install_dependencies(): - """Check and install required dependencies""" - required_packages = [ - "gunicorn", - "tiktoken", - "psutil", - # Add other required packages here - ] - - for package in required_packages: - if not pm.is_installed(package): - print(f"Installing {package}...") - pm.install(package) - print(f"{package} installed successfully") - - -# Signal handler for graceful shutdown -def signal_handler(sig, frame): - print("\n\n" + "=" * 80) - print("RECEIVED TERMINATION SIGNAL") - print(f"Process ID: {os.getpid()}") - print("=" * 80 + "\n") - - # Release shared resources - finalize_share_data() - - # Exit with success status - sys.exit(0) - - -def main(): - # Check and install dependencies - check_and_install_dependencies() - - # Register signal handlers for graceful shutdown - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command - - # Parse all arguments using parse_args - args = parse_args(is_uvicorn_mode=False) - - # Display startup information - display_splash_screen(args) - - print("🚀 Starting LightRAG with Gunicorn") - print(f"🔄 Worker management: Gunicorn (workers={args.workers})") - print("🔍 Preloading app: Enabled") - print("📝 Note: Using Gunicorn's preload feature for shared data initialization") - print("\n\n" + "=" * 80) - print("MAIN PROCESS INITIALIZATION") - print(f"Process ID: {os.getpid()}") - print(f"Workers setting: {args.workers}") - print("=" * 80 + "\n") - - # Import Gunicorn's StandaloneApplication - from gunicorn.app.base import BaseApplication - - # Define a custom application class that loads our config - class GunicornApp(BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} - self.application = app - super().__init__() - - def load_config(self): - # Define valid Gunicorn configuration options - valid_options = { - "bind", - "workers", - "worker_class", - "timeout", - "keepalive", - "preload_app", - "errorlog", - "accesslog", - "loglevel", - "certfile", - "keyfile", - "limit_request_line", - "limit_request_fields", - "limit_request_field_size", - "graceful_timeout", - "max_requests", - "max_requests_jitter", - } - - # Special hooks that need to be set separately - special_hooks = { - "on_starting", - "on_reload", - "on_exit", - "pre_fork", - "post_fork", - "pre_exec", - "pre_request", - "post_request", - "worker_init", - "worker_exit", - "nworkers_changed", - "child_exit", - } - - # Import and configure the gunicorn_config module - import gunicorn_config - - # Set configuration variables in gunicorn_config, prioritizing command line arguments - gunicorn_config.workers = ( - args.workers if args.workers else int(os.getenv("WORKERS", 1)) - ) - - # Bind configuration prioritizes command line arguments - host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") - port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) - gunicorn_config.bind = f"{host}:{port}" - - # Log level configuration prioritizes command line arguments - gunicorn_config.loglevel = ( - args.log_level.lower() - if args.log_level - else os.getenv("LOG_LEVEL", "info") - ) - - # Timeout configuration prioritizes command line arguments - gunicorn_config.timeout = ( - args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) - ) - - # Keepalive configuration - gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) - - # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ( - "true", - "1", - "yes", - "t", - "on", - ): - gunicorn_config.certfile = ( - args.ssl_certfile - if args.ssl_certfile - else os.getenv("SSL_CERTFILE") - ) - gunicorn_config.keyfile = ( - args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") - ) - - # Set configuration options from the module - for key in dir(gunicorn_config): - if key in valid_options: - value = getattr(gunicorn_config, key) - # Skip functions like on_starting and None values - if not callable(value) and value is not None: - self.cfg.set(key, value) - # Set special hooks - elif key in special_hooks: - value = getattr(gunicorn_config, key) - if callable(value): - self.cfg.set(key, value) - - if hasattr(gunicorn_config, "logconfig_dict"): - self.cfg.set( - "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") - ) - - def load(self): - # Import the application - from lightrag.api.lightrag_server import get_application - - return get_application(args) - - # Create the application - app = GunicornApp("") - - # Force workers to be an integer and greater than 1 for multi-process mode - workers_count = int(args.workers) - if workers_count > 1: - # Set a flag to indicate we're in the main process - os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" - initialize_share_data(workers_count) - else: - initialize_share_data(1) - - # Run the application - print("\nStarting Gunicorn with direct Python API...") - app.run() - - -if __name__ == "__main__": - main() From c21d5744f9f7abb5b2058f8ff4007ef54d7c58e7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 3 Mar 2025 02:05:54 +0800 Subject: [PATCH 08/22] Remove duplicated run_with_gunicorn.py --- run_with_gunicorn.py | 203 ------------------------------------------- 1 file changed, 203 deletions(-) delete mode 100755 run_with_gunicorn.py diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py deleted file mode 100755 index 2e4e3cf7..00000000 --- a/run_with_gunicorn.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python -""" -Start LightRAG server with Gunicorn -""" - -import os -import sys -import signal -import pipmaster as pm -from lightrag.api.utils_api import parse_args, display_splash_screen -from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data - - -def check_and_install_dependencies(): - """Check and install required dependencies""" - required_packages = [ - "gunicorn", - "tiktoken", - "psutil", - # Add other required packages here - ] - - for package in required_packages: - if not pm.is_installed(package): - print(f"Installing {package}...") - pm.install(package) - print(f"{package} installed successfully") - - -# Signal handler for graceful shutdown -def signal_handler(sig, frame): - print("\n\n" + "=" * 80) - print("RECEIVED TERMINATION SIGNAL") - print(f"Process ID: {os.getpid()}") - print("=" * 80 + "\n") - - # Release shared resources - finalize_share_data() - - # Exit with success status - sys.exit(0) - - -def main(): - # Check and install dependencies - check_and_install_dependencies() - - # Register signal handlers for graceful shutdown - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command - - # Parse all arguments using parse_args - args = parse_args(is_uvicorn_mode=False) - - # Display startup information - display_splash_screen(args) - - print("🚀 Starting LightRAG with Gunicorn") - print(f"🔄 Worker management: Gunicorn (workers={args.workers})") - print("🔍 Preloading app: Enabled") - print("📝 Note: Using Gunicorn's preload feature for shared data initialization") - print("\n\n" + "=" * 80) - print("MAIN PROCESS INITIALIZATION") - print(f"Process ID: {os.getpid()}") - print(f"Workers setting: {args.workers}") - print("=" * 80 + "\n") - - # Import Gunicorn's StandaloneApplication - from gunicorn.app.base import BaseApplication - - # Define a custom application class that loads our config - class GunicornApp(BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} - self.application = app - super().__init__() - - def load_config(self): - # Define valid Gunicorn configuration options - valid_options = { - "bind", - "workers", - "worker_class", - "timeout", - "keepalive", - "preload_app", - "errorlog", - "accesslog", - "loglevel", - "certfile", - "keyfile", - "limit_request_line", - "limit_request_fields", - "limit_request_field_size", - "graceful_timeout", - "max_requests", - "max_requests_jitter", - } - - # Special hooks that need to be set separately - special_hooks = { - "on_starting", - "on_reload", - "on_exit", - "pre_fork", - "post_fork", - "pre_exec", - "pre_request", - "post_request", - "worker_init", - "worker_exit", - "nworkers_changed", - "child_exit", - } - - # Import and configure the gunicorn_config module - import gunicorn_config - - # Set configuration variables in gunicorn_config, prioritizing command line arguments - gunicorn_config.workers = ( - args.workers if args.workers else int(os.getenv("WORKERS", 1)) - ) - - # Bind configuration prioritizes command line arguments - host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") - port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) - gunicorn_config.bind = f"{host}:{port}" - - # Log level configuration prioritizes command line arguments - gunicorn_config.loglevel = ( - args.log_level.lower() - if args.log_level - else os.getenv("LOG_LEVEL", "info") - ) - - # Timeout configuration prioritizes command line arguments - gunicorn_config.timeout = ( - args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) - ) - - # Keepalive configuration - gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) - - # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ( - "true", - "1", - "yes", - "t", - "on", - ): - gunicorn_config.certfile = ( - args.ssl_certfile - if args.ssl_certfile - else os.getenv("SSL_CERTFILE") - ) - gunicorn_config.keyfile = ( - args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") - ) - - # Set configuration options from the module - for key in dir(gunicorn_config): - if key in valid_options: - value = getattr(gunicorn_config, key) - # Skip functions like on_starting and None values - if not callable(value) and value is not None: - self.cfg.set(key, value) - # Set special hooks - elif key in special_hooks: - value = getattr(gunicorn_config, key) - if callable(value): - self.cfg.set(key, value) - - if hasattr(gunicorn_config, "logconfig_dict"): - self.cfg.set( - "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") - ) - - def load(self): - # Import the application - from lightrag.api.lightrag_server import get_application - - return get_application(args) - - # Create the application - app = GunicornApp("") - - # Force workers to be an integer and greater than 1 for multi-process mode - workers_count = int(args.workers) - if workers_count > 1: - # Set a flag to indicate we're in the main process - os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" - initialize_share_data(workers_count) - else: - initialize_share_data(1) - - # Run the application - print("\nStarting Gunicorn with direct Python API...") - app.run() - - -if __name__ == "__main__": - main() From 0ea274a30dce6627986a4f40d51912f7c462b0bf Mon Sep 17 00:00:00 2001 From: MdNazishArmanShorthillsAI Date: Mon, 3 Mar 2025 13:53:45 +0530 Subject: [PATCH 09/22] Improved cashing check --- lightrag/lightrag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 4f1ad7dc..04f66adc 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1150,7 +1150,7 @@ class LightRAG: """ if param.mode in ["local", "global", "hybrid"]: response = await kg_query( - query, + query.strip(), self.chunk_entity_relation_graph, self.entities_vdb, self.relationships_vdb, @@ -1171,7 +1171,7 @@ class LightRAG: ) elif param.mode == "naive": response = await naive_query( - query, + query.strip(), self.chunks_vdb, self.text_chunks, param, @@ -1190,7 +1190,7 @@ class LightRAG: ) elif param.mode == "mix": response = await mix_kg_vector_query( - query, + query.strip(), self.chunk_entity_relation_graph, self.entities_vdb, self.relationships_vdb, From 462c27c1672c30f54313f832fac182d1587e8092 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 3 Mar 2025 23:18:41 +0800 Subject: [PATCH 10/22] Refactor logging setup and simplify Gunicorn configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Move logging setup code to utils.py • Provide setup_logger for standalone LightRAG logger intialization --- lightrag/api/gunicorn_config.py | 53 ++------- lightrag/api/lightrag_server.py | 3 + lightrag/lightrag.py | 3 - lightrag/utils.py | 96 +++++++++++++++ run_with_gunicorn.py | 203 -------------------------------- 5 files changed, 112 insertions(+), 246 deletions(-) delete mode 100755 run_with_gunicorn.py diff --git a/lightrag/api/gunicorn_config.py b/lightrag/api/gunicorn_config.py index 7f9b4d58..0594ceae 100644 --- a/lightrag/api/gunicorn_config.py +++ b/lightrag/api/gunicorn_config.py @@ -2,12 +2,15 @@ import os import logging from lightrag.kg.shared_storage import finalize_share_data -from lightrag.api.lightrag_server import LightragPathFilter +from lightrag.utils import setup_logger # Get log directory path from environment variable log_dir = os.getenv("LOG_DIR", os.getcwd()) log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) +# Ensure log directory exists +os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + # Get log file max size and backup count from environment variables log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups @@ -108,6 +111,9 @@ def on_starting(server): except ImportError: print("psutil not installed, skipping memory usage reporting") + # Log the location of the LightRAG log file + print(f"LightRAG log file: {log_file_path}\n") + print("Gunicorn initialization complete, forking workers...\n") @@ -134,51 +140,18 @@ def post_fork(server, worker): Executed after a worker has been forked. This is a good place to set up worker-specific configurations. """ - # Configure formatters - detailed_formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - simple_formatter = logging.Formatter("%(levelname)s: %(message)s") - - def setup_logger(logger_name: str, level: str = "INFO", add_filter: bool = False): - """Set up a logger with console and file handlers""" - logger_instance = logging.getLogger(logger_name) - logger_instance.setLevel(level) - logger_instance.handlers = [] # Clear existing handlers - logger_instance.propagate = False - - # Add console handler - console_handler = logging.StreamHandler() - console_handler.setFormatter(simple_formatter) - console_handler.setLevel(level) - logger_instance.addHandler(console_handler) - - # Add file handler - file_handler = logging.handlers.RotatingFileHandler( - filename=log_file_path, - maxBytes=log_max_bytes, - backupCount=log_backup_count, - encoding="utf-8", - ) - file_handler.setFormatter(detailed_formatter) - file_handler.setLevel(level) - logger_instance.addHandler(file_handler) - - # Add path filter if requested - if add_filter: - path_filter = LightragPathFilter() - logger_instance.addFilter(path_filter) - # Set up main loggers log_level = loglevel.upper() if loglevel else "INFO" - setup_logger("uvicorn", log_level) - setup_logger("uvicorn.access", log_level, add_filter=True) - setup_logger("lightrag", log_level, add_filter=True) + setup_logger("uvicorn", log_level, add_filter=False, log_file_path=log_file_path) + setup_logger( + "uvicorn.access", log_level, add_filter=True, log_file_path=log_file_path + ) + setup_logger("lightrag", log_level, add_filter=True, log_file_path=log_file_path) # Set up lightrag submodule loggers for name in logging.root.manager.loggerDict: if name.startswith("lightrag."): - setup_logger(name, log_level, add_filter=True) + setup_logger(name, log_level, add_filter=True, log_file_path=log_file_path) # Disable uvicorn.error logger uvicorn_error_logger = logging.getLogger("uvicorn.error") diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..693c6a9f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -437,6 +437,9 @@ def configure_logging(): log_dir = os.getenv("LOG_DIR", os.getcwd()) log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + print(f"\nLightRAG log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + # Get log file max size and backup count from environment variables log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 208bdf3e..adcb1029 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -266,9 +266,6 @@ class LightRAG: _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): - os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) - logger.info(f"Logger initialized for working directory: {self.working_dir}") - from lightrag.kg.shared_storage import ( initialize_share_data, ) diff --git a/lightrag/utils.py b/lightrag/utils.py index c86ad8c0..bb1d6fae 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -6,6 +6,7 @@ import io import csv import json import logging +import logging.handlers import os import re from dataclasses import dataclass @@ -68,6 +69,101 @@ logger.setLevel(logging.INFO) logging.getLogger("httpx").setLevel(logging.WARNING) +class LightragPathFilter(logging.Filter): + """Filter for lightrag logger to filter out frequent path access logs""" + + def __init__(self): + super().__init__() + # Define paths to be filtered + self.filtered_paths = ["/documents", "/health", "/webui/"] + + def filter(self, record): + try: + # Check if record has the required attributes for an access log + if not hasattr(record, "args") or not isinstance(record.args, tuple): + return True + if len(record.args) < 5: + return True + + # Extract method, path and status from the record args + method = record.args[1] + path = record.args[2] + status = record.args[4] + + # Filter out successful GET requests to filtered paths + if ( + method == "GET" + and (status == 200 or status == 304) + and path in self.filtered_paths + ): + return False + + return True + except Exception: + # In case of any error, let the message through + return True + + +def setup_logger( + logger_name: str, + level: str = "INFO", + add_filter: bool = False, + log_file_path: str = None, +): + """Set up a logger with console and file handlers + + Args: + logger_name: Name of the logger to set up + level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + add_filter: Whether to add LightragPathFilter to the logger + log_file_path: Path to the log file. If None, will use current directory/lightrag.log + """ + # Configure formatters + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + simple_formatter = logging.Formatter("%(levelname)s: %(message)s") + + # Get log file path + if log_file_path is None: + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag.log")) + + # Ensure log directory exists + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + logger_instance = logging.getLogger(logger_name) + logger_instance.setLevel(level) + logger_instance.handlers = [] # Clear existing handlers + logger_instance.propagate = False + + # Add console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(simple_formatter) + console_handler.setLevel(level) + logger_instance.addHandler(console_handler) + + # Add file handler + file_handler = logging.handlers.RotatingFileHandler( + filename=log_file_path, + maxBytes=log_max_bytes, + backupCount=log_backup_count, + encoding="utf-8", + ) + file_handler.setFormatter(detailed_formatter) + file_handler.setLevel(level) + logger_instance.addHandler(file_handler) + + # Add path filter if requested + if add_filter: + path_filter = LightragPathFilter() + logger_instance.addFilter(path_filter) + + class UnlimitedSemaphore: """A context manager that allows unlimited access.""" diff --git a/run_with_gunicorn.py b/run_with_gunicorn.py deleted file mode 100755 index 2e4e3cf7..00000000 --- a/run_with_gunicorn.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python -""" -Start LightRAG server with Gunicorn -""" - -import os -import sys -import signal -import pipmaster as pm -from lightrag.api.utils_api import parse_args, display_splash_screen -from lightrag.kg.shared_storage import initialize_share_data, finalize_share_data - - -def check_and_install_dependencies(): - """Check and install required dependencies""" - required_packages = [ - "gunicorn", - "tiktoken", - "psutil", - # Add other required packages here - ] - - for package in required_packages: - if not pm.is_installed(package): - print(f"Installing {package}...") - pm.install(package) - print(f"{package} installed successfully") - - -# Signal handler for graceful shutdown -def signal_handler(sig, frame): - print("\n\n" + "=" * 80) - print("RECEIVED TERMINATION SIGNAL") - print(f"Process ID: {os.getpid()}") - print("=" * 80 + "\n") - - # Release shared resources - finalize_share_data() - - # Exit with success status - sys.exit(0) - - -def main(): - # Check and install dependencies - check_and_install_dependencies() - - # Register signal handlers for graceful shutdown - signal.signal(signal.SIGINT, signal_handler) # Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # kill command - - # Parse all arguments using parse_args - args = parse_args(is_uvicorn_mode=False) - - # Display startup information - display_splash_screen(args) - - print("🚀 Starting LightRAG with Gunicorn") - print(f"🔄 Worker management: Gunicorn (workers={args.workers})") - print("🔍 Preloading app: Enabled") - print("📝 Note: Using Gunicorn's preload feature for shared data initialization") - print("\n\n" + "=" * 80) - print("MAIN PROCESS INITIALIZATION") - print(f"Process ID: {os.getpid()}") - print(f"Workers setting: {args.workers}") - print("=" * 80 + "\n") - - # Import Gunicorn's StandaloneApplication - from gunicorn.app.base import BaseApplication - - # Define a custom application class that loads our config - class GunicornApp(BaseApplication): - def __init__(self, app, options=None): - self.options = options or {} - self.application = app - super().__init__() - - def load_config(self): - # Define valid Gunicorn configuration options - valid_options = { - "bind", - "workers", - "worker_class", - "timeout", - "keepalive", - "preload_app", - "errorlog", - "accesslog", - "loglevel", - "certfile", - "keyfile", - "limit_request_line", - "limit_request_fields", - "limit_request_field_size", - "graceful_timeout", - "max_requests", - "max_requests_jitter", - } - - # Special hooks that need to be set separately - special_hooks = { - "on_starting", - "on_reload", - "on_exit", - "pre_fork", - "post_fork", - "pre_exec", - "pre_request", - "post_request", - "worker_init", - "worker_exit", - "nworkers_changed", - "child_exit", - } - - # Import and configure the gunicorn_config module - import gunicorn_config - - # Set configuration variables in gunicorn_config, prioritizing command line arguments - gunicorn_config.workers = ( - args.workers if args.workers else int(os.getenv("WORKERS", 1)) - ) - - # Bind configuration prioritizes command line arguments - host = args.host if args.host != "0.0.0.0" else os.getenv("HOST", "0.0.0.0") - port = args.port if args.port != 9621 else int(os.getenv("PORT", 9621)) - gunicorn_config.bind = f"{host}:{port}" - - # Log level configuration prioritizes command line arguments - gunicorn_config.loglevel = ( - args.log_level.lower() - if args.log_level - else os.getenv("LOG_LEVEL", "info") - ) - - # Timeout configuration prioritizes command line arguments - gunicorn_config.timeout = ( - args.timeout if args.timeout else int(os.getenv("TIMEOUT", 150)) - ) - - # Keepalive configuration - gunicorn_config.keepalive = int(os.getenv("KEEPALIVE", 5)) - - # SSL configuration prioritizes command line arguments - if args.ssl or os.getenv("SSL", "").lower() in ( - "true", - "1", - "yes", - "t", - "on", - ): - gunicorn_config.certfile = ( - args.ssl_certfile - if args.ssl_certfile - else os.getenv("SSL_CERTFILE") - ) - gunicorn_config.keyfile = ( - args.ssl_keyfile if args.ssl_keyfile else os.getenv("SSL_KEYFILE") - ) - - # Set configuration options from the module - for key in dir(gunicorn_config): - if key in valid_options: - value = getattr(gunicorn_config, key) - # Skip functions like on_starting and None values - if not callable(value) and value is not None: - self.cfg.set(key, value) - # Set special hooks - elif key in special_hooks: - value = getattr(gunicorn_config, key) - if callable(value): - self.cfg.set(key, value) - - if hasattr(gunicorn_config, "logconfig_dict"): - self.cfg.set( - "logconfig_dict", getattr(gunicorn_config, "logconfig_dict") - ) - - def load(self): - # Import the application - from lightrag.api.lightrag_server import get_application - - return get_application(args) - - # Create the application - app = GunicornApp("") - - # Force workers to be an integer and greater than 1 for multi-process mode - workers_count = int(args.workers) - if workers_count > 1: - # Set a flag to indicate we're in the main process - os.environ["LIGHTRAG_MAIN_PROCESS"] = "1" - initialize_share_data(workers_count) - else: - initialize_share_data(1) - - # Run the application - print("\nStarting Gunicorn with direct Python API...") - app.run() - - -if __name__ == "__main__": - main() From b26a574f40253ce6a32380b0f29b6d42d75ab0d6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 01:07:34 +0800 Subject: [PATCH 11/22] Deprecate log_level and log_file_path in LightRAG. - Remove log_level from API initialization - Add warnings for deprecated logging params --- README.md | 18 +++++++++++++++--- lightrag/api/lightrag_server.py | 2 -- lightrag/lightrag.py | 25 ++++++++++++++++++++----- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index abc2f8b3..5e8c5a94 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,9 @@ import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import setup_logger + +setup_logger("lightrag", level="INFO") async def initialize_rag(): rag = LightRAG( @@ -344,6 +347,10 @@ from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_i from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import setup_logger + +# Setup log handler for LightRAG +setup_logger("lightrag", level="INFO") async def initialize_rag(): rag = LightRAG( @@ -640,6 +647,9 @@ export NEO4J_URI="neo4j://localhost:7687" export NEO4J_USERNAME="neo4j" export NEO4J_PASSWORD="password" +# Setup logger for LightRAG +setup_logger("lightrag", level="INFO") + # When you launch the project be sure to override the default KG: NetworkX # by specifying kg="Neo4JStorage". @@ -649,8 +659,12 @@ rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model graph_storage="Neo4JStorage", #<-----------override KG default - log_level="DEBUG" #<-----------override log_level default ) + +# Initialize database connections +await rag.initialize_storages() +# Initialize pipeline status for document processing +await initialize_pipeline_status() ``` see test_neo4j.py for a working example. @@ -859,7 +873,6 @@ Valid modes are: | **kv\_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`, `OracleKVStorage` | `JsonKVStorage` | | **vector\_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`, `OracleVectorDBStorage` | `NanoVectorDBStorage` | | **graph\_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`, `Neo4JStorage`, `OracleGraphStorage` | `NetworkXStorage` | -| **log\_level** | | Log level for application runtime | `logging.DEBUG` | | **chunk\_token\_size** | `int` | Maximum token size per chunk when splitting documents | `1200` | | **chunk\_overlap\_token\_size** | `int` | Overlap token size between two chunks when splitting documents | `100` | | **tiktoken\_model\_name** | `str` | Model name for the Tiktoken encoder used to calculate token numbers | `gpt-4o-mini` | @@ -881,7 +894,6 @@ Valid modes are: | **addon\_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"], "insert_batch_size": 10}`: sets example limit, output language, and batch size for document processing | `example_number: all examples, language: English, insert_batch_size: 10` | | **convert\_response\_to\_json\_func** | `callable` | Not used | `convert_response_to_json` | | **embedding\_cache\_config** | `dict` | Configuration for question-answer caching. Contains three parameters:
- `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers.
- `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM.
- `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` | -|**log\_dir** | `str` | Directory to store logs. | `./` | diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 693c6a9f..c91f693f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -329,7 +329,6 @@ def create_app(args): "similarity_threshold": 0.95, "use_llm_check": False, }, - log_level=args.log_level, namespace_prefix=args.namespace_prefix, auto_manage_storages_states=False, ) @@ -359,7 +358,6 @@ def create_app(args): "similarity_threshold": 0.95, "use_llm_check": False, }, - log_level=args.log_level, namespace_prefix=args.namespace_prefix, auto_manage_storages_states=False, ) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 4dacac08..114b5735 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import configparser import os +import warnings from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial @@ -85,14 +86,10 @@ class LightRAG: doc_status_storage: str = field(default="JsonDocStatusStorage") """Storage type for tracking document processing statuses.""" - # Logging + # Logging (Deprecated, use setup_logger in utils.py instead) # --- - log_level: int = field(default=logger.level) - """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" - log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log")) - """Log file path.""" # Entity extraction # --- @@ -270,6 +267,24 @@ class LightRAG: initialize_share_data, ) + # Handle deprecated parameters + kwargs = self.__dict__ + if "log_level" in kwargs: + warnings.warn( + "WARNING: log_level parameter is deprecated, use setup_logger in utils.py instead", + UserWarning, + stacklevel=2, + ) + # Remove the attribute to prevent its use + delattr(self, "log_level") + if "log_file_path" in kwargs: + warnings.warn( + "WARNING: log_file_path parameter is deprecated, use setup_logger in utils.py instead", + UserWarning, + stacklevel=2, + ) + delattr(self, "log_file_path") + initialize_share_data() if not os.path.exists(self.working_dir): From 905699429281c576f9abbd29ae8c247b64bcda29 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 01:28:08 +0800 Subject: [PATCH 12/22] Deprecate and remove logging parameters in LightRAG. - Set log_level and log_file_path to None by default - Issue warnings if deprecated parameters are used - Maintain backward compatibility with warnings --- lightrag/lightrag.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 114b5735..21688b7d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -88,8 +88,8 @@ class LightRAG: # Logging (Deprecated, use setup_logger in utils.py instead) # --- - log_level: int = field(default=logger.level) - log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log")) + log_level: int | None = field(default=None) + log_file_path: str | None = field(default=None) # Entity extraction # --- @@ -268,21 +268,23 @@ class LightRAG: ) # Handle deprecated parameters - kwargs = self.__dict__ - if "log_level" in kwargs: + if self.log_level is not None: warnings.warn( "WARNING: log_level parameter is deprecated, use setup_logger in utils.py instead", UserWarning, stacklevel=2, ) - # Remove the attribute to prevent its use - delattr(self, "log_level") - if "log_file_path" in kwargs: + if self.log_file_path is not None: warnings.warn( "WARNING: log_file_path parameter is deprecated, use setup_logger in utils.py instead", UserWarning, stacklevel=2, ) + + # Remove these attributes to prevent their use + if hasattr(self, "log_level"): + delattr(self, "log_level") + if hasattr(self, "log_file_path"): delattr(self, "log_file_path") initialize_share_data() From 0af774a28f92b0fd6c2ba9ebcd8ce49f697a3eab Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 01:28:39 +0800 Subject: [PATCH 13/22] Fix linting --- lightrag/lightrag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 21688b7d..a2d806b6 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -280,7 +280,7 @@ class LightRAG: UserWarning, stacklevel=2, ) - + # Remove these attributes to prevent their use if hasattr(self, "log_level"): delattr(self, "log_level") From bc9905a06177961b6f0e78f1da967e1b45ecf8cf Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 02:28:09 +0800 Subject: [PATCH 14/22] Fix gensim not compatible wtih numpy and scipy problem - Replace numpy with gensim in requirements.txt - Let gensim choose a correct version of numpy and scipy --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a1a1157e..d9a5c68e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ configparser future # Basic modules -numpy +gensim pipmaster pydantic python-dotenv From 61839f311a566531c038da57a0451272eff1d9c3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 10:00:07 +0800 Subject: [PATCH 15/22] Fix package name checks for docx and pptx modules. - Added type ignore for package checks - Corrected docx pptx package name for new version --- lightrag/api/routers/document_routes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index ab5aff96..39314233 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -217,7 +217,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: ): content = file.decode("utf-8") case ".pdf": - if not pm.is_installed("pypdf2"): + if not pm.is_installed("pypdf2"): # type: ignore pm.install("pypdf2") from PyPDF2 import PdfReader # type: ignore from io import BytesIO @@ -227,7 +227,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: for page in reader.pages: content += page.extract_text() + "\n" case ".docx": - if not pm.is_installed("docx"): + if not pm.is_installed("python-docx"): # type: ignore pm.install("docx") from docx import Document from io import BytesIO @@ -236,7 +236,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: doc = Document(docx_file) content = "\n".join([paragraph.text for paragraph in doc.paragraphs]) case ".pptx": - if not pm.is_installed("pptx"): + if not pm.is_installed("python-pptx"): # type: ignore pm.install("pptx") from pptx import Presentation from io import BytesIO @@ -248,7 +248,7 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: if hasattr(shape, "text"): content += shape.text + "\n" case ".xlsx": - if not pm.is_installed("openpyxl"): + if not pm.is_installed("openpyxl"): # type: ignore pm.install("openpyxl") from openpyxl import load_workbook from io import BytesIO From b12c05ec0a228ad6b2d99fb0c99d2c62131eb5d3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 12:09:00 +0800 Subject: [PATCH 16/22] fix: api server installation missing MANIFEST.in file - Added MANIFEST.in to include webui files - Removed /webui/ endpoint from lightrag_server.py --- MANIFEST.in | 1 + lightrag/api/lightrag_server.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..44c3aff1 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include lightrag/api/webui * diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5f2c437f..8695d6b6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -410,10 +410,6 @@ def create_app(args): name="webui", ) - @app.get("/webui/") - async def webui_root(): - return FileResponse(static_dir / "index.html") - return app From d7f7c07251edf21d8460b9a91ba31c06aad9314e Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 4 Mar 2025 12:19:40 +0800 Subject: [PATCH 17/22] Fix linting --- lightrag/api/lightrag_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8695d6b6..631fa238 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -6,7 +6,6 @@ from fastapi import ( FastAPI, Depends, ) -from fastapi.responses import FileResponse import asyncio import os import logging From 6c8fa9521477b3a9440f640337b060fa46f2f5c8 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 12:25:07 +0800 Subject: [PATCH 18/22] fix demo --- README.md | 21 ++++--- examples/lightrag_azure_openai_demo.py | 58 +++++++++++-------- examples/lightrag_bedrock_demo.py | 4 ++ examples/lightrag_nvidia_demo.py | 2 +- examples/lightrag_openai_compatible_demo.py | 2 +- ..._openai_compatible_demo_embedding_cache.py | 2 +- examples/lightrag_oracle_demo.py | 2 +- examples/lightrag_tidb_demo.py | 2 +- examples/lightrag_zhipu_postgres_demo.py | 2 +- examples/query_keyword_separation_example.py | 2 +- 10 files changed, 58 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 5e8c5a94..f863d9ed 100644 --- a/README.md +++ b/README.md @@ -655,16 +655,19 @@ setup_logger("lightrag", level="INFO") # Note: Default settings use NetworkX # Initialize LightRAG with Neo4J implementation. -rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model - graph_storage="Neo4JStorage", #<-----------override KG default -) +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + graph_storage="Neo4JStorage", #<-----------override KG default + ) -# Initialize database connections -await rag.initialize_storages() -# Initialize pipeline status for document processing -await initialize_pipeline_status() + # Initialize database connections + await rag.initialize_storages() + # Initialize pipeline status for document processing + await initialize_pipeline_status() + + return rag ``` see test_neo4j.py for a working example. diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py index e0840366..c101383d 100644 --- a/examples/lightrag_azure_openai_demo.py +++ b/examples/lightrag_azure_openai_demo.py @@ -81,34 +81,46 @@ asyncio.run(test_funcs()) embedding_dimension = 3072 -rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=embedding_dimension, - max_token_size=8192, - func=embedding_func, - ), -) -rag.initialize_storages() -initialize_pipeline_status() +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), + ) -book1 = open("./book_1.txt", encoding="utf-8") -book2 = open("./book_2.txt", encoding="utf-8") + await rag.initialize_storages() + await initialize_pipeline_status() -rag.insert([book1.read(), book2.read()]) + return rag -query_text = "What are the main themes?" -print("Result (Naive):") -print(rag.query(query_text, param=QueryParam(mode="naive"))) +def main(): + rag = asyncio.run(initialize_rag()) -print("\nResult (Local):") -print(rag.query(query_text, param=QueryParam(mode="local"))) + book1 = open("./book_1.txt", encoding="utf-8") + book2 = open("./book_2.txt", encoding="utf-8") -print("\nResult (Global):") -print(rag.query(query_text, param=QueryParam(mode="global"))) + rag.insert([book1.read(), book2.read()]) -print("\nResult (Hybrid):") -print(rag.query(query_text, param=QueryParam(mode="hybrid"))) + query_text = "What are the main themes?" + + print("Result (Naive):") + print(rag.query(query_text, param=QueryParam(mode="naive"))) + + print("\nResult (Local):") + print(rag.query(query_text, param=QueryParam(mode="local"))) + + print("\nResult (Global):") + print(rag.query(query_text, param=QueryParam(mode="global"))) + + print("\nResult (Hybrid):") + print(rag.query(query_text, param=QueryParam(mode="hybrid"))) + + +if __name__ == "__main__": + main() diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py index 68e9f962..c7f41677 100644 --- a/examples/lightrag_bedrock_demo.py +++ b/examples/lightrag_bedrock_demo.py @@ -53,3 +53,7 @@ def main(): "What are the top themes in this story?", param=QueryParam(mode=mode) ) ) + + +if __name__ == "__main__": + main() diff --git a/examples/lightrag_nvidia_demo.py b/examples/lightrag_nvidia_demo.py index 6de0814c..0e9259bc 100644 --- a/examples/lightrag_nvidia_demo.py +++ b/examples/lightrag_nvidia_demo.py @@ -125,7 +125,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() # reading file with open("./book.txt", "r", encoding="utf-8") as f: diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index 1c4a7a92..d26a8de3 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -77,7 +77,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() with open("./book.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/examples/lightrag_openai_compatible_demo_embedding_cache.py b/examples/lightrag_openai_compatible_demo_embedding_cache.py index 85408f3b..4638219f 100644 --- a/examples/lightrag_openai_compatible_demo_embedding_cache.py +++ b/examples/lightrag_openai_compatible_demo_embedding_cache.py @@ -81,7 +81,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() with open("./book.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 420f1af0..6663f6a1 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -107,7 +107,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() # Extract and Insert into LightRAG storage with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index f167e9cc..52695560 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -87,7 +87,7 @@ async def initialize_rag(): async def main(): try: # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index 304c5f2c..e4a20f26 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -59,7 +59,7 @@ async def initialize_rag(): async def main(): # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func diff --git a/examples/query_keyword_separation_example.py b/examples/query_keyword_separation_example.py index cbfdd930..092330f4 100644 --- a/examples/query_keyword_separation_example.py +++ b/examples/query_keyword_separation_example.py @@ -102,7 +102,7 @@ async def initialize_rag(): # Example function demonstrating the new query_with_separate_keyword_extraction usage async def run_example(): # Initialize RAG instance - rag = asyncio.run(initialize_rag()) + rag = await initialize_rag() book1 = open("./book_1.txt", encoding="utf-8") book2 = open("./book_2.txt", encoding="utf-8") From 23106b81fbaeb9f4ddff9c18881874c458d8ab26 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 12:29:17 +0800 Subject: [PATCH 19/22] fix custom kg demo --- README.md | 70 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index f863d9ed..ed257049 100644 --- a/README.md +++ b/README.md @@ -505,44 +505,58 @@ rag.query_with_separate_keyword_extraction( ```python custom_kg = { + "chunks": [ + { + "content": "Alice and Bob are collaborating on quantum computing research.", + "source_id": "doc-1" + } + ], "entities": [ { - "entity_name": "CompanyA", - "entity_type": "Organization", - "description": "A major technology company", - "source_id": "Source1" + "entity_name": "Alice", + "entity_type": "person", + "description": "Alice is a researcher specializing in quantum physics.", + "source_id": "doc-1" }, { - "entity_name": "ProductX", - "entity_type": "Product", - "description": "A popular product developed by CompanyA", - "source_id": "Source1" + "entity_name": "Bob", + "entity_type": "person", + "description": "Bob is a mathematician.", + "source_id": "doc-1" + }, + { + "entity_name": "Quantum Computing", + "entity_type": "technology", + "description": "Quantum computing utilizes quantum mechanical phenomena for computation.", + "source_id": "doc-1" } ], "relationships": [ { - "src_id": "CompanyA", - "tgt_id": "ProductX", - "description": "CompanyA develops ProductX", - "keywords": "develop, produce", + "src_id": "Alice", + "tgt_id": "Bob", + "description": "Alice and Bob are research partners.", + "keywords": "collaboration research", "weight": 1.0, - "source_id": "Source1" + "source_id": "doc-1" + }, + { + "src_id": "Alice", + "tgt_id": "Quantum Computing", + "description": "Alice conducts research on quantum computing.", + "keywords": "research expertise", + "weight": 1.0, + "source_id": "doc-1" + }, + { + "src_id": "Bob", + "tgt_id": "Quantum Computing", + "description": "Bob researches quantum computing.", + "keywords": "research application", + "weight": 1.0, + "source_id": "doc-1" } - ], - "chunks": [ - { - "content": "ProductX, developed by CompanyA, has revolutionized the market with its cutting-edge features.", - "source_id": "Source1", - }, - { - "content": "PersonA is a prominent researcher at UniversityB, focusing on artificial intelligence and machine learning.", - "source_id": "Source2", - }, - { - "content": "None", - "source_id": "UNKNOWN", - }, - ], + ] } rag.insert_custom_kg(custom_kg) From 0f430ca1a7f6058897c8a1cff098b6630801011c Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 12:42:40 +0800 Subject: [PATCH 20/22] update README.md --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ed257049..57563a1f 100644 --- a/README.md +++ b/README.md @@ -785,7 +785,8 @@ rag.delete_by_doc_id("doc_id") LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. -### Create Entities and Relations +
+ Create Entities and Relations ```python # Create new entity @@ -807,8 +808,10 @@ relation = rag.create_relation("Google", "Gmail", { "weight": 2.0 }) ``` +
-### Edit Entities and Relations +
+ Edit Entities and Relations ```python # Edit an existing entity @@ -830,6 +833,7 @@ updated_relation = rag.edit_relation("Google", "Google Mail", { "weight": 3.0 }) ``` +
All operations are available in both synchronous and asynchronous versions. The asynchronous versions have the prefix "a" (e.g., `acreate_entity`, `aedit_relation`). From fd9f71e0eee26189f19448d04678ff5dc0254524 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 13:22:33 +0800 Subject: [PATCH 21/22] fix delete_by_doc_id --- lightrag/kg/json_kv_impl.py | 9 +++++++++ lightrag/kg/tidb_impl.py | 8 ++++++++ lightrag/lightrag.py | 33 +++++++++++++++++++++++++-------- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 8d707899..c0b61a63 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -44,6 +44,15 @@ class JsonKVStorage(BaseKVStorage): ) write_json(data_dict, self._file_name) + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + async with self._storage_lock: + return dict(self._data) + async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._storage_lock: return self._data.get(id) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 4adb0141..51d1c365 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -174,6 +174,14 @@ class TiDBKVStorage(BaseKVStorage): self.db = None ################ QUERY METHODS ################ + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + async with self._storage_lock: + return dict(self._data) async def get_by_id(self, id: str) -> dict[str, Any] | None: """Fetch doc_full data by id.""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a5d3c94b..b2e9845e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1431,14 +1431,22 @@ class LightRAG: logger.debug(f"Starting deletion for document {doc_id}") - doc_to_chunk_id = doc_id.replace("doc", "chunk") + # 2. Get all chunks related to this document + # Find all chunks where full_doc_id equals the current doc_id + all_chunks = await self.text_chunks.get_all() + related_chunks = { + chunk_id: chunk_data + for chunk_id, chunk_data in all_chunks.items() + if isinstance(chunk_data, dict) + and chunk_data.get("full_doc_id") == doc_id + } - # 2. Get all related chunks - chunks = await self.text_chunks.get_by_id(doc_to_chunk_id) - if not chunks: + if not related_chunks: + logger.warning(f"No chunks found for document {doc_id}") return - chunk_ids = {chunks["full_doc_id"].replace("doc", "chunk")} + # Get all related chunk IDs + chunk_ids = set(related_chunks.keys()) logger.debug(f"Found {len(chunk_ids)} chunks to delete") # 3. Before deleting, check the related entities and relationships for these chunks @@ -1626,9 +1634,18 @@ class LightRAG: logger.warning(f"Document {doc_id} still exists in full_docs") # Verify if chunks have been deleted - remaining_chunks = await self.text_chunks.get_by_id(doc_to_chunk_id) - if remaining_chunks: - logger.warning(f"Found {len(remaining_chunks)} remaining chunks") + all_remaining_chunks = await self.text_chunks.get_all() + remaining_related_chunks = { + chunk_id: chunk_data + for chunk_id, chunk_data in all_remaining_chunks.items() + if isinstance(chunk_data, dict) + and chunk_data.get("full_doc_id") == doc_id + } + + if remaining_related_chunks: + logger.warning( + f"Found {len(remaining_related_chunks)} remaining chunks" + ) # Verify entities and relationships for chunk_id in chunk_ids: From 0679ca4055d36dfd53afcb9ab87ea5d4c056cd31 Mon Sep 17 00:00:00 2001 From: zrguo Date: Tue, 4 Mar 2025 14:20:55 +0800 Subject: [PATCH 22/22] Update neo4j_impl.py --- lightrag/kg/neo4j_impl.py | 92 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index dccee330..fec39138 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -690,8 +690,98 @@ class Neo4JStorage(BaseGraphStorage): labels.append(record["label"]) return labels + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) async def delete_node(self, node_id: str) -> None: - raise NotImplementedError + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + """ + label = await self._ensure_label(node_id) + + async def _do_delete(tx: AsyncManagedTransaction): + query = f""" + MATCH (n:`{label}`) + DETACH DELETE n + """ + await tx.run(query) + logger.debug(f"Deleted node with label '{label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + for node in nodes: + await self.delete_node(node) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + for source, target in edges: + source_label = await self._ensure_label(source) + target_label = await self._ensure_label(target) + + async def _do_delete_edge(tx: AsyncManagedTransaction): + query = f""" + MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + DELETE r + """ + await tx.run(query) + logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise async def embed_nodes( self, algorithm: str