From 6e3b23069c0a76a5bfa7e27189faac57ff7d0691 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 7 Mar 2025 16:43:18 +0800 Subject: [PATCH 01/33] - Remove useless `_label_exists` method --- lightrag/kg/neo4j_impl.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index fec39138..2498341d 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -164,23 +164,13 @@ class Neo4JStorage(BaseGraphStorage): # Noe4J handles persistence automatically pass - async def _label_exists(self, label: str) -> bool: - """Check if a label exists in the Neo4j database.""" - query = "CALL db.labels() YIELD label RETURN label" - try: - async with self._driver.session(database=self._DATABASE) as session: - result = await session.run(query) - labels = [record["label"] for record in await result.data()] - return label in labels - except Exception as e: - logger.error(f"Error checking label existence: {e}") - return False - async def _ensure_label(self, label: str) -> str: - """Ensure a label exists by validating it.""" + """Ensure a label is valid + + Args: + label: The label to validate + """ clean_label = label.strip('"') - if not await self._label_exists(clean_label): - logger.warning(f"Label '{clean_label}' does not exist in Neo4j") return clean_label async def has_node(self, node_id: str) -> bool: @@ -290,7 +280,7 @@ class Neo4JStorage(BaseGraphStorage): if record: try: result = dict(record["edge_properties"]) - logger.info(f"Result: {result}") + logger.debug(f"Result: {result}") # Ensure required keys exist with defaults required_keys = { "weight": 0.0, From 0ee2e7fd4800050ef2d1819c157a196ed66cf4fa Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 7 Mar 2025 16:56:48 +0800 Subject: [PATCH 02/33] Suppress Neo4j warning logs by setting logger level. --- lightrag/kg/neo4j_impl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 2498341d..265c0347 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -15,6 +15,7 @@ from tenacity import ( retry_if_exception_type, ) +import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge @@ -37,6 +38,8 @@ config.read("config.ini", "utf-8") # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) +# Set neo4j logger level to ERROR to suppress warning logs +logging.getLogger("neo4j").setLevel(logging.ERROR) @final @dataclass From af803f4e7ad3267fcd184fd6c3914b4c6b2c6bef Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 01:20:36 +0800 Subject: [PATCH 03/33] Refactor Neo4J graph query with min_degree an inclusive match support --- lightrag/kg/neo4j_impl.py | 434 ++++++++++++++++++++++++-------------- 1 file changed, 275 insertions(+), 159 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 265c0347..f6567249 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -41,6 +41,7 @@ MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) # Set neo4j logger level to ERROR to suppress warning logs logging.getLogger("neo4j").setLevel(logging.ERROR) + @final @dataclass class Neo4JStorage(BaseGraphStorage): @@ -63,19 +64,25 @@ class Neo4JStorage(BaseGraphStorage): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=800), + config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=60.0), + config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=60.0), + config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 + ), + ) + MAX_TRANSACTION_RETRY_TIME = float( + os.environ.get( + "NEO4J_MAX_TRANSACTION_RETRY_TIME", + config.get("neo4j", "max_transaction_retry_time", fallback=30.0), ), ) DATABASE = os.environ.get( @@ -88,6 +95,7 @@ class Neo4JStorage(BaseGraphStorage): max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, connection_timeout=CONNECTION_TIMEOUT, connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, + max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, ) # Try to connect to the database @@ -169,21 +177,24 @@ class Neo4JStorage(BaseGraphStorage): async def _ensure_label(self, label: str) -> str: """Ensure a label is valid - + Args: label: The label to validate """ clean_label = label.strip('"') + if not clean_label: + raise ValueError("Neo4j: Label cannot be empty") return clean_label async def has_node(self, node_id: str) -> bool: entity_name_label = await self._ensure_label(node_id) - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) result = await session.run(query) single_result = await result.single() + await result.consume() # Ensure result is fully consumed logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" ) @@ -193,13 +204,14 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" ) result = await session.run(query) single_result = await result.single() + await result.consume() # Ensure result is fully consumed logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" ) @@ -215,13 +227,16 @@ class Neo4JStorage(BaseGraphStorage): dict: Node properties if found None: If node not found """ - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) - record = await result.single() - if record: - node = record["n"] + records = await result.fetch(2) # Get up to 2 records to check for duplicates + await result.consume() # Ensure result is fully consumed + if len(records) > 1: + logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") + if records: + node = records[0]["n"] node_dict = dict(node) logger.debug( f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" @@ -230,23 +245,40 @@ class Neo4JStorage(BaseGraphStorage): return None async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + """ entity_name_label = node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = f""" MATCH (n:`{entity_name_label}`) - RETURN COUNT{{ (n)--() }} AS totalEdgeCount + OPTIONAL MATCH (n)-[r]-() + RETURN n, COUNT(r) AS degree """ result = await session.run(query) - record = await result.single() - if record: - edge_count = record["totalEdgeCount"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}" - ) - return edge_count - else: - return None + records = await result.fetch(100) + await result.consume() # Ensure result is fully consumed + + if not records: + logger.warning(f"No node found with label '{entity_name_label}'") + return 0 + + if len(records) > 1: + logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") + + degree = records[0]["degree"] + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" + ) + return degree async def edge_degree(self, src_id: str, tgt_id: str) -> int: entity_name_label_source = src_id.strip('"') @@ -264,6 +296,31 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees + async def check_duplicate_nodes(self) -> list[tuple[str, int]]: + """Find all labels that have multiple nodes + + Returns: + list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes + """ + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + MATCH (n) + WITH labels(n) as nodeLabels + UNWIND nodeLabels as label + WITH label, count(*) as node_count + WHERE node_count > 1 + RETURN label, node_count + ORDER BY node_count DESC + """ + result = await session.run(query) + duplicates = [] + async for record in result: + label = record["label"] + count = record["node_count"] + logger.info(f"Found {count} nodes with label: {label}") + duplicates.append((label, count)) + return duplicates + async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -271,18 +328,21 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: query = f""" - MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`) + MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties - LIMIT 1 """ result = await session.run(query) - record = await result.single() - if record: + records = await result.fetch(2) # Get up to 2 records to check for duplicates + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: try: - result = dict(record["edge_properties"]) + result = dict(records[0]["edge_properties"]) logger.debug(f"Result: {result}") # Ensure required keys exist with defaults required_keys = { @@ -349,24 +409,27 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: results = await session.run(query) edges = [] - async for record in results: - source_node = record["n"] - connected_node = record["connected"] + try: + async for record in results: + source_node = record["n"] + connected_node = record["connected"] - source_label = ( - list(source_node.labels)[0] if source_node.labels else None - ) - target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None - ) + source_label = ( + list(source_node.labels)[0] if source_node.labels else None + ) + target_label = ( + list(connected_node.labels)[0] + if connected_node and connected_node.labels + else None + ) - if source_label and target_label: - edges.append((source_label, target_label)) + if source_label and target_label: + edges.append((source_label, target_label)) + finally: + await results.consume() # Ensure results are consumed even if processing fails return edges @@ -427,30 +490,46 @@ class Neo4JStorage(BaseGraphStorage): ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. + Checks if both source and target nodes exist before creating the edge. Args: source_node_id (str): Label of the source node (used as identifier) target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge + + Raises: + ValueError: If either source or target node does not exist """ source_label = await self._ensure_label(source_node_id) target_label = await self._ensure_label(target_node_id) edge_properties = edge_data + # Check if both nodes exist + source_exists = await self.has_node(source_label) + target_exists = await self.has_node(target_label) + + if not source_exists: + raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") + if not target_exists: + raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") + async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" MATCH (source:`{source_label}`) WITH source MATCH (target:`{target_label}`) - MERGE (source)-[r:DIRECTED]->(target) + MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r """ result = await tx.run(query, properties=edge_properties) - record = await result.single() - logger.debug( - f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" - ) + try: + record = await result.single() + logger.debug( + f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" + ) + finally: + await result.consume() # Ensure result is consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -463,145 +542,179 @@ class Neo4JStorage(BaseGraphStorage): print("Implemented but never called.") async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 3, + min_degree: int = 0, + inclusive: bool = False, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000). When reducing the number of nodes, the prioritization criteria are as follows: - 1. Label matching nodes take precedence (nodes containing the specified label string) - 2. Followed by nodes directly connected to the matching nodes - 3. Finally, the degree of the nodes + 1. min_degree does not affect nodes directly connected to the matching nodes + 2. Label matching nodes take precedence + 3. Followed by nodes directly connected to the matching nodes + 4. Finally, the degree of the nodes Args: - node_label (str): String to match in node labels (will match any node containing this string in its label) - max_depth (int, optional): Maximum depth of the graph. Defaults to 5. + node_label: Label of the starting node + max_depth: Maximum depth of the subgraph + min_degree: Minimum degree of nodes to include. Defaults to 0 + inclusive: Do an inclusive search if true Returns: KnowledgeGraph: Complete connected subgraph for specified node """ label = node_label.strip('"') - # Escape single quotes to prevent injection attacks - escaped_label = label.replace("'", "\\'") result = KnowledgeGraph() seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: try: if label == "*": main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() WITH n, count(r) AS degree + WHERE degree >= $min_degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect(n) AS nodes - MATCH (a)-[r]->(b) - WHERE a IN nodes AND b IN nodes - RETURN nodes, collect(DISTINCT r) AS relationships + WITH collect({node: n}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree}, ) else: - validate_query = f""" - MATCH (n) - WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}') - RETURN n LIMIT 1 - """ - validate_result = await session.run(validate_query) - if not await validate_result.single(): - logger.warning( - f"No nodes containing '{label}' in their labels found!" - ) - return result - # Main query uses partial matching - main_query = f""" + main_query = """ MATCH (start) - WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}') + WHERE any(label IN labels(start) WHERE + CASE + WHEN $inclusive THEN label CONTAINS $label + ELSE label = $label + END + ) WITH start - CALL apoc.path.subgraphAll(start, {{ - relationshipFilter: '>', + CALL apoc.path.subgraphAll(start, { + relationshipFilter: '', minLevel: 0, - maxLevel: {max_depth}, + maxLevel: $max_depth, bfs: true - }}) + }) YIELD nodes, relationships WITH start, nodes, relationships UNWIND nodes AS node OPTIONAL MATCH (node)-[r]-() - WITH node, count(r) AS degree, start, nodes, relationships, - CASE - WHEN id(node) = id(start) THEN 2 - WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1 - ELSE 0 - END AS priority - ORDER BY priority DESC, degree DESC + WITH node, count(r) AS degree, start, nodes, relationships + WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree + ORDER BY + CASE + WHEN node = start THEN 3 + WHEN EXISTS((start)--(node)) THEN 2 + ELSE 1 + END DESC, + degree DESC LIMIT $max_nodes - WITH collect(node) AS filtered_nodes, nodes, relationships - RETURN filtered_nodes AS nodes, - [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships + WITH collect({node: node}) AS filtered_nodes + UNWIND filtered_nodes AS node_info + WITH collect(node_info.node) AS kept_nodes, filtered_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN filtered_nodes AS node_info, + collect(DISTINCT r) AS relationships """ result_set = await session.run( - main_query, {"max_nodes": MAX_GRAPH_NODES} + main_query, + { + "max_nodes": MAX_GRAPH_NODES, + "label": label, + "inclusive": inclusive, + "max_depth": max_depth, + "min_degree": min_degree, + }, ) - record = await result_set.single() + try: + record = await result_set.single() - if record: - # Handle nodes (compatible with multi-label cases) - for node in record["nodes"]: - # Use node ID + label combination as unique identifier - node_id = node.id - if node_id not in seen_nodes: - result.nodes.append( - KnowledgeGraphNode( - id=f"{node_id}", - labels=list(node.labels), - properties=dict(node), + if record: + # Handle nodes (compatible with multi-label cases) + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=list(node.labels), + properties=dict(node), + ) ) - ) - seen_nodes.add(node_id) + seen_nodes.add(node_id) - # Handle relationships (including direction information) - for rel in record["relationships"]: - edge_id = rel.id - if edge_id not in seen_edges: - start = rel.start_node - end = rel.end_node - result.edges.append( - KnowledgeGraphEdge( - id=f"{edge_id}", - type=rel.type, - source=f"{start.id}", - target=f"{end.id}", - properties=dict(rel), + # Handle relationships (including direction information) + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) ) - ) - seen_edges.add(edge_id) + seen_edges.add(edge_id) - logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" - ) + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + finally: + await result_set.consume() # Ensure result set is consumed except neo4jExceptions.ClientError as e: - logger.error(f"APOC query failed: {str(e)}") - return await self._robust_fallback(label, max_depth) + logger.warning( + f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation" + ) + if inclusive: + logger.warning( + "Inclusive search mode is not supported in recursive query, using exact matching" + ) + return await self._robust_fallback(label, max_depth, min_degree) return result async def _robust_fallback( - self, label: str, max_depth: int + self, label: str, max_depth: int, min_degree: int = 0 ) -> Dict[str, List[Dict]]: - """Enhanced fallback query solution""" + """ + Fallback implementation when APOC plugin is not available or incompatible. + This method implements the same functionality as get_knowledge_graph but uses + only basic Cypher queries and recursive traversal instead of APOC procedures. + """ result = {"nodes": [], "edges": []} visited_nodes = set() visited_edges = set() async def traverse(current_label: str, current_depth: int): + # Check traversal limits if current_depth > max_depth: + logger.debug(f"Reached max depth: {max_depth}") + return + if len(visited_nodes) >= MAX_GRAPH_NODES: + logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") return # Get current node details @@ -614,46 +727,46 @@ class Neo4JStorage(BaseGraphStorage): return visited_nodes.add(node_id) - # Add node data (with complete labels) - node_data = {k: v for k, v in node.items()} - node_data["labels"] = [ - current_label - ] # Assume get_node method returns label information - result["nodes"].append(node_data) + # Add node data with label as ID + result["nodes"].append({ + "id": current_label, + "labels": current_label, + "properties": node + }) - # Get all outgoing and incoming edges + # Get connected nodes that meet the degree requirement + # Note: We don't need to check a's degree since it's the current node + # and was already validated in the previous iteration query = f""" - MATCH (a)-[r]-(b) - WHERE a:`{current_label}` OR b:`{current_label}` - RETURN a, r, b, - CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction + MATCH (a:`{current_label}`)-[r]-(b) + WITH r, b, + COUNT((b)--()) AS b_degree + WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) + RETURN r, b """ - async with self._driver.session(database=self._DATABASE) as session: - results = await session.run(query) + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + results = await session.run(query, {"min_degree": min_degree}) async for record in results: # Handle edges rel = record["r"] edge_id = f"{rel.id}_{rel.type}" if edge_id not in visited_edges: - edge_data = dict(rel) - edge_data.update( - { - "source": list(record["a"].labels)[0], - "target": list(record["b"].labels)[0], + b_node = record["b"] + if b_node.labels: # Only process if target node has labels + target_label = list(b_node.labels)[0] + result["edges"].append({ + "id": f"{current_label}_{target_label}", "type": rel.type, - "direction": record["direction"], - } - ) - result["edges"].append(edge_data) - visited_edges.add(edge_id) + "source": current_label, + "target": target_label, + "properties": dict(rel) + }) + visited_edges.add(edge_id) - # Recursively traverse adjacent nodes - next_label = ( - list(record["b"].labels)[0] - if record["direction"] == "OUTGOING" - else list(record["a"].labels)[0] - ) - await traverse(next_label, current_depth + 1) + # Continue traversal + await traverse(target_label, current_depth + 1) + else: + logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") await traverse(label, 0) return result @@ -664,7 +777,7 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ - async with self._driver.session(database=self._DATABASE) as session: + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" @@ -679,8 +792,11 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) labels = [] - async for record in result: - labels.append(record["label"]) + try: + async for record in result: + labels.append(record["label"]) + finally: + await result.consume() # Ensure results are consumed even if processing fails return labels @retry( @@ -763,7 +879,7 @@ class Neo4JStorage(BaseGraphStorage): async def _do_delete_edge(tx: AsyncManagedTransaction): query = f""" - MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`) + MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) DELETE r """ await tx.run(query) From c07b592e1bfe73cde40c46f46e06f1dc9c3ae292 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 02:39:51 +0800 Subject: [PATCH 04/33] Add missing await consume --- lightrag/kg/neo4j_impl.py | 250 ++++++++++++++++++++------------------ 1 file changed, 130 insertions(+), 120 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index f6567249..ea316d0f 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -64,19 +64,19 @@ class Neo4JStorage(BaseGraphStorage): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=50), # Reduced from 800 + config.get("neo4j", "connection_pool_size", fallback=50), ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=30.0), # Reduced from 60.0 + config.get("neo4j", "connection_timeout", fallback=30.0), ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), # Reduced from 60.0 + config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), ), ) MAX_TRANSACTION_RETRY_TIME = float( @@ -188,23 +188,24 @@ class Neo4JStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: entity_name_label = await self._ensure_label(node_id) - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" ) result = await session.run(query) single_result = await result.single() await result.consume() # Ensure result is fully consumed - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}" - ) return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = ( f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " "RETURN COUNT(r) > 0 AS edgeExists" @@ -212,9 +213,6 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) single_result = await result.single() await result.consume() # Ensure result is fully consumed - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}" - ) return single_result["edgeExists"] async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -227,14 +225,20 @@ class Neo4JStorage(BaseGraphStorage): dict: Node properties if found None: If node not found """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: entity_name_label = await self._ensure_label(node_id) query = f"MATCH (n:`{entity_name_label}`) RETURN n" result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates await result.consume() # Ensure result is fully consumed if len(records) > 1: - logger.warning(f"Multiple nodes found with label '{entity_name_label}'. Using first node.") + logger.warning( + f"Multiple nodes found with label '{entity_name_label}'. Using first node." + ) if records: node = records[0]["n"] node_dict = dict(node) @@ -248,16 +252,18 @@ class Neo4JStorage(BaseGraphStorage): """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. If no node is found, returns 0. - + Args: node_id: The label of the node - + Returns: int: The number of relationships the node has, or 0 if no node found """ entity_name_label = node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = f""" MATCH (n:`{entity_name_label}`) OPTIONAL MATCH (n)-[r]-() @@ -266,14 +272,16 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) records = await result.fetch(100) await result.consume() # Ensure result is fully consumed - + if not records: logger.warning(f"No node found with label '{entity_name_label}'") return 0 - + if len(records) > 1: - logger.warning(f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree") - + logger.warning( + f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + ) + degree = records[0]["degree"] logger.debug( f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" @@ -296,30 +304,6 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees - async def check_duplicate_nodes(self) -> list[tuple[str, int]]: - """Find all labels that have multiple nodes - - Returns: - list[tuple[str, int]]: List of tuples containing (label, node_count) for labels with multiple nodes - """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: - query = """ - MATCH (n) - WITH labels(n) as nodeLabels - UNWIND nodeLabels as label - WITH label, count(*) as node_count - WHERE node_count > 1 - RETURN label, node_count - ORDER BY node_count DESC - """ - result = await session.run(query) - duplicates = [] - async for record in result: - label = record["label"] - count = record["node_count"] - logger.info(f"Found {count} nodes with label: {label}") - duplicates.append((label, count)) - return duplicates async def get_edge( self, source_node_id: str, target_node_id: str @@ -328,64 +312,69 @@ class Neo4JStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: query = f""" MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`) RETURN properties(r) as edge_properties """ result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + try: + records = await result.fetch(2) # Get up to 2 records to check for duplicates + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: + try: + result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in result: + result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" + ) + return result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" ) - if records: - try: - result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in result: - result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" - ) - return result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" - ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + # Return default edge properties when no edge found + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + finally: + await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( @@ -409,7 +398,9 @@ class Neo4JStorage(BaseGraphStorage): query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) RETURN n, r, connected""" - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: results = await session.run(query) edges = [] try: @@ -429,7 +420,9 @@ class Neo4JStorage(BaseGraphStorage): if source_label and target_label: edges.append((source_label, target_label)) finally: - await results.consume() # Ensure results are consumed even if processing fails + await ( + results.consume() + ) # Ensure results are consumed even if processing fails return edges @@ -461,10 +454,11 @@ class Neo4JStorage(BaseGraphStorage): MERGE (n:`{label}`) SET n += $properties """ - await tx.run(query, properties=properties) + result = await tx.run(query, properties=properties) logger.debug( f"Upserted node with label '{label}' and properties: {properties}" ) + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -509,9 +503,13 @@ class Neo4JStorage(BaseGraphStorage): target_exists = await self.has_node(target_label) if not source_exists: - raise ValueError(f"Neo4j: source node with label '{source_label}' does not exist") + raise ValueError( + f"Neo4j: source node with label '{source_label}' does not exist" + ) if not target_exists: - raise ValueError(f"Neo4j: target node with label '{target_label}' does not exist") + raise ValueError( + f"Neo4j: target node with label '{target_label}' does not exist" + ) async def _do_upsert_edge(tx: AsyncManagedTransaction): query = f""" @@ -570,7 +568,9 @@ class Neo4JStorage(BaseGraphStorage): seen_nodes = set() seen_edges = set() - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: try: if label == "*": main_query = """ @@ -728,11 +728,9 @@ class Neo4JStorage(BaseGraphStorage): visited_nodes.add(node_id) # Add node data with label as ID - result["nodes"].append({ - "id": current_label, - "labels": current_label, - "properties": node - }) + result["nodes"].append( + {"id": current_label, "labels": current_label, "properties": node} + ) # Get connected nodes that meet the degree requirement # Note: We don't need to check a's degree since it's the current node @@ -744,7 +742,9 @@ class Neo4JStorage(BaseGraphStorage): WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) RETURN r, b """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: results = await session.run(query, {"min_degree": min_degree}) async for record in results: # Handle edges @@ -754,19 +754,23 @@ class Neo4JStorage(BaseGraphStorage): b_node = record["b"] if b_node.labels: # Only process if target node has labels target_label = list(b_node.labels)[0] - result["edges"].append({ - "id": f"{current_label}_{target_label}", - "type": rel.type, - "source": current_label, - "target": target_label, - "properties": dict(rel) - }) + result["edges"].append( + { + "id": f"{current_label}_{target_label}", + "type": rel.type, + "source": current_label, + "target": target_label, + "properties": dict(rel), + } + ) visited_edges.add(edge_id) # Continue traversal await traverse(target_label, current_depth + 1) else: - logger.warning(f"Skipping edge {edge_id} due to missing labels on target node") + logger.warning( + f"Skipping edge {edge_id} due to missing labels on target node" + ) await traverse(label, 0) return result @@ -777,7 +781,9 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ - async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" @@ -796,7 +802,9 @@ class Neo4JStorage(BaseGraphStorage): async for record in result: labels.append(record["label"]) finally: - await result.consume() # Ensure results are consumed even if processing fails + await ( + result.consume() + ) # Ensure results are consumed even if processing fails return labels @retry( @@ -824,8 +832,9 @@ class Neo4JStorage(BaseGraphStorage): MATCH (n:`{label}`) DETACH DELETE n """ - await tx.run(query) + result = await tx.run(query) logger.debug(f"Deleted node with label '{label}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: @@ -882,8 +891,9 @@ class Neo4JStorage(BaseGraphStorage): MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`) DELETE r """ - await tx.run(query) + result = await tx.run(query) logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'") + await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: From fcb04e47e5f1beda21c9304ba3c07d90e2e07fc1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 04:28:54 +0800 Subject: [PATCH 05/33] Refactor Neo4J APOC fall back retrival implementaion --- lightrag/kg/neo4j_impl.py | 255 ++++++++++++++++++++++---------------- 1 file changed, 149 insertions(+), 106 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index ea316d0f..60e8982e 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -3,7 +3,7 @@ import inspect import os import re from dataclasses import dataclass -from typing import Any, List, Dict, final +from typing import Any, final, Optional import numpy as np import configparser @@ -304,7 +304,6 @@ class Neo4JStorage(BaseGraphStorage): ) return degrees - async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: @@ -321,60 +320,59 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) - try: - records = await result.fetch(2) # Get up to 2 records to check for duplicates - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." - ) - if records: - try: - result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in result: - result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}" - ) - return result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + records = await result.fetch(2) # Get up to 2 records to check for duplicates + await result.consume() # Ensure result is fully consumed before processing records + + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - finally: - await result.consume() # Ensure result is fully consumed + if records: + try: + edge_result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {edge_result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" + ) + return edge_result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" + ) + # Return default edge properties when no edge found + return { + "weight": 0.0, + "description": None, + "keywords": None, + "source_id": None, + } except Exception as e: logger.error( @@ -685,30 +683,36 @@ class Neo4JStorage(BaseGraphStorage): await result_set.consume() # Ensure result set is consumed except neo4jExceptions.ClientError as e: - logger.warning( - f"APOC plugin error: {str(e)}, falling back to basic Cypher implementation" - ) - if inclusive: + logger.warning(f"APOC plugin error: {str(e)}") + if label != "*": logger.warning( - "Inclusive search mode is not supported in recursive query, using exact matching" + "Neo4j: falling back to basic Cypher recursive search..." ) - return await self._robust_fallback(label, max_depth, min_degree) + if inclusive: + logger.warning( + "Neo4j: inclusive search mode is not supported in recursive query, using exact matching" + ) + return await self._robust_fallback(label, max_depth, min_degree) return result async def _robust_fallback( self, label: str, max_depth: int, min_degree: int = 0 - ) -> Dict[str, List[Dict]]: + ) -> KnowledgeGraph: """ Fallback implementation when APOC plugin is not available or incompatible. This method implements the same functionality as get_knowledge_graph but uses only basic Cypher queries and recursive traversal instead of APOC procedures. """ - result = {"nodes": [], "edges": []} + result = KnowledgeGraph() visited_nodes = set() visited_edges = set() - async def traverse(current_label: str, current_depth: int): + async def traverse( + node: KnowledgeGraphNode, + edge: Optional[KnowledgeGraphEdge], + current_depth: int, + ): # Check traversal limits if current_depth > max_depth: logger.debug(f"Reached max depth: {max_depth}") @@ -717,62 +721,101 @@ class Neo4JStorage(BaseGraphStorage): logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}") return - # Get current node details - node = await self.get_node(current_label) - if not node: + # Check if node already visited + if node.id in visited_nodes: return - node_id = f"{current_label}" - if node_id in visited_nodes: - return - visited_nodes.add(node_id) - - # Add node data with label as ID - result["nodes"].append( - {"id": current_label, "labels": current_label, "properties": node} - ) - - # Get connected nodes that meet the degree requirement - # Note: We don't need to check a's degree since it's the current node - # and was already validated in the previous iteration - query = f""" - MATCH (a:`{current_label}`)-[r]-(b) - WITH r, b, - COUNT((b)--()) AS b_degree - WHERE b_degree >= $min_degree OR EXISTS((a)--(b)) - RETURN r, b - """ + # Get all edges and target nodes async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - results = await session.run(query, {"min_degree": min_degree}) - async for record in results: - # Handle edges + query = """ + MATCH (a)-[r]-(b) + WHERE id(a) = toInteger($node_id) + WITH r, b, id(r) as edge_id, id(b) as target_id + RETURN r, b, edge_id, target_id + """ + results = await session.run(query, {"node_id": node.id}) + + # Get all records and release database connection + records = await results.fetch() + await results.consume() # Ensure results are consumed + + # Nodes not connected to start node need to check degree + if current_depth > 1 and len(records) < min_degree: + return + + # Add current node to result + result.nodes.append(node) + visited_nodes.add(node.id) + + # Add edge to result if it exists and not already added + if edge and edge.id not in visited_edges: + result.edges.append(edge) + visited_edges.add(edge.id) + + # Prepare nodes and edges for recursive processing + nodes_to_process = [] + for record in records: rel = record["r"] - edge_id = f"{rel.id}_{rel.type}" + edge_id = str(record["edge_id"]) if edge_id not in visited_edges: b_node = record["b"] - if b_node.labels: # Only process if target node has labels - target_label = list(b_node.labels)[0] - result["edges"].append( - { - "id": f"{current_label}_{target_label}", - "type": rel.type, - "source": current_label, - "target": target_label, - "properties": dict(rel), - } - ) - visited_edges.add(edge_id) + target_id = str(record["target_id"]) - # Continue traversal - await traverse(target_label, current_depth + 1) + if b_node.labels: # Only process if target node has labels + # Create KnowledgeGraphNode for target + target_node = KnowledgeGraphNode( + id=target_id, + labels=list(b_node.labels), + properties=dict(b_node), + ) + + # Create KnowledgeGraphEdge + target_edge = KnowledgeGraphEdge( + id=edge_id, + type=rel.type, + source=node.id, + target=target_id, + properties=dict(rel), + ) + + nodes_to_process.append((target_node, target_edge)) else: logger.warning( f"Skipping edge {edge_id} due to missing labels on target node" ) - await traverse(label, 0) + # Process nodes after releasing database connection + for target_node, target_edge in nodes_to_process: + await traverse(target_node, target_edge, current_depth + 1) + + # Get the starting node's data + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (n:`{label}`) + RETURN id(n) as node_id, n + """ + node_result = await session.run(query) + try: + node_record = await node_result.single() + if not node_record: + return result + + # Create initial KnowledgeGraphNode + start_node = KnowledgeGraphNode( + id=str(node_record["node_id"]), + labels=list(node_record["n"].labels), + properties=dict(node_record["n"]), + ) + finally: + await node_result.consume() # Ensure results are consumed + + # Start traversal with the initial node + await traverse(start_node, None, 0) + return result async def get_all_labels(self) -> list[str]: From 84222b8b76bb077b144463af8acfde8df188d505 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 10:19:20 +0800 Subject: [PATCH 06/33] Refactor Neo4JStorage methods for robustness and clarity. - Add error handling and resource cleanup - Improve method documentation - Optimize result consumption --- lightrag/kg/neo4j_impl.py | 412 +++++++++++++++++++++++--------------- 1 file changed, 255 insertions(+), 157 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 60e8982e..082b4bf2 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -163,13 +163,14 @@ class Neo4JStorage(BaseGraphStorage): } async def close(self): + """Close the Neo4j driver and release all resources""" if self._driver: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): - if self._driver: - await self._driver.close() + """Ensure driver is closed when context manager exits""" + await self.close() async def index_done_callback(self) -> None: # Noe4J handles persistence automatically @@ -187,33 +188,72 @@ class Neo4JStorage(BaseGraphStorage): return clean_label async def has_node(self, node_id: str) -> bool: + """ + Check if a node with the given label exists in the database + + Args: + node_id: Label of the node to check + + Returns: + bool: True if node exists, False otherwise + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query + """ entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" - ) - result = await session.run(query) - single_result = await result.single() - await result.consume() # Ensure result is fully consumed - return single_result["node_exists"] + try: + query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists" + result = await session.run(query) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["node_exists"] + except Exception as e: + logger.error( + f"Error checking node existence for {entity_name_label}: {str(e)}" + ) + await result.consume() # Ensure results are consumed even on error + raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + """ + Check if an edge exists between two nodes + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + bool: True if edge exists, False otherwise + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ + entity_name_label_source = await self._ensure_label(source_node_id) + entity_name_label_target = await self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = ( - f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " - "RETURN COUNT(r) > 0 AS edgeExists" - ) - result = await session.run(query) - single_result = await result.single() - await result.consume() # Ensure result is fully consumed - return single_result["edgeExists"] + try: + query = ( + f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run(query) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return single_result["edgeExists"] + except Exception as e: + logger.error( + f"Error checking edge existence between {entity_name_label_source} and {entity_name_label_target}: {str(e)}" + ) + await result.consume() # Ensure results are consumed even on error + raise async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier. @@ -224,29 +264,40 @@ class Neo4JStorage(BaseGraphStorage): Returns: dict: Node properties if found None: If node not found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query """ + entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - entity_name_label = await self._ensure_label(node_id) - query = f"MATCH (n:`{entity_name_label}`) RETURN n" - result = await session.run(query) - records = await result.fetch( - 2 - ) # Get up to 2 records to check for duplicates - await result.consume() # Ensure result is fully consumed - if len(records) > 1: - logger.warning( - f"Multiple nodes found with label '{entity_name_label}'. Using first node." - ) - if records: - node = records[0]["n"] - node_dict = dict(node) - logger.debug( - f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" - ) - return node_dict - return None + try: + query = f"MATCH (n:`{entity_name_label}`) RETURN n" + result = await session.run(query) + try: + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{entity_name_label}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + logger.debug( + f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}" + ) + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {entity_name_label}: {str(e)}") + raise async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. @@ -258,39 +309,63 @@ class Neo4JStorage(BaseGraphStorage): Returns: int: The number of relationships the node has, or 0 if no node found + + Raises: + ValueError: If node_id is invalid + Exception: If there is an error executing the query """ - entity_name_label = node_id.strip('"') + entity_name_label = await self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = f""" - MATCH (n:`{entity_name_label}`) - OPTIONAL MATCH (n)-[r]-() - RETURN n, COUNT(r) AS degree - """ - result = await session.run(query) - records = await result.fetch(100) - await result.consume() # Ensure result is fully consumed + try: + query = f""" + MATCH (n:`{entity_name_label}`) + OPTIONAL MATCH (n)-[r]-() + RETURN n, COUNT(r) AS degree + """ + result = await session.run(query) + try: + records = await result.fetch(100) - if not records: - logger.warning(f"No node found with label '{entity_name_label}'") - return 0 + if not records: + logger.warning( + f"No node found with label '{entity_name_label}'" + ) + return 0 - if len(records) > 1: - logger.warning( - f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + if len(records) > 1: + logger.warning( + f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree" + ) + + degree = records[0]["degree"] + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" + ) + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error( + f"Error getting node degree for {entity_name_label}: {str(e)}" ) - - degree = records[0]["degree"] - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}" - ) - return degree + raise async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ + entity_name_label_source = await self._ensure_label(src_id) + entity_name_label_target = await self._ensure_label(tgt_id) + src_degree = await self.node_degree(entity_name_label_source) trg_degree = await self.node_degree(entity_name_label_target) @@ -299,17 +374,27 @@ class Neo4JStorage(BaseGraphStorage): trg_degree = 0 if trg_degree is None else trg_degree degrees = int(src_degree) + int(trg_degree) - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}" - ) return degrees async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + ValueError: If either node_id is invalid + Exception: If there is an error executing the query + """ try: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + entity_name_label_source = await self._ensure_label(source_node_id) + entity_name_label_target = await self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -320,109 +405,123 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) - records = await result.fetch(2) # Get up to 2 records to check for duplicates - await result.consume() # Ensure result is fully consumed before processing records - - if len(records) > 1: - logger.warning( - f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + try: + records = await result.fetch( + 2 + ) # Get up to 2 records to check for duplicates + + if len(records) > 1: + logger.warning( + f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge." + ) + if records: + try: + edge_result = dict(records[0]["edge_properties"]) + logger.debug(f"Result: {edge_result}") + # Ensure required keys exist with defaults + required_keys = { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + for key, default_value in required_keys.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {entity_name_label_source} and {entity_name_label_target} " + f"missing {key}, using default: {default_value}" + ) + + logger.debug( + f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" + ) + return edge_result + except (KeyError, TypeError, ValueError) as e: + logger.error( + f"Error processing edge properties between {entity_name_label_source} " + f"and {entity_name_label_target}: {str(e)}" + ) + # Return default edge properties on error + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + + logger.debug( + f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" ) - if records: - try: - edge_result = dict(records[0]["edge_properties"]) - logger.debug(f"Result: {edge_result}") - # Ensure required keys exist with defaults - required_keys = { - "weight": 0.0, - "source_id": None, - "description": None, - "keywords": None, - } - for key, default_value in required_keys.items(): - if key not in edge_result: - edge_result[key] = default_value - logger.warning( - f"Edge between {entity_name_label_source} and {entity_name_label_target} " - f"missing {key}, using default: {default_value}" - ) - - logger.debug( - f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" - ) - return edge_result - except (KeyError, TypeError, ValueError) as e: - logger.error( - f"Error processing edge properties between {entity_name_label_source} " - f"and {entity_name_label_target}: {str(e)}" - ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } - - logger.debug( - f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}" - ) - # Return default edge properties when no edge found - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + # Return default edge properties when no edge found + return { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + } + finally: + await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) - # Return default edge properties on error - return { - "weight": 0.0, - "description": None, - "keywords": None, - "source_id": None, - } + raise async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - node_label = source_node_id.strip('"') + """Retrieves all edges (relationships) for a particular node identified by its label. + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + ValueError: If source_node_id is invalid + Exception: If there is an error executing the query """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information - """ - query = f"""MATCH (n:`{node_label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - async with self._driver.session( - database=self._DATABASE, default_access_mode="READ" - ) as session: - results = await session.run(query) - edges = [] - try: - async for record in results: - source_node = record["n"] - connected_node = record["connected"] + try: + node_label = await self._ensure_label(source_node_id) - source_label = ( - list(source_node.labels)[0] if source_node.labels else None - ) - target_label = ( - list(connected_node.labels)[0] - if connected_node and connected_node.labels - else None - ) + query = f"""MATCH (n:`{node_label}`) + OPTIONAL MATCH (n)-[r]-(connected) + RETURN n, r, connected""" - if source_label and target_label: - edges.append((source_label, target_label)) - finally: - await ( - results.consume() - ) # Ensure results are consumed even if processing fails + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + results = await session.run(query) + edges = [] - return edges + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + source_label = ( + list(source_node.labels)[0] if source_node.labels else None + ) + target_label = ( + list(connected_node.labels)[0] + if connected_node and connected_node.labels + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges if edges else None + except Exception as e: + logger.error(f"Error getting edges for node {node_label}: {str(e)}") + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise @retry( stop=stop_after_attempt(3), @@ -838,7 +937,6 @@ class Neo4JStorage(BaseGraphStorage): RETURN DISTINCT label ORDER BY label """ - result = await session.run(query) labels = [] try: From 78f8d7a1ce1186ce3398afb946f3da79bad50df7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 10:20:10 +0800 Subject: [PATCH 07/33] Convert node and edge IDs to f-strings for consistency. - Use f-strings for node IDs - Use f-strings for edge IDs - Ensure consistent ID formatting --- lightrag/kg/neo4j_impl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 082b4bf2..05deb0a9 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -865,17 +865,17 @@ class Neo4JStorage(BaseGraphStorage): if b_node.labels: # Only process if target node has labels # Create KnowledgeGraphNode for target target_node = KnowledgeGraphNode( - id=target_id, + id=f"{target_id}", labels=list(b_node.labels), properties=dict(b_node), ) # Create KnowledgeGraphEdge target_edge = KnowledgeGraphEdge( - id=edge_id, + id=f"{edge_id}", type=rel.type, - source=node.id, - target=target_id, + source=f"{node.id}", + target=f"{target_id}", properties=dict(rel), ) @@ -905,7 +905,7 @@ class Neo4JStorage(BaseGraphStorage): # Create initial KnowledgeGraphNode start_node = KnowledgeGraphNode( - id=str(node_record["node_id"]), + id=f"{node_record['node_id']}", labels=list(node_record["n"].labels), properties=dict(node_record["n"]), ) From af26d656985e0d9dd722c1cc8ea0d65f6348dc79 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 10:23:27 +0800 Subject: [PATCH 08/33] Convert _ensure_label method from async to sync --- lightrag/kg/neo4j_impl.py | 40 ++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 05deb0a9..cf3c024f 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -176,11 +176,17 @@ class Neo4JStorage(BaseGraphStorage): # Noe4J handles persistence automatically pass - async def _ensure_label(self, label: str) -> str: + def _ensure_label(self, label: str) -> str: """Ensure a label is valid Args: label: The label to validate + + Returns: + str: The cleaned label + + Raises: + ValueError: If label is empty after cleaning """ clean_label = label.strip('"') if not clean_label: @@ -201,7 +207,7 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = await self._ensure_label(node_id) + entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -233,8 +239,8 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If either node_id is invalid Exception: If there is an error executing the query """ - entity_name_label_source = await self._ensure_label(source_node_id) - entity_name_label_target = await self._ensure_label(target_node_id) + entity_name_label_source = self._ensure_label(source_node_id) + entity_name_label_target = self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -269,7 +275,7 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = await self._ensure_label(node_id) + entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -314,7 +320,7 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ - entity_name_label = await self._ensure_label(node_id) + entity_name_label = self._ensure_label(node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -363,8 +369,8 @@ class Neo4JStorage(BaseGraphStorage): Returns: int: Sum of the degrees of both nodes """ - entity_name_label_source = await self._ensure_label(src_id) - entity_name_label_target = await self._ensure_label(tgt_id) + entity_name_label_source = self._ensure_label(src_id) + entity_name_label_target = self._ensure_label(tgt_id) src_degree = await self.node_degree(entity_name_label_source) trg_degree = await self.node_degree(entity_name_label_target) @@ -393,8 +399,8 @@ class Neo4JStorage(BaseGraphStorage): Exception: If there is an error executing the query """ try: - entity_name_label_source = await self._ensure_label(source_node_id) - entity_name_label_target = await self._ensure_label(target_node_id) + entity_name_label_source = self._ensure_label(source_node_id) + entity_name_label_target = self._ensure_label(target_node_id) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" @@ -484,7 +490,7 @@ class Neo4JStorage(BaseGraphStorage): Exception: If there is an error executing the query """ try: - node_label = await self._ensure_label(source_node_id) + node_label = self._ensure_label(source_node_id) query = f"""MATCH (n:`{node_label}`) OPTIONAL MATCH (n)-[r]-(connected) @@ -543,7 +549,7 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = await self._ensure_label(node_id) + label = self._ensure_label(node_id) properties = node_data async def _do_upsert(tx: AsyncManagedTransaction): @@ -591,8 +597,8 @@ class Neo4JStorage(BaseGraphStorage): Raises: ValueError: If either source or target node does not exist """ - source_label = await self._ensure_label(source_node_id) - target_label = await self._ensure_label(target_node_id) + source_label = self._ensure_label(source_node_id) + target_label = self._ensure_label(target_node_id) edge_properties = edge_data # Check if both nodes exist @@ -966,7 +972,7 @@ class Neo4JStorage(BaseGraphStorage): Args: node_id: The label of the node to delete """ - label = await self._ensure_label(node_id) + label = self._ensure_label(node_id) async def _do_delete(tx: AsyncManagedTransaction): query = f""" @@ -1024,8 +1030,8 @@ class Neo4JStorage(BaseGraphStorage): edges: List of edges to be deleted, each edge is a (source, target) tuple """ for source, target in edges: - source_label = await self._ensure_label(source) - target_label = await self._ensure_label(target) + source_label = self._ensure_label(source) + target_label = self._ensure_label(target) async def _do_delete_edge(tx: AsyncManagedTransaction): query = f""" From 887f6ed81a2cb6036163105433b160e1343daf98 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 11:20:22 +0800 Subject: [PATCH 09/33] Fix return empty list when no edges is found --- lightrag/kg/neo4j_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index cf3c024f..34226df7 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -520,7 +520,7 @@ class Neo4JStorage(BaseGraphStorage): edges.append((source_label, target_label)) await results.consume() # Ensure results are consumed - return edges if edges else None + return edges except Exception as e: logger.error(f"Error getting edges for node {node_label}: {str(e)}") await results.consume() # Ensure results are consumed even on error From 22a93fb717b7a66dda345fbacdb2e6d5df874707 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 11:29:08 +0800 Subject: [PATCH 10/33] Limit neighbor nodes fetch to 1000 in Neo4JStorage. --- lightrag/kg/neo4j_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 34226df7..7e1007b9 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -843,7 +843,7 @@ class Neo4JStorage(BaseGraphStorage): results = await session.run(query, {"node_id": node.id}) # Get all records and release database connection - records = await results.fetch() + records = await results.fetch(1000) # Max neighbour nodes we can handled await results.consume() # Ensure results are consumed # Nodes not connected to start node need to check degree From fb4a4c736edca76f8ab5968c0b4d8869bec94bf2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 11:36:24 +0800 Subject: [PATCH 11/33] Add duplicate edge upsert checking and logging --- lightrag/kg/neo4j_impl.py | 78 ++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 7e1007b9..1e46798a 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -412,9 +412,7 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) try: - records = await result.fetch( - 2 - ) # Get up to 2 records to check for duplicates + records = await result.fetch(2) if len(records) > 1: logger.warning( @@ -552,20 +550,20 @@ class Neo4JStorage(BaseGraphStorage): label = self._ensure_label(node_id) properties = node_data - async def _do_upsert(tx: AsyncManagedTransaction): - query = f""" - MERGE (n:`{label}`) - SET n += $properties - """ - result = await tx.run(query, properties=properties) - logger.debug( - f"Upserted node with label '{label}' and properties: {properties}" - ) - await result.consume() # Ensure result is fully consumed - try: async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_upsert) + async def execute_upsert(tx: AsyncManagedTransaction): + query = f""" + MERGE (n:`{label}`) + SET n += $properties + """ + result = await tx.run(query, properties=properties) + logger.debug( + f"Upserted node with label '{label}' and properties: {properties}" + ) + await result.consume() # Ensure result is fully consumed + + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") raise @@ -614,27 +612,39 @@ class Neo4JStorage(BaseGraphStorage): f"Neo4j: target node with label '{target_label}' does not exist" ) - async def _do_upsert_edge(tx: AsyncManagedTransaction): - query = f""" - MATCH (source:`{source_label}`) - WITH source - MATCH (target:`{target_label}`) - MERGE (source)-[r:DIRECTED]-(target) - SET r += $properties - RETURN r - """ - result = await tx.run(query, properties=edge_properties) - try: - record = await result.single() - logger.debug( - f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}" - ) - finally: - await result.consume() # Ensure result is consumed - try: async with self._driver.session(database=self._DATABASE) as session: - await session.execute_write(_do_upsert_edge) + async def execute_upsert(tx: AsyncManagedTransaction): + query = f""" + MATCH (source:`{source_label}`) + WITH source + MATCH (target:`{target_label}`) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run(query, properties=edge_properties) + try: + records = await result.fetch(100) + if len(records) > 1: + source_nodes = [dict(r['source']) for r in records] + target_nodes = [dict(r['target']) for r in records] + logger.warning( + f"Multiple edges created: found {len(records)} results for edge between " + f"source label '{source_label}' and target label '{target_label}'. " + f"Source nodes: {source_nodes}, " + f"Target nodes: {target_nodes}. " + "Using first edge only." + ) + if records: + logger.debug( + f"Upserted edge from '{source_label}' to '{target_label}' " + f"with properties: {edge_properties}" + ) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") raise From 95c06f1bde92bb5ced2c2a2536b2c304a414db0a Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 22:36:41 +0800 Subject: [PATCH 12/33] Add graph DB lock to shared storage system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Introduced new graph_db_lock • Added detailed lock debugging output --- lightrag/kg/shared_storage.py | 94 +++++++++++++++++++++++++---------- 1 file changed, 69 insertions(+), 25 deletions(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index c8c154aa..67206971 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -7,12 +7,18 @@ from typing import Any, Dict, Optional, Union, TypeVar, Generic # Define a direct print function for critical logs that must be visible in all processes -def direct_log(message, level="INFO"): +def direct_log(message, level="INFO", enable_output: bool = True): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. + + Args: + message: The message to log + level: Log level (default: "INFO") + enable_output: Whether to actually output the log (default: True) """ - print(f"{level}: {message}", file=sys.stderr, flush=True) + if enable_output: + print(f"{level}: {message}", file=sys.stderr, flush=True) T = TypeVar("T") @@ -32,55 +38,88 @@ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated _storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None +_graph_db_lock: Optional[LockType] = None class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" - def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool): + def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool, name: str = "unnamed", enable_logging: bool = True): self._lock = lock self._is_async = is_async + self._pid = os.getpid() # for debug only + self._name = name # for debug only + self._enable_logging = enable_logging # for debug only async def __aenter__(self) -> "UnifiedLock[T]": - if self._is_async: - await self._lock.acquire() - else: - self._lock.acquire() - return self + try: + direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) + if self._is_async: + await self._lock.acquire() + else: + self._lock.acquire() + direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", enable_output=self._enable_logging) + return self + except Exception as e: + direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) + raise async def __aexit__(self, exc_type, exc_val, exc_tb): - if self._is_async: - self._lock.release() - else: - self._lock.release() + try: + direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) + if self._is_async: + self._lock.release() + else: + self._lock.release() + direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", enable_output=self._enable_logging) + except Exception as e: + direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) + raise def __enter__(self) -> "UnifiedLock[T]": """For backward compatibility""" - if self._is_async: - raise RuntimeError("Use 'async with' for shared_storage lock") - self._lock.acquire() - return self + try: + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", enable_output=self._enable_logging) + self._lock.acquire() + direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", enable_output=self._enable_logging) + return self + except Exception as e: + direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) + raise def __exit__(self, exc_type, exc_val, exc_tb): """For backward compatibility""" - if self._is_async: - raise RuntimeError("Use 'async with' for shared_storage lock") - self._lock.release() + try: + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", enable_output=self._enable_logging) + self._lock.release() + direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", enable_output=self._enable_logging) + except Exception as e: + direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) + raise -def get_internal_lock() -> UnifiedLock: +def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess) + return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess, name="internal_lock", enable_logging=enable_logging) -def get_storage_lock() -> UnifiedLock: +def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess) + return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess, name="storage_lock", enable_logging=enable_logging) -def get_pipeline_status_lock() -> UnifiedLock: +def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess) + return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging) + + +def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: + """return unified graph database lock for ensuring atomic operations""" + return UnifiedLock(lock=_graph_db_lock, is_async=not is_multiprocess, name="graph_db_lock", enable_logging=enable_logging) def initialize_share_data(workers: int = 1): @@ -108,6 +147,7 @@ def initialize_share_data(workers: int = 1): _storage_lock, \ _internal_lock, \ _pipeline_status_lock, \ + _graph_db_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ @@ -128,6 +168,7 @@ def initialize_share_data(workers: int = 1): _internal_lock = _manager.Lock() _storage_lock = _manager.Lock() _pipeline_status_lock = _manager.Lock() + _graph_db_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() @@ -139,6 +180,7 @@ def initialize_share_data(workers: int = 1): _internal_lock = asyncio.Lock() _storage_lock = asyncio.Lock() _pipeline_status_lock = asyncio.Lock() + _graph_db_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} _update_flags = {} @@ -304,6 +346,7 @@ def finalize_share_data(): _storage_lock, \ _internal_lock, \ _pipeline_status_lock, \ + _graph_db_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ @@ -369,6 +412,7 @@ def finalize_share_data(): _storage_lock = None _internal_lock = None _pipeline_status_lock = None + _graph_db_lock = None _update_flags = None direct_log(f"Process {os.getpid()} storage data finalization complete") From 73452e63fa76f6b710de42ca34e4f5823c27e01a Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 8 Mar 2025 22:48:12 +0800 Subject: [PATCH 13/33] Add async lock for atomic graph database operations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Introduced graph_db_lock mechanism • Ensured atomic node/edge merge and insert operation --- lightrag/operate.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 30983145..f89a551d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -519,19 +519,24 @@ async def extract_entities( for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) - all_entities_data = await asyncio.gather( - *[ - _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ) + from .kg.shared_storage import get_graph_db_lock + graph_db_lock = get_graph_db_lock(enable_logging = True) + + # Ensure that nodes and edges are merged and upserted atomically + async with graph_db_lock: + all_entities_data = await asyncio.gather( + *[ + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ) - all_relationships_data = await asyncio.gather( - *[ - _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) - for k, v in maybe_edges.items() - ] - ) + all_relationships_data = await asyncio.gather( + *[ + _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) + for k, v in maybe_edges.items() + ] + ) if not (all_entities_data or all_relationships_data): log_message = "Didn't extract any entities and relationships." From 18c077040939c7e5a90a90e06af1b0da3c6911f6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 00:24:55 +0800 Subject: [PATCH 14/33] fix: duplicate nodes for same entity(label) problem in Neo4j - Add entity_id field as key in Neo4j nodes - Use entity_id for nodes retrival and upsert --- lightrag/kg/neo4j_impl.py | 106 ++++++++++++++++++++++++++------------ lightrag/operate.py | 2 + 2 files changed, 74 insertions(+), 34 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 1e46798a..0b660d68 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -280,12 +280,10 @@ class Neo4JStorage(BaseGraphStorage): database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = f"MATCH (n:`{entity_name_label}`) RETURN n" - result = await session.run(query) + query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n" + result = await session.run(query, entity_id=entity_name_label) try: - records = await result.fetch( - 2 - ) # Get up to 2 records to check for duplicates + records = await result.fetch(2) # Get 2 records for duplication check if len(records) > 1: logger.warning( @@ -549,12 +547,14 @@ class Neo4JStorage(BaseGraphStorage): """ label = self._ensure_label(node_id) properties = node_data + if "entity_id" not in properties: + raise ValueError("Neo4j: node properties must contain an 'entity_id' field") try: async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): query = f""" - MERGE (n:`{label}`) + MERGE (n:`{label}` {{entity_id: $properties.entity_id}}) SET n += $properties """ result = await tx.run(query, properties=properties) @@ -568,6 +568,56 @@ class Neo4JStorage(BaseGraphStorage): logger.error(f"Error during upsert: {str(e)}") raise + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + ( + neo4jExceptions.ServiceUnavailable, + neo4jExceptions.TransientError, + neo4jExceptions.WriteServiceUnavailable, + neo4jExceptions.ClientError, + ) + ), + ) + async def _get_unique_node_entity_id(self, node_label: str) -> str: + """ + Get the entity_id of a node with the given label, ensuring the node is unique. + + Args: + node_label (str): Label of the node to check + + Returns: + str: The entity_id of the unique node + + Raises: + ValueError: If no node with the given label exists or if multiple nodes have the same label + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + MATCH (n:`{node_label}`) + RETURN n, count(n) as node_count + """ + result = await session.run(query) + try: + records = await result.fetch(2) # We only need to know if there are 0, 1, or >1 nodes + + if not records or records[0]["node_count"] == 0: + raise ValueError(f"Neo4j: node with label '{node_label}' does not exist") + + if records[0]["node_count"] > 1: + raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node") + + node = records[0]["n"] + if "entity_id" not in node: + raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property") + + return node["entity_id"] + finally: + await result.consume() # Ensure result is fully consumed + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -585,7 +635,8 @@ class Neo4JStorage(BaseGraphStorage): ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. - Checks if both source and target nodes exist before creating the edge. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. Args: source_node_id (str): Label of the source node (used as identifier) @@ -593,52 +644,39 @@ class Neo4JStorage(BaseGraphStorage): edge_data (dict): Dictionary of properties to set on the edge Raises: - ValueError: If either source or target node does not exist + ValueError: If either source or target node does not exist or is not unique """ source_label = self._ensure_label(source_node_id) target_label = self._ensure_label(target_node_id) edge_properties = edge_data - # Check if both nodes exist - source_exists = await self.has_node(source_label) - target_exists = await self.has_node(target_label) - - if not source_exists: - raise ValueError( - f"Neo4j: source node with label '{source_label}' does not exist" - ) - if not target_exists: - raise ValueError( - f"Neo4j: target node with label '{target_label}' does not exist" - ) + # Get entity_ids for source and target nodes, ensuring they are unique + source_entity_id = await self._get_unique_node_entity_id(source_label) + target_entity_id = await self._get_unique_node_entity_id(target_label) try: async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): query = f""" - MATCH (source:`{source_label}`) + MATCH (source:`{source_label}` {{entity_id: $source_entity_id}}) WITH source - MATCH (target:`{target_label}`) + MATCH (target:`{target_label}` {{entity_id: $target_entity_id}}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target """ - result = await tx.run(query, properties=edge_properties) + result = await tx.run( + query, + source_entity_id=source_entity_id, + target_entity_id=target_entity_id, + properties=edge_properties + ) try: records = await result.fetch(100) - if len(records) > 1: - source_nodes = [dict(r['source']) for r in records] - target_nodes = [dict(r['target']) for r in records] - logger.warning( - f"Multiple edges created: found {len(records)} results for edge between " - f"source label '{source_label}' and target label '{target_label}'. " - f"Source nodes: {source_nodes}, " - f"Target nodes: {target_nodes}. " - "Using first edge only." - ) if records: logger.debug( - f"Upserted edge from '{source_label}' to '{target_label}' " + f"Upserted edge from '{source_label}' (entity_id: {source_entity_id}) " + f"to '{target_label}' (entity_id: {target_entity_id}) " f"with properties: {edge_properties}" ) finally: diff --git a/lightrag/operate.py b/lightrag/operate.py index f89a551d..fb7b27a0 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -220,6 +220,7 @@ async def _merge_nodes_then_upsert( entity_name, description, global_config ) node_data = dict( + entity_id=entity_name, entity_type=entity_type, description=description, source_id=source_id, @@ -301,6 +302,7 @@ async def _merge_edges_then_upsert( await knowledge_graph_inst.upsert_node( need_insert_id, node_data={ + "entity_id": need_insert_id, "source_id": source_id, "description": description, "entity_type": "UNKNOWN", From 3cf4268e7abdd238036e596f06c6036d636a6c74 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 00:59:40 +0800 Subject: [PATCH 15/33] Change logging level from INFO to DEBUG for cache hit/miss messages --- lightrag/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index bb1d6fae..1b65097e 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -667,11 +667,11 @@ async def handle_cache( cache_type=cache_type, ) if best_cached_response is not None: - logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})") + logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})") return best_cached_response, None, None, None else: # if caching keyword embedding is enabled, return the quantized embedding for saving it latter - logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})") + logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})") return None, quantized, min_val, max_val # For default mode or is_embedding_cache_enabled is False, use regular cache @@ -681,10 +681,10 @@ async def handle_cache( else: mode_cache = await hashing_kv.get_by_id(mode) or {} if args_hash in mode_cache: - logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") + logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") return mode_cache[args_hash]["return"], None, None, None - logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") + logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") return None, None, None, None From c5d0962872bf525945931cac19245f5553db3e5d Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 01:00:42 +0800 Subject: [PATCH 16/33] Fix linting --- lightrag/kg/neo4j_impl.py | 46 +++++++++------ lightrag/kg/shared_storage.py | 102 ++++++++++++++++++++++++++++------ lightrag/operate.py | 9 ++- 3 files changed, 120 insertions(+), 37 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 0b660d68..d0841eec 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -181,10 +181,10 @@ class Neo4JStorage(BaseGraphStorage): Args: label: The label to validate - + Returns: str: The cleaned label - + Raises: ValueError: If label is empty after cleaning """ @@ -283,7 +283,9 @@ class Neo4JStorage(BaseGraphStorage): query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n" result = await session.run(query, entity_id=entity_name_label) try: - records = await result.fetch(2) # Get 2 records for duplication check + records = await result.fetch( + 2 + ) # Get 2 records for duplication check if len(records) > 1: logger.warning( @@ -552,6 +554,7 @@ class Neo4JStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): query = f""" MERGE (n:`{label}` {{entity_id: $properties.entity_id}}) @@ -562,7 +565,7 @@ class Neo4JStorage(BaseGraphStorage): f"Upserted node with label '{label}' and properties: {properties}" ) await result.consume() # Ensure result is fully consumed - + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") @@ -602,18 +605,26 @@ class Neo4JStorage(BaseGraphStorage): """ result = await session.run(query) try: - records = await result.fetch(2) # We only need to know if there are 0, 1, or >1 nodes - + records = await result.fetch( + 2 + ) # We only need to know if there are 0, 1, or >1 nodes + if not records or records[0]["node_count"] == 0: - raise ValueError(f"Neo4j: node with label '{node_label}' does not exist") - + raise ValueError( + f"Neo4j: node with label '{node_label}' does not exist" + ) + if records[0]["node_count"] > 1: - raise ValueError(f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node") - + raise ValueError( + f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node" + ) + node = records[0]["n"] if "entity_id" not in node: - raise ValueError(f"Neo4j: node with label '{node_label}' does not have an entity_id property") - + raise ValueError( + f"Neo4j: node with label '{node_label}' does not have an entity_id property" + ) + return node["entity_id"] finally: await result.consume() # Ensure result is fully consumed @@ -656,6 +667,7 @@ class Neo4JStorage(BaseGraphStorage): try: async with self._driver.session(database=self._DATABASE) as session: + async def execute_upsert(tx: AsyncManagedTransaction): query = f""" MATCH (source:`{source_label}` {{entity_id: $source_entity_id}}) @@ -666,10 +678,10 @@ class Neo4JStorage(BaseGraphStorage): RETURN r, source, target """ result = await tx.run( - query, + query, source_entity_id=source_entity_id, target_entity_id=target_entity_id, - properties=edge_properties + properties=edge_properties, ) try: records = await result.fetch(100) @@ -681,7 +693,7 @@ class Neo4JStorage(BaseGraphStorage): ) finally: await result.consume() # Ensure result is consumed - + await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") @@ -891,7 +903,9 @@ class Neo4JStorage(BaseGraphStorage): results = await session.run(query, {"node_id": node.id}) # Get all records and release database connection - records = await results.fetch(1000) # Max neighbour nodes we can handled + records = await results.fetch( + 1000 + ) # Max neighbour nodes we can handled await results.consume() # Ensure results are consumed # Nodes not connected to start node need to check degree diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 67206971..9ccb2a99 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -11,7 +11,7 @@ def direct_log(message, level="INFO", enable_output: bool = True): """ Log a message directly to stderr to ensure visibility in all processes, including the Gunicorn master process. - + Args: message: The message to log level: Log level (default: "INFO") @@ -44,7 +44,13 @@ _graph_db_lock: Optional[LockType] = None class UnifiedLock(Generic[T]): """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" - def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool, name: str = "unnamed", enable_logging: bool = True): + def __init__( + self, + lock: Union[ProcessLock, asyncio.Lock], + is_async: bool, + name: str = "unnamed", + enable_logging: bool = True, + ): self._lock = lock self._is_async = is_async self._pid = os.getpid() # for debug only @@ -53,27 +59,47 @@ class UnifiedLock(Generic[T]): async def __aenter__(self) -> "UnifiedLock[T]": try: - direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})", + enable_output=self._enable_logging, + ) if self._is_async: await self._lock.acquire() else: self._lock.acquire() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})", + enable_output=self._enable_logging, + ) return self except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise async def __aexit__(self, exc_type, exc_val, exc_tb): try: - direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})", + enable_output=self._enable_logging, + ) if self._is_async: self._lock.release() else: self._lock.release() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})", + enable_output=self._enable_logging, + ) except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise def __enter__(self) -> "UnifiedLock[T]": @@ -81,12 +107,22 @@ class UnifiedLock(Generic[T]): try: if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") - direct_log(f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)", + enable_output=self._enable_logging, + ) self._lock.acquire() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)", + enable_output=self._enable_logging, + ) return self except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise def __exit__(self, exc_type, exc_val, exc_tb): @@ -94,32 +130,62 @@ class UnifiedLock(Generic[T]): try: if self._is_async: raise RuntimeError("Use 'async with' for shared_storage lock") - direct_log(f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", + enable_output=self._enable_logging, + ) self._lock.release() - direct_log(f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)", + enable_output=self._enable_logging, + ) except Exception as e: - direct_log(f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", level="ERROR", enable_output=self._enable_logging) + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=self._enable_logging, + ) raise def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess, name="internal_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_internal_lock, + is_async=not is_multiprocess, + name="internal_lock", + enable_logging=enable_logging, + ) def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess, name="storage_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_storage_lock, + is_async=not is_multiprocess, + name="storage_lock", + enable_logging=enable_logging, + ) def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: """return unified storage lock for data consistency""" - return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess, name="pipeline_status_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_pipeline_status_lock, + is_async=not is_multiprocess, + name="pipeline_status_lock", + enable_logging=enable_logging, + ) def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: """return unified graph database lock for ensuring atomic operations""" - return UnifiedLock(lock=_graph_db_lock, is_async=not is_multiprocess, name="graph_db_lock", enable_logging=enable_logging) + return UnifiedLock( + lock=_graph_db_lock, + is_async=not is_multiprocess, + name="graph_db_lock", + enable_logging=enable_logging, + ) def initialize_share_data(workers: int = 1): diff --git a/lightrag/operate.py b/lightrag/operate.py index fb7b27a0..6c1bfd05 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -522,8 +522,9 @@ async def extract_entities( maybe_edges[tuple(sorted(k))].extend(v) from .kg.shared_storage import get_graph_db_lock - graph_db_lock = get_graph_db_lock(enable_logging = True) - + + graph_db_lock = get_graph_db_lock(enable_logging=True) + # Ensure that nodes and edges are merged and upserted atomically async with graph_db_lock: all_entities_data = await asyncio.gather( @@ -535,7 +536,9 @@ async def extract_entities( all_relationships_data = await asyncio.gather( *[ - _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) + _merge_edges_then_upsert( + k[0], k[1], v, knowledge_graph_inst, global_config + ) for k, v in maybe_edges.items() ] ) From 6a969e8de442fd2e9bb54eb3d4b8be9805c887cc Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 01:14:24 +0800 Subject: [PATCH 17/33] Disable logging for graph database lock acquisition and release --- lightrag/operate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 6c1bfd05..ce686feb 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -523,7 +523,7 @@ async def extract_entities( from .kg.shared_storage import get_graph_db_lock - graph_db_lock = get_graph_db_lock(enable_logging=True) + graph_db_lock = get_graph_db_lock(enable_logging=False) # Ensure that nodes and edges are merged and upserted atomically async with graph_db_lock: From 90527875fd74c0eb4ace001ef96d9de71ffa3146 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 15:22:06 +0800 Subject: [PATCH 18/33] Fix async issues in namespace init --- lightrag/kg/json_doc_status_impl.py | 2 +- lightrag/kg/json_kv_impl.py | 2 +- lightrag/kg/shared_storage.py | 18 ++++++++++-------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 01c657fa..824bd052 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -33,7 +33,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def initialize(self): """Initialize storage data""" # check need_init must before get_namespace_data - need_init = try_initialize_namespace(self.namespace) + need_init = await try_initialize_namespace(self.namespace) self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index c0b61a63..96217d4b 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -29,7 +29,7 @@ class JsonKVStorage(BaseKVStorage): async def initialize(self): """Initialize storage data""" # check need_init must before get_namespace_data - need_init = try_initialize_namespace(self.namespace) + need_init = await try_initialize_namespace(self.namespace) self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 9ccb2a99..68747ff8 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -355,7 +355,7 @@ async def get_all_update_flags_status() -> Dict[str, list]: return result -def try_initialize_namespace(namespace: str) -> bool: +async def try_initialize_namespace(namespace: str) -> bool: """ Returns True if the current worker(process) gets initialization permission for loading data later. The worker does not get the permission is prohibited to load data from files. @@ -365,15 +365,17 @@ def try_initialize_namespace(namespace: str) -> bool: if _init_flags is None: raise ValueError("Try to create nanmespace before Shared-Data is initialized") - if namespace not in _init_flags: - _init_flags[namespace] = True + async with get_internal_lock(): + if namespace not in _init_flags: + _init_flags[namespace] = True + direct_log( + f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + ) + return True direct_log( - f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" ) - return True - direct_log( - f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" - ) + return False From c854aabde09b569e15721212a31a285206a1e07f Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 15:25:10 +0800 Subject: [PATCH 19/33] Add process ID to log messages for better multi-process debugging clarity - Add PID to KV and Neo4j storage logs - Add PID to query context logs - Improve KV data count logging for llm cache --- lightrag/kg/json_doc_status_impl.py | 3 ++- lightrag/kg/json_kv_impl.py | 28 +++++++++++++++++++++++++--- lightrag/kg/neo4j_impl.py | 2 +- lightrag/operate.py | 2 ++ 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 824bd052..e05c04f6 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -40,7 +40,7 @@ class JsonDocStatusStorage(DocStatusStorage): async with self._storage_lock: self._data.update(loaded_data) logger.info( - f"Loaded document status storage with {len(loaded_data)} records" + f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records" ) async def filter_keys(self, keys: set[str]) -> set[str]: @@ -90,6 +90,7 @@ class JsonDocStatusStorage(DocStatusStorage): data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) + logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") write_json(data_dict, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 96217d4b..c0aa81b2 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -35,13 +35,34 @@ class JsonKVStorage(BaseKVStorage): loaded_data = load_json(self._file_name) or {} async with self._storage_lock: self._data.update(loaded_data) - logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data") + + # Calculate data count based on namespace + if self.namespace.endswith("cache"): + # For cache namespaces, sum the cache entries across all cache types + data_count = sum(len(first_level_dict) for first_level_dict in loaded_data.values() + if isinstance(first_level_dict, dict)) + else: + # For non-cache namespaces, use the original count method + data_count = len(loaded_data) + + logger.info(f"Process {os.getpid()} KV load {self.namespace} with {data_count} records") async def index_done_callback(self) -> None: async with self._storage_lock: data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) + + # Calculate data count based on namespace + if self.namespace.endswith("cache"): + # # For cache namespaces, sum the cache entries across all cache types + data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() + if isinstance(first_level_dict, dict)) + else: + # For non-cache namespaces, use the original count method + data_count = len(data_dict) + + logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") write_json(data_dict, self._file_name) async def get_all(self) -> dict[str, Any]: @@ -73,12 +94,13 @@ class JsonKVStorage(BaseKVStorage): return set(keys) - set(self._data.keys()) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return async with self._storage_lock: left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) + if left_data: + logger.info(f"Process {os.getpid()} KV inserting {len(left_data)} to {self.namespace}") + self._data.update(left_data) async def delete(self, ids: list[str]) -> None: async with self._storage_lock: diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index d0841eec..8d5a1a55 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -842,7 +842,7 @@ class Neo4JStorage(BaseGraphStorage): seen_edges.add(edge_id) logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges" ) finally: await result_set.consume() # Ensure result set is consumed diff --git a/lightrag/operate.py b/lightrag/operate.py index ce686feb..d16e170c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import json import re +import os from typing import Any, AsyncIterator from collections import Counter, defaultdict @@ -1027,6 +1028,7 @@ async def _build_query_context( text_chunks_db: BaseKVStorage, query_param: QueryParam, ): + logger.info(f"Process {os.getpid()} buidling query context...") if query_param.mode == "local": entities_context, relations_context, text_units_context = await _get_node_data( ll_keywords, From 020a6b5ae0605effd2a0a0c78903c31d146443ae Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 16:45:57 +0800 Subject: [PATCH 20/33] Refactor LLM cache config to use argparse and add status display --- lightrag/api/lightrag_server.py | 8 +++----- lightrag/api/utils_api.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 5df4f765..c42a816a 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -50,9 +50,6 @@ from .auth import auth_handler # This update allows the user to put a different.env file for each lightrag folder load_dotenv(".env", override=True) -# Read entity extraction cache config -enable_llm_cache = os.getenv("ENABLE_LLM_CACHE_FOR_EXTRACT", "false").lower() == "true" - # Initialize config parser config = configparser.ConfigParser() config.read("config.ini") @@ -326,7 +323,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable + enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -355,7 +352,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable + enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -419,6 +416,7 @@ def create_app(args): "doc_status_storage": args.doc_status_storage, "graph_storage": args.graph_storage, "vector_storage": args.vector_storage, + "enable_llm_cache": args.enable_llm_cache, }, "update_status": update_status, } diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index dc467449..da443558 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -359,6 +359,13 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) + + # Inject LLM cache configuration + args.enable_llm_cache = get_env_value( + "ENABLE_LLM_CACHE_FOR_EXTRACT", + False, + bool + ) ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name @@ -451,8 +458,10 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.history_turns}") ASCIIColors.white(" ├─ Cosine Threshold: ", end="") ASCIIColors.yellow(f"{args.cosine_threshold}") - ASCIIColors.white(" └─ Top-K: ", end="") + ASCIIColors.white(" ├─ Top-K: ", end="") ASCIIColors.yellow(f"{args.top_k}") + ASCIIColors.white(" └─ LLM Cache Enabled: ", end="") + ASCIIColors.yellow(f"{args.enable_llm_cache}") # System Configuration ASCIIColors.magenta("\n💾 Storage Configuration:") From e47883d8728ceb91f571632e87512d7398fe07e4 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 17:33:15 +0800 Subject: [PATCH 21/33] Add atomic data initialization lock to prevent race conditions --- lightrag/kg/json_doc_status_impl.py | 22 +++++++++-------- lightrag/kg/json_kv_impl.py | 38 +++++++++++++++-------------- lightrag/kg/shared_storage.py | 18 +++++++++++++- 3 files changed, 49 insertions(+), 29 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index e05c04f6..67a4705a 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -15,6 +15,7 @@ from lightrag.utils import ( from .shared_storage import ( get_namespace_data, get_storage_lock, + get_data_init_lock, try_initialize_namespace, ) @@ -27,21 +28,22 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._storage_lock = get_storage_lock() self._data = None async def initialize(self): """Initialize storage data""" - # check need_init must before get_namespace_data - need_init = await try_initialize_namespace(self.namespace) + self._storage_lock = get_storage_lock() self._data = await get_namespace_data(self.namespace) - if need_init: - loaded_data = load_json(self._file_name) or {} - async with self._storage_lock: - self._data.update(loaded_data) - logger.info( - f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records" - ) + async with get_data_init_lock(): + # check need_init must before get_namespace_data + need_init = await try_initialize_namespace(self.namespace) + if need_init: + loaded_data = load_json(self._file_name) or {} + async with self._storage_lock: + self._data.update(loaded_data) + logger.info( + f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records" + ) async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index c0aa81b2..5070c0b4 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -13,6 +13,7 @@ from lightrag.utils import ( from .shared_storage import ( get_namespace_data, get_storage_lock, + get_data_init_lock, try_initialize_namespace, ) @@ -23,29 +24,30 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._storage_lock = get_storage_lock() self._data = None async def initialize(self): """Initialize storage data""" - # check need_init must before get_namespace_data - need_init = await try_initialize_namespace(self.namespace) + self._storage_lock = get_storage_lock() self._data = await get_namespace_data(self.namespace) - if need_init: - loaded_data = load_json(self._file_name) or {} - async with self._storage_lock: - self._data.update(loaded_data) - - # Calculate data count based on namespace - if self.namespace.endswith("cache"): - # For cache namespaces, sum the cache entries across all cache types - data_count = sum(len(first_level_dict) for first_level_dict in loaded_data.values() - if isinstance(first_level_dict, dict)) - else: - # For non-cache namespaces, use the original count method - data_count = len(loaded_data) - - logger.info(f"Process {os.getpid()} KV load {self.namespace} with {data_count} records") + async with get_data_init_lock(): + # check need_init must before get_namespace_data + need_init = await try_initialize_namespace(self.namespace) + if need_init: + loaded_data = load_json(self._file_name) or {} + async with self._storage_lock: + self._data.update(loaded_data) + + # Calculate data count based on namespace + if self.namespace.endswith("cache"): + # For cache namespaces, sum the cache entries across all cache types + data_count = sum(len(first_level_dict) for first_level_dict in loaded_data.values() + if isinstance(first_level_dict, dict)) + else: + # For non-cache namespaces, use the original count method + data_count = len(loaded_data) + + logger.info(f"Process {os.getpid()} KV load {self.namespace} with {data_count} records") async def index_done_callback(self) -> None: async with self._storage_lock: diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 68747ff8..e3c25d34 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -39,6 +39,7 @@ _storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None _pipeline_status_lock: Optional[LockType] = None _graph_db_lock: Optional[LockType] = None +_data_init_lock: Optional[LockType] = None class UnifiedLock(Generic[T]): @@ -188,6 +189,16 @@ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: ) +def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: + """return unified data initialization lock for ensuring atomic data initialization""" + return UnifiedLock( + lock=_data_init_lock, + is_async=not is_multiprocess, + name="data_init_lock", + enable_logging=enable_logging, + ) + + def initialize_share_data(workers: int = 1): """ Initialize shared storage data for single or multi-process mode. @@ -214,6 +225,7 @@ def initialize_share_data(workers: int = 1): _internal_lock, \ _pipeline_status_lock, \ _graph_db_lock, \ + _data_init_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ @@ -226,15 +238,16 @@ def initialize_share_data(workers: int = 1): ) return - _manager = Manager() _workers = workers if workers > 1: is_multiprocess = True + _manager = Manager() _internal_lock = _manager.Lock() _storage_lock = _manager.Lock() _pipeline_status_lock = _manager.Lock() _graph_db_lock = _manager.Lock() + _data_init_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() _update_flags = _manager.dict() @@ -247,6 +260,7 @@ def initialize_share_data(workers: int = 1): _storage_lock = asyncio.Lock() _pipeline_status_lock = asyncio.Lock() _graph_db_lock = asyncio.Lock() + _data_init_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} _update_flags = {} @@ -415,6 +429,7 @@ def finalize_share_data(): _internal_lock, \ _pipeline_status_lock, \ _graph_db_lock, \ + _data_init_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ @@ -481,6 +496,7 @@ def finalize_share_data(): _internal_lock = None _pipeline_status_lock = None _graph_db_lock = None + _data_init_lock = None _update_flags = None direct_log(f"Process {os.getpid()} storage data finalization complete") From bc42afe7b65f92a5d73eb01f5410bdde9385ddd0 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 22:15:26 +0800 Subject: [PATCH 22/33] Unify llm_response_cache and hashing_kv, prevent creating an independent hashing_kv. --- lightrag/api/lightrag_server.py | 6 +-- lightrag/api/utils_api.py | 6 +-- lightrag/lightrag.py | 90 ++++----------------------------- lightrag/operate.py | 2 +- lightrag/utils.py | 22 ++++---- 5 files changed, 30 insertions(+), 96 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index c42a816a..8871650a 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -323,7 +323,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -352,7 +352,7 @@ def create_app(args): vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, - enable_llm_cache_for_entity_extract=args.enable_llm_cache, # Read from args + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, embedding_cache_config={ "enabled": True, "similarity_threshold": 0.95, @@ -416,7 +416,7 @@ def create_app(args): "doc_status_storage": args.doc_status_storage, "graph_storage": args.graph_storage, "vector_storage": args.vector_storage, - "enable_llm_cache": args.enable_llm_cache, + "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, }, "update_status": update_status, } diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index da443558..9a619f9e 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -361,7 +361,7 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) # Inject LLM cache configuration - args.enable_llm_cache = get_env_value( + args.enable_llm_cache_for_extract = get_env_value( "ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool @@ -460,8 +460,8 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.cosine_threshold}") ASCIIColors.white(" ├─ Top-K: ", end="") ASCIIColors.yellow(f"{args.top_k}") - ASCIIColors.white(" └─ LLM Cache Enabled: ", end="") - ASCIIColors.yellow(f"{args.enable_llm_cache}") + ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="") + ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}") # System Configuration ASCIIColors.magenta("\n💾 Storage Configuration:") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index b06520fc..a91aa6fa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -354,6 +354,7 @@ class LightRAG: namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), + global_config=asdict(self), # Add global_config to ensure cache works properly embedding_func=self.embedding_func, ) @@ -404,18 +405,8 @@ class LightRAG: embedding_func=None, ) - if self.llm_response_cache and hasattr( - self.llm_response_cache, "global_config" - ): - hashing_kv = self.llm_response_cache - else: - hashing_kv = self.key_string_value_json_storage_cls( # type: ignore - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ) + # Directly use llm_response_cache, don't create a new object + hashing_kv = self.llm_response_cache self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -1260,16 +1251,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) elif param.mode == "naive": @@ -1279,16 +1261,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) elif param.mode == "mix": @@ -1301,16 +1274,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache system_prompt=system_prompt, ) else: @@ -1344,14 +1308,7 @@ class LightRAG: text=query, param=param, global_config=asdict(self), - hashing_kv=self.llm_response_cache - or self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) param.hl_keywords = hl_keywords @@ -1375,16 +1332,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) elif param.mode == "naive": response = await naive_query( @@ -1393,16 +1341,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) elif param.mode == "mix": response = await mix_kg_vector_query( @@ -1414,16 +1353,7 @@ class LightRAG: self.text_chunks, param, asdict(self), - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - global_config=asdict(self), - embedding_func=self.embedding_func, - ), + hashing_kv=self.llm_response_cache, # Directly use llm_response_cache ) else: raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/operate.py b/lightrag/operate.py index d16e170c..9ba3b06d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -410,7 +410,6 @@ async def extract_entities( _prompt, "default", cache_type="extract", - force_llm_cache=True, ) if cached_return: logger.debug(f"Found cache for {arg_hash}") @@ -432,6 +431,7 @@ async def extract_entities( cache_type="extract", ), ) + logger.info(f"Extract: saved cache for {arg_hash}") return res if history_messages: diff --git a/lightrag/utils.py b/lightrag/utils.py index 1b65097e..02c3236d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -633,15 +633,15 @@ async def handle_cache( prompt, mode="default", cache_type=None, - force_llm_cache=False, ): """Generic cache handling function""" - if hashing_kv is None or not ( - force_llm_cache or hashing_kv.global_config.get("enable_llm_cache") - ): + if hashing_kv is None: return None, None, None, None - if mode != "default": + if mode != "default": # handle cache for all type of query + if not hashing_kv.global_config.get("enable_llm_cache"): + return None, None, None, None + # Get embedding cache configuration embedding_cache_config = hashing_kv.global_config.get( "embedding_cache_config", @@ -651,8 +651,7 @@ async def handle_cache( use_llm_check = embedding_cache_config.get("use_llm_check", False) quantized = min_val = max_val = None - if is_embedding_cache_enabled: - # Use embedding cache + if is_embedding_cache_enabled: # Use embedding simularity to match cache current_embedding = await hashing_kv.embedding_func([prompt]) llm_model_func = hashing_kv.global_config.get("llm_model_func") quantized, min_val, max_val = quantize_embedding(current_embedding[0]) @@ -674,8 +673,13 @@ async def handle_cache( logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})") return None, quantized, min_val, max_val - # For default mode or is_embedding_cache_enabled is False, use regular cache - # default mode is for extract_entities or naive query + else: # handle cache for entity extraction + if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): + return None, None, None, None + + # Here is the conditions of code reaching this point: + # 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled + # 2. Entity extract: enable_llm_cache_for_entity_extract is True if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: From c938989920b6b8c90f561ff4007f1e12f2f33596 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 9 Mar 2025 23:33:03 +0800 Subject: [PATCH 23/33] Fix llm cache save problem in json_kv storage --- lightrag/kg/json_doc_status_impl.py | 4 ++-- lightrag/kg/json_kv_impl.py | 6 ++---- lightrag/operate.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 67a4705a..11766fa7 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -96,12 +96,12 @@ class JsonDocStatusStorage(DocStatusStorage): write_json(data_dict, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - + logger.info(f"Inserting {len(data)} to {self.namespace}") async with self._storage_lock: self._data.update(data) + await self.index_done_callback() async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 5070c0b4..b90bf1d8 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -98,11 +98,9 @@ class JsonKVStorage(BaseKVStorage): async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return + logger.info(f"Inserting {len(data)} to {self.namespace}") async with self._storage_lock: - left_data = {k: v for k, v in data.items() if k not in self._data} - if left_data: - logger.info(f"Process {os.getpid()} KV inserting {len(left_data)} to {self.namespace}") - self._data.update(left_data) + self._data.update(data) async def delete(self, ids: list[str]) -> None: async with self._storage_lock: diff --git a/lightrag/operate.py b/lightrag/operate.py index 9ba3b06d..cfd8b6f8 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -403,6 +403,7 @@ async def extract_entities( else: _prompt = input_text + # TODO: add cache_type="extract" arg_hash = compute_args_hash(_prompt) cached_return, _1, _2, _3 = await handle_cache( llm_response_cache, @@ -431,7 +432,6 @@ async def extract_entities( cache_type="extract", ), ) - logger.info(f"Extract: saved cache for {arg_hash}") return res if history_messages: From 4977c718f1f0bc62eec4a855721ecc4c71ae5558 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 00:12:35 +0800 Subject: [PATCH 24/33] Improve KV storage initialize logic --- lightrag/kg/json_doc_status_impl.py | 2 +- lightrag/kg/json_kv_impl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 11766fa7..b5249540 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -33,10 +33,10 @@ class JsonDocStatusStorage(DocStatusStorage): async def initialize(self): """Initialize storage data""" self._storage_lock = get_storage_lock() - self._data = await get_namespace_data(self.namespace) async with get_data_init_lock(): # check need_init must before get_namespace_data need_init = await try_initialize_namespace(self.namespace) + self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index b90bf1d8..81439151 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -29,10 +29,10 @@ class JsonKVStorage(BaseKVStorage): async def initialize(self): """Initialize storage data""" self._storage_lock = get_storage_lock() - self._data = await get_namespace_data(self.namespace) async with get_data_init_lock(): # check need_init must before get_namespace_data need_init = await try_initialize_namespace(self.namespace) + self._data = await get_namespace_data(self.namespace) if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: From d2708b966d5f623b9ee3d68736d51e2c83063b30 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 01:17:25 +0800 Subject: [PATCH 25/33] Added update flag to avoid persistence if no data is changed for KV storage --- lightrag/kg/json_doc_status_impl.py | 22 ++++++++++++---- lightrag/kg/json_kv_impl.py | 41 ++++++++++++++++++----------- lightrag/kg/shared_storage.py | 15 +++++++++++ 3 files changed, 58 insertions(+), 20 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index b5249540..c33059ad 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -16,6 +16,9 @@ from .shared_storage import ( get_namespace_data, get_storage_lock, get_data_init_lock, + get_update_flag, + set_all_update_flags, + clear_all_update_flags, try_initialize_namespace, ) @@ -29,10 +32,13 @@ class JsonDocStatusStorage(DocStatusStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._data = None + self._storage_lock = None + self.storage_updated = None async def initialize(self): """Initialize storage data""" self._storage_lock = get_storage_lock() + self.storage_updated = await get_update_flag(self.namespace) async with get_data_init_lock(): # check need_init must before get_namespace_data need_init = await try_initialize_namespace(self.namespace) @@ -89,11 +95,13 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - data_dict = ( - dict(self._data) if hasattr(self._data, "_getvalue") else self._data - ) - logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") - write_json(data_dict, self._file_name) + if self.storage_updated: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") + write_json(data_dict, self._file_name) + await clear_all_update_flags(self.namespace) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: @@ -101,6 +109,7 @@ class JsonDocStatusStorage(DocStatusStorage): logger.info(f"Inserting {len(data)} to {self.namespace}") async with self._storage_lock: self._data.update(data) + await set_all_update_flags(self.namespace) await self.index_done_callback() @@ -112,9 +121,12 @@ class JsonDocStatusStorage(DocStatusStorage): async with self._storage_lock: for doc_id in doc_ids: self._data.pop(doc_id, None) + await set_all_update_flags(self.namespace) await self.index_done_callback() async def drop(self) -> None: """Drop the storage""" async with self._storage_lock: self._data.clear() + await set_all_update_flags(self.namespace) + await self.index_done_callback() diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 81439151..c69b53ec 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -14,6 +14,9 @@ from .shared_storage import ( get_namespace_data, get_storage_lock, get_data_init_lock, + get_update_flag, + set_all_update_flags, + clear_all_update_flags, try_initialize_namespace, ) @@ -25,10 +28,13 @@ class JsonKVStorage(BaseKVStorage): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") self._data = None + self._storage_lock = None + self.storage_updated = None async def initialize(self): """Initialize storage data""" self._storage_lock = get_storage_lock() + self.storage_updated = await get_update_flag(self.namespace) async with get_data_init_lock(): # check need_init must before get_namespace_data need_init = await try_initialize_namespace(self.namespace) @@ -51,21 +57,24 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - data_dict = ( - dict(self._data) if hasattr(self._data, "_getvalue") else self._data - ) - - # Calculate data count based on namespace - if self.namespace.endswith("cache"): - # # For cache namespaces, sum the cache entries across all cache types - data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() - if isinstance(first_level_dict, dict)) - else: - # For non-cache namespaces, use the original count method - data_count = len(data_dict) - - logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") - write_json(data_dict, self._file_name) + if self.storage_updated: + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) + + # Calculate data count based on namespace + if self.namespace.endswith("cache"): + # # For cache namespaces, sum the cache entries across all cache types + data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() + if isinstance(first_level_dict, dict)) + else: + # For non-cache namespaces, use the original count method + data_count = len(data_dict) + + logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") + write_json(data_dict, self._file_name) + await clear_all_update_flags(self.namespace) + async def get_all(self) -> dict[str, Any]: """Get all data from storage @@ -101,9 +110,11 @@ class JsonKVStorage(BaseKVStorage): logger.info(f"Inserting {len(data)} to {self.namespace}") async with self._storage_lock: self._data.update(data) + await set_all_update_flags(self.namespace) async def delete(self, ids: list[str]) -> None: async with self._storage_lock: for doc_id in ids: self._data.pop(doc_id, None) + await set_all_update_flags(self.namespace) await self.index_done_callback() diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index e3c25d34..9ce04d23 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -344,6 +344,21 @@ async def set_all_update_flags(namespace: str): else: _update_flags[namespace][i] = True +async def clear_all_update_flags(namespace: str): + """Clear all update flag of namespace indicating all workers need to reload data from files""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + async with get_internal_lock(): + if namespace not in _update_flags: + raise ValueError(f"Namespace {namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[namespace])): + if is_multiprocess: + _update_flags[namespace][i].value = False + else: + _update_flags[namespace][i] = False async def get_all_update_flags_status() -> Dict[str, list]: """ From 6b0acce6440dea3e438b10f5e28208277440572b Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 01:45:58 +0800 Subject: [PATCH 26/33] Avoid redundant llm cache updates --- lightrag/kg/json_doc_status_impl.py | 3 ++- lightrag/kg/json_kv_impl.py | 3 ++- lightrag/utils.py | 30 +++++++++++++++++++++++++---- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index c33059ad..3c1fb4c2 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -20,6 +20,7 @@ from .shared_storage import ( set_all_update_flags, clear_all_update_flags, try_initialize_namespace, + is_multiprocess, ) @@ -95,7 +96,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - if self.storage_updated: + if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index c69b53ec..b5d963fb 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -18,6 +18,7 @@ from .shared_storage import ( set_all_update_flags, clear_all_update_flags, try_initialize_namespace, + is_multiprocess, ) @@ -57,7 +58,7 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - if self.storage_updated: + if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 02c3236d..56548420 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -705,9 +705,22 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): - if hashing_kv is None or hasattr(cache_data.content, "__aiter__"): + """Save data to cache, with improved handling for streaming responses and duplicate content. + + Args: + hashing_kv: The key-value storage for caching + cache_data: The cache data to save + """ + # Skip if storage is None or content is a streaming response + if hashing_kv is None or not cache_data.content: return - + + # If content is a streaming response, don't cache it + if hasattr(cache_data.content, "__aiter__"): + logger.debug("Streaming response detected, skipping cache") + return + + # Get existing cache data if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = ( await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash) @@ -715,7 +728,15 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): ) else: mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} - + + # Check if we already have identical content cached + if cache_data.args_hash in mode_cache: + existing_content = mode_cache[cache_data.args_hash].get("return") + if existing_content == cache_data.content: + logger.info(f"Cache content unchanged for {cache_data.args_hash}, skipping update") + return + + # Update cache with new content mode_cache[cache_data.args_hash] = { "return": cache_data.content, "cache_type": cache_data.cache_type, @@ -729,7 +750,8 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "embedding_max": cache_data.max_val, "original_prompt": cache_data.prompt, } - + + # Only upsert if there's actual new content await hashing_kv.upsert({cache_data.mode: mode_cache}) From 14e1b31d1cd5273d188ecb7ecc3df4d17ff95075 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 02:05:55 +0800 Subject: [PATCH 27/33] Improved logging clarity in storage operations --- lightrag/kg/json_doc_status_impl.py | 2 +- lightrag/kg/json_kv_impl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 3c1fb4c2..5b378c17 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -107,7 +107,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.info(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: self._data.update(data) await set_all_update_flags(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index b5d963fb..6c855a25 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -108,7 +108,7 @@ class JsonKVStorage(BaseKVStorage): async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return - logger.info(f"Inserting {len(data)} to {self.namespace}") + logger.info(f"Inserting {len(data)} records to {self.namespace}") async with self._storage_lock: self._data.update(data) await set_all_update_flags(self.namespace) From 4065a7df92cbe388741b463c4e48d3863920ef87 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 02:07:19 +0800 Subject: [PATCH 28/33] Fix linting --- lightrag/api/utils_api.py | 6 ++---- lightrag/kg/json_doc_status_impl.py | 8 +++++-- lightrag/kg/json_kv_impl.py | 33 +++++++++++++++++++---------- lightrag/kg/shared_storage.py | 2 ++ lightrag/lightrag.py | 4 +++- lightrag/utils.py | 16 ++++++++------ 6 files changed, 44 insertions(+), 25 deletions(-) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 9a619f9e..ffe63abd 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -359,12 +359,10 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace: # Inject chunk configuration args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int) args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int) - + # Inject LLM cache configuration args.enable_llm_cache_for_extract = get_env_value( - "ENABLE_LLM_CACHE_FOR_EXTRACT", - False, - bool + "ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool ) ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 5b378c17..4502397b 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -96,11 +96,15 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) - logger.info(f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}") + logger.info( + f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}" + ) write_json(data_dict, self._file_name) await clear_all_update_flags(self.namespace) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 6c855a25..80abe92e 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -44,21 +44,28 @@ class JsonKVStorage(BaseKVStorage): loaded_data = load_json(self._file_name) or {} async with self._storage_lock: self._data.update(loaded_data) - + # Calculate data count based on namespace if self.namespace.endswith("cache"): # For cache namespaces, sum the cache entries across all cache types - data_count = sum(len(first_level_dict) for first_level_dict in loaded_data.values() - if isinstance(first_level_dict, dict)) + data_count = sum( + len(first_level_dict) + for first_level_dict in loaded_data.values() + if isinstance(first_level_dict, dict) + ) else: # For non-cache namespaces, use the original count method data_count = len(loaded_data) - - logger.info(f"Process {os.getpid()} KV load {self.namespace} with {data_count} records") + + logger.info( + f"Process {os.getpid()} KV load {self.namespace} with {data_count} records" + ) async def index_done_callback(self) -> None: async with self._storage_lock: - if (is_multiprocess and self.storage_updated.value) or (not is_multiprocess and self.storage_updated): + if (is_multiprocess and self.storage_updated.value) or ( + not is_multiprocess and self.storage_updated + ): data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) @@ -66,17 +73,21 @@ class JsonKVStorage(BaseKVStorage): # Calculate data count based on namespace if self.namespace.endswith("cache"): # # For cache namespaces, sum the cache entries across all cache types - data_count = sum(len(first_level_dict) for first_level_dict in data_dict.values() - if isinstance(first_level_dict, dict)) + data_count = sum( + len(first_level_dict) + for first_level_dict in data_dict.values() + if isinstance(first_level_dict, dict) + ) else: # For non-cache namespaces, use the original count method data_count = len(data_dict) - - logger.info(f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}") + + logger.info( + f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}" + ) write_json(data_dict, self._file_name) await clear_all_update_flags(self.namespace) - async def get_all(self) -> dict[str, Any]: """Get all data from storage diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 9ce04d23..63ff1f0d 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -344,6 +344,7 @@ async def set_all_update_flags(namespace: str): else: _update_flags[namespace][i] = True + async def clear_all_update_flags(namespace: str): """Clear all update flag of namespace indicating all workers need to reload data from files""" global _update_flags @@ -360,6 +361,7 @@ async def clear_all_update_flags(namespace: str): else: _update_flags[namespace][i] = False + async def get_all_update_flags_status() -> Dict[str, list]: """ Get update flags status for all namespaces. diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a91aa6fa..ceb47a01 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -354,7 +354,9 @@ class LightRAG: namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), - global_config=asdict(self), # Add global_config to ensure cache works properly + global_config=asdict( + self + ), # Add global_config to ensure cache works properly embedding_func=self.embedding_func, ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 56548420..e8f79610 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -706,7 +706,7 @@ class CacheData: async def save_to_cache(hashing_kv, cache_data: CacheData): """Save data to cache, with improved handling for streaming responses and duplicate content. - + Args: hashing_kv: The key-value storage for caching cache_data: The cache data to save @@ -714,12 +714,12 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): # Skip if storage is None or content is a streaming response if hashing_kv is None or not cache_data.content: return - + # If content is a streaming response, don't cache it if hasattr(cache_data.content, "__aiter__"): logger.debug("Streaming response detected, skipping cache") return - + # Get existing cache data if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = ( @@ -728,14 +728,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): ) else: mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} - + # Check if we already have identical content cached if cache_data.args_hash in mode_cache: existing_content = mode_cache[cache_data.args_hash].get("return") if existing_content == cache_data.content: - logger.info(f"Cache content unchanged for {cache_data.args_hash}, skipping update") + logger.info( + f"Cache content unchanged for {cache_data.args_hash}, skipping update" + ) return - + # Update cache with new content mode_cache[cache_data.args_hash] = { "return": cache_data.content, @@ -750,7 +752,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): "embedding_max": cache_data.max_val, "original_prompt": cache_data.prompt, } - + # Only upsert if there's actual new content await hashing_kv.upsert({cache_data.mode: mode_cache}) From 46610682ce9ce6197e7319f805408ef475d1ed0d Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 15:41:00 +0800 Subject: [PATCH 29/33] Fix data persistence issue in single-process mode In single-process mode, data updates and persistence were not working properly because the update flags were not being correctly handled between different objects. --- lightrag/kg/json_doc_status_impl.py | 5 +---- lightrag/kg/json_kv_impl.py | 5 +---- lightrag/kg/shared_storage.py | 13 ++++++++++--- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 4502397b..57a34ae5 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -20,7 +20,6 @@ from .shared_storage import ( set_all_update_flags, clear_all_update_flags, try_initialize_namespace, - is_multiprocess, ) @@ -96,9 +95,7 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - if (is_multiprocess and self.storage_updated.value) or ( - not is_multiprocess and self.storage_updated - ): + if self.storage_updated.value: data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 80abe92e..e7deaf15 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -18,7 +18,6 @@ from .shared_storage import ( set_all_update_flags, clear_all_update_flags, try_initialize_namespace, - is_multiprocess, ) @@ -63,9 +62,7 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: async with self._storage_lock: - if (is_multiprocess and self.storage_updated.value) or ( - not is_multiprocess and self.storage_updated - ): + if self.storage_updated.value: data_dict = ( dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 63ff1f0d..9bf072be 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -322,7 +322,12 @@ async def get_update_flag(namespace: str): if is_multiprocess and _manager is not None: new_update_flag = _manager.Value("b", False) else: - new_update_flag = False + # Create a simple mutable object to store boolean value for compatibility with mutiprocess + class MutableBoolean: + def __init__(self, initial_value=False): + self.value = initial_value + + new_update_flag = MutableBoolean(False) _update_flags[namespace].append(new_update_flag) return new_update_flag @@ -342,7 +347,8 @@ async def set_all_update_flags(namespace: str): if is_multiprocess: _update_flags[namespace][i].value = True else: - _update_flags[namespace][i] = True + # Use .value attribute instead of direct assignment + _update_flags[namespace][i].value = True async def clear_all_update_flags(namespace: str): @@ -359,7 +365,8 @@ async def clear_all_update_flags(namespace: str): if is_multiprocess: _update_flags[namespace][i].value = False else: - _update_flags[namespace][i] = False + # Use .value attribute instead of direct assignment + _update_flags[namespace][i].value = False async def get_all_update_flags_status() -> Dict[str, list]: From 57a41eedb89e7d0d876b5b9b792aba4d06dd7fde Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 15:41:46 +0800 Subject: [PATCH 30/33] Fix linting --- lightrag/kg/shared_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 9bf072be..382e490b 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -326,7 +326,7 @@ async def get_update_flag(namespace: str): class MutableBoolean: def __init__(self, initial_value=False): self.value = initial_value - + new_update_flag = MutableBoolean(False) _update_flags[namespace].append(new_update_flag) From 3cca18c59c3d81cad721154a78673d131593d4b2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 16:48:59 +0800 Subject: [PATCH 31/33] Refactor pipeline status updates and entity extraction. - Let all parrallel jobs using one pipe_status objects - Improved thread safety with pipeline_status_lock - Only pipeline jobs can add message to pipe_status - Marked insert_custom_chunks as deprecated --- lightrag/lightrag.py | 22 ++++++++++++---------- lightrag/operate.py | 35 ++++++++++++++++++++++------------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6abd7a17..5b42fa3d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -583,6 +583,7 @@ class LightRAG: split_by_character, split_by_character_only ) + # TODO: deprecated, use insert instead def insert_custom_chunks( self, full_text: str, @@ -594,6 +595,7 @@ class LightRAG: self.ainsert_custom_chunks(full_text, text_chunks, doc_id) ) + # TODO: deprecated, use ainsert instead async def ainsert_custom_chunks( self, full_text: str, text_chunks: list[str], doc_id: str | None = None ) -> None: @@ -885,7 +887,7 @@ class LightRAG: self.chunks_vdb.upsert(chunks) ) entity_relation_task = asyncio.create_task( - self._process_entity_relation_graph(chunks) + self._process_entity_relation_graph(chunks, pipeline_status, pipeline_status_lock) ) full_docs_task = asyncio.create_task( self.full_docs.upsert( @@ -1000,21 +1002,23 @@ class LightRAG: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: + async def _process_entity_relation_graph(self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None) -> None: try: await extract_entities( chunk, knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, - llm_response_cache=self.llm_response_cache, global_config=asdict(self), + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, + llm_response_cache=self.llm_response_cache, ) except Exception as e: logger.error("Failed to extract entities and relationships") raise e - async def _insert_done(self) -> None: + async def _insert_done(self, pipeline_status=None, pipeline_status_lock=None) -> None: tasks = [ cast(StorageNameSpace, storage_inst).index_done_callback() for storage_inst in [ # type: ignore @@ -1033,12 +1037,10 @@ class LightRAG: log_message = "All Insert done" logger.info(log_message) - # 获取 pipeline_status 并更新 latest_message 和 history_messages - from lightrag.kg.shared_storage import get_namespace_data - - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) def insert_custom_kg( self, custom_kg: dict[str, Any], full_doc_id: str = None diff --git a/lightrag/operate.py b/lightrag/operate.py index ba39fe89..5d6b7c7d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -340,11 +340,10 @@ async def extract_entities( entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict[str, str], + pipeline_status: dict = None, + pipeline_status_lock = None, llm_response_cache: BaseKVStorage | None = None, ) -> None: - from lightrag.kg.shared_storage import get_namespace_data - - pipeline_status = await get_namespace_data("pipeline_status") use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] enable_llm_cache_for_entity_extract: bool = global_config[ @@ -507,8 +506,10 @@ async def extract_entities( relations_count = len(maybe_edges) log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return dict(maybe_nodes), dict(maybe_edges) tasks = [_process_single_content(c) for c in ordered_chunks] @@ -547,25 +548,33 @@ async def extract_entities( if not (all_entities_data or all_relationships_data): log_message = "Didn't extract any entities and relationships." logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) return if not all_entities_data: log_message = "Didn't extract any entities" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) if not all_relationships_data: log_message = "Didn't extract any relationships" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)" logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) verbose_debug( f"New entities:{all_entities_data}, relationships:{all_relationships_data}" ) From 5d64f3b0a03fbb856504f5729530166d1107f495 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 17:14:14 +0800 Subject: [PATCH 32/33] Improved auto-scan task initialization and status tracking. - Added autoscan status tracking in pipeline - Ensured auto-scan runs only once per startup --- lightrag/api/lightrag_server.py | 34 +++++++++++++++++---------------- lightrag/kg/shared_storage.py | 1 + 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8871650a..fd09a691 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -141,23 +141,25 @@ def create_app(args): try: # Initialize database connections await rag.initialize_storages() - await initialize_pipeline_status() - # Auto scan documents if enabled - if args.auto_scan_at_startup: - # Check if a task is already running (with lock protection) - pipeline_status = await get_namespace_data("pipeline_status") - should_start_task = False - async with get_pipeline_status_lock(): - if not pipeline_status.get("busy", False): - should_start_task = True - # Only start the task if no other task is running - if should_start_task: - # Create background task - task = asyncio.create_task(run_scanning_process(rag, doc_manager)) - app.state.background_tasks.add(task) - task.add_done_callback(app.state.background_tasks.discard) - logger.info("Auto scan task started at startup.") + await initialize_pipeline_status() + pipeline_status = await get_namespace_data("pipeline_status") + + should_start_autoscan = False + async with get_pipeline_status_lock(): + # Auto scan documents if enabled + if args.auto_scan_at_startup: + if not pipeline_status.get("autoscanned", False): + pipeline_status["autoscanned"] = True + should_start_autoscan = True + + # Only run auto scan when no other process started it first + if should_start_autoscan: + # Create background task + task = asyncio.create_task(run_scanning_process(rag, doc_manager)) + app.state.background_tasks.add(task) + task.add_done_callback(app.state.background_tasks.discard) + logger.info(f"Process {os.getpid()} auto scan task started at startup.") ASCIIColors.green("\nServer is ready to accept connections! 🚀\n") diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 382e490b..736887a6 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -286,6 +286,7 @@ async def initialize_pipeline_status(): history_messages = _manager.list() if is_multiprocess else [] pipeline_namespace.update( { + "autoscanned": False, # Auto-scan started "busy": False, # Control concurrent processes "job_name": "Default Job", # Current job name (indexing files/indexing texts) "job_start": None, # Job start time From bbff3ed0abc26b5b403e75c432ca1fd95b6071c3 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 10 Mar 2025 17:30:40 +0800 Subject: [PATCH 33/33] Fix linting --- lightrag/lightrag.py | 12 +++++++++--- lightrag/operate.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5b42fa3d..3cd379b6 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -887,7 +887,9 @@ class LightRAG: self.chunks_vdb.upsert(chunks) ) entity_relation_task = asyncio.create_task( - self._process_entity_relation_graph(chunks, pipeline_status, pipeline_status_lock) + self._process_entity_relation_graph( + chunks, pipeline_status, pipeline_status_lock + ) ) full_docs_task = asyncio.create_task( self.full_docs.upsert( @@ -1002,7 +1004,9 @@ class LightRAG: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) - async def _process_entity_relation_graph(self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None) -> None: + async def _process_entity_relation_graph( + self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None + ) -> None: try: await extract_entities( chunk, @@ -1018,7 +1022,9 @@ class LightRAG: logger.error("Failed to extract entities and relationships") raise e - async def _insert_done(self, pipeline_status=None, pipeline_status_lock=None) -> None: + async def _insert_done( + self, pipeline_status=None, pipeline_status_lock=None + ) -> None: tasks = [ cast(StorageNameSpace, storage_inst).index_done_callback() for storage_inst in [ # type: ignore diff --git a/lightrag/operate.py b/lightrag/operate.py index 5d6b7c7d..e352ff79 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -341,7 +341,7 @@ async def extract_entities( relationships_vdb: BaseVectorStorage, global_config: dict[str, str], pipeline_status: dict = None, - pipeline_status_lock = None, + pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, ) -> None: use_llm_func: callable = global_config["llm_model_func"]