From 182aee2e146bfb764d592a750a87d2104d05d35b Mon Sep 17 00:00:00 2001 From: frederikhendrix Date: Mon, 7 Apr 2025 19:09:31 +0200 Subject: [PATCH] get_node added and all to base.py and to neo4j_impl.py file --- lightrag/base.py | 20 +++++ lightrag/kg/neo4j_impl.py | 156 ++++++++++++++++++++++++++++++++++++++ lightrag/operate.py | 107 ++++++++++++++------------ 3 files changed, 234 insertions(+), 49 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 5cf5ab61..f30cbf17 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -309,6 +309,26 @@ class BaseGraphStorage(StorageNameSpace, ABC): async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Upsert a node into the graph.""" + @abstractmethod + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + """Get nodes as a batch using UNWIND""" + + @abstractmethod + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + """Node degrees as a batch using UNWIND""" + + @abstractmethod + async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: + """Edge degrees as a batch using UNWIND also uses node_degrees_batch""" + + @abstractmethod + async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: + """Get edges as a batch using UNWIND""" + + @abstractmethod + async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: + """"Get nodes edges as a batch using UNWIND""" + @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Upsert an edge into the graph.""" diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index b84a0c6a..1e6c8711 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -314,6 +314,37 @@ class Neo4JStorage(BaseGraphStorage): logger.error(f"Error getting node for {node_id}: {str(e)}") raise + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + """ + Retrieve multiple nodes in one query using UNWIND. + + Args: + node_ids: List of node entity IDs to fetch. + + Returns: + A dictionary mapping each node_id to its node data (or None if not found). + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + UNWIND $node_ids AS id + MATCH (n:base {entity_id: id}) + RETURN n.entity_id AS entity_id, n + """ + result = await session.run(query, node_ids=node_ids) + nodes = {} + async for record in result: + entity_id = record["entity_id"] + node = record["n"] + node_dict = dict(node) + # Remove the 'base' label if present in a 'labels' property + if "labels" in node_dict: + node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"] + nodes[entity_id] = node_dict + await result.consume() # Make sure to consume the result fully + return nodes + async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. @@ -357,6 +388,41 @@ class Neo4JStorage(BaseGraphStorage): logger.error(f"Error getting node degree for {node_id}: {str(e)}") raise + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + """ + Retrieve the degree for multiple nodes in a single query using UNWIND. + + Args: + node_ids: List of node labels (entity_id values) to look up. + + Returns: + A dictionary mapping each node_id to its degree (number of relationships). + If a node is not found, its degree will be set to 0. + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + UNWIND $node_ids AS id + MATCH (n:base {entity_id: id}) + RETURN n.entity_id AS entity_id, count { (n)--() } AS degree; + """ + result = await session.run(query, node_ids=node_ids) + degrees = {} + async for record in result: + entity_id = record["entity_id"] + degrees[entity_id] = record["degree"] + await result.consume() # Ensure result is fully consumed + + # For any node_id that did not return a record, set degree to 0. + for nid in node_ids: + if nid not in degrees: + logger.warning(f"No node found with label '{nid}'") + degrees[nid] = 0 + + logger.debug(f"Neo4j batch node degree query returned: {degrees}") + return degrees + async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes. @@ -376,6 +442,30 @@ class Neo4JStorage(BaseGraphStorage): degrees = int(src_degree) + int(trg_degree) return degrees + + async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: + """ + Calculate the combined degree for each edge (sum of the source and target node degrees) + in batch using the already implemented node_degrees_batch. + + Args: + edge_pairs: List of (src, tgt) tuples. + + Returns: + A dictionary mapping each (src, tgt) tuple to the sum of their degrees. + """ + # Collect unique node IDs from all edge pairs. + unique_node_ids = {src for src, _ in edge_pairs} + unique_node_ids.update({tgt for _, tgt in edge_pairs}) + + # Get degrees for all nodes in one go. + degrees = await self.node_degrees_batch(list(unique_node_ids)) + + # Sum up degrees for each edge pair. + edge_degrees = {} + for src, tgt in edge_pairs: + edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) + return edge_degrees async def get_edge( self, source_node_id: str, target_node_id: str @@ -463,6 +553,43 @@ class Neo4JStorage(BaseGraphStorage): ) raise + async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: + """ + Retrieve edge properties for multiple (src, tgt) pairs in one query. + + Args: + pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] + + Returns: + A dictionary mapping (src, tgt) tuples to their edge properties. + """ + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = """ + UNWIND $pairs AS pair + MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt}) + RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges + """ + result = await session.run(query, pairs=pairs) + edges_dict = {} + async for record in result: + src = record["src_id"] + tgt = record["tgt_id"] + edges = record["edges"] + if edges and len(edges) > 0: + edge_props = edges[0] # choose the first if multiple exist + # Ensure required keys exist with defaults + for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items(): + if key not in edge_props: + edge_props[key] = default + edges_dict[(src, tgt)] = edge_props + else: + # No edge found – set default edge properties + edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None} + await result.consume() + return edges_dict + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node identified by its label. @@ -523,6 +650,35 @@ class Neo4JStorage(BaseGraphStorage): logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") raise + async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: + """ + Batch retrieve edges for multiple nodes in one query using UNWIND. + + Args: + node_ids: List of node IDs (entity_id) for which to retrieve edges. + + Returns: + A dictionary mapping each node ID to its list of edge tuples (source, target). + """ + async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session: + query = """ + UNWIND $node_ids AS id + MATCH (n:base {entity_id: id}) + OPTIONAL MATCH (n)-[r]-(connected:base) + RETURN id AS queried_id, n.entity_id AS source_entity_id, connected.entity_id AS target_entity_id + """ + result = await session.run(query, node_ids=node_ids) + # Initialize the dictionary with empty lists for each node ID + edges_dict = {node_id: [] for node_id in node_ids} + async for record in result: + queried_id = record["queried_id"] + source_label = record["source_entity_id"] + target_label = record["target_entity_id"] + if source_label and target_label: + edges_dict[queried_id].append((source_label, target_label)) + await result.consume() # Ensure results are fully consumed + return edges_dict + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), diff --git a/lightrag/operate.py b/lightrag/operate.py index 0e223bb6..a5d804bc 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1233,16 +1233,20 @@ async def _get_node_data( if not len(results): return "", "", "" - # get entity information - node_datas, node_degrees = await asyncio.gather( - asyncio.gather( - *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results] - ), - asyncio.gather( - *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results] - ), + + # Extract all entity IDs from your results list + node_ids = [r["entity_name"] for r in results] + + # Call the batch node retrieval and degree functions concurrently. + nodes_dict, degrees_dict = await asyncio.gather( + knowledge_graph_inst.get_nodes_batch(node_ids), + knowledge_graph_inst.node_degrees_batch(node_ids) ) + # Now, if you need the node data and degree in order: + node_datas = [nodes_dict.get(nid) for nid in node_ids] + node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids] + if not all([n is not None for n in node_datas]): logger.warning("Some nodes are missing, maybe the storage is damaged") @@ -1374,9 +1378,10 @@ async def _find_most_related_text_unit_from_entities( all_one_hop_nodes.update([e[1] for e in this_edges]) all_one_hop_nodes = list(all_one_hop_nodes) - all_one_hop_nodes_data = await asyncio.gather( - *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes] - ) + + # Batch retrieve one-hop node data using get_nodes_batch + all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes) + all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes] # Add null check for node data all_one_hop_text_units_lookup = { @@ -1512,29 +1517,34 @@ async def _get_edge_data( if not len(results): return "", "", "" - edge_datas, edge_degree = await asyncio.gather( - asyncio.gather( - *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] - ), - asyncio.gather( - *[ - knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) - for r in results - ] - ), + # Prepare edge pairs in two forms: + # For the batch edge properties function, use dicts. + edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] + # For edge degrees, use tuples. + edge_pairs_tuples = [(r["src_id"], r["tgt_id"]) for r in results] + + # Call the batched functions concurrently. + edge_data_dict, edge_degrees_dict = await asyncio.gather( + knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), + knowledge_graph_inst.get_edges_degree_batch(edge_pairs_tuples) ) - edge_datas = [ - { - "src_id": k["src_id"], - "tgt_id": k["tgt_id"], - "rank": d, - "created_at": k.get("__created_at__", None), - **v, - } - for k, v, d in zip(results, edge_datas, edge_degree) - if v is not None - ] + # Reconstruct edge_datas list in the same order as results. + edge_datas = [] + for k in results: + pair = (k["src_id"], k["tgt_id"]) + edge_props = edge_data_dict.get(pair) + if edge_props is not None: + # Use edge degree from the batch as rank. + combined = { + "src_id": k["src_id"], + "tgt_id": k["tgt_id"], + "rank": edge_degrees_dict.get(pair, k.get("rank", 0)), + "created_at": k.get("__created_at__", None), + **edge_props, + } + edge_datas.append(combined) + edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) @@ -1640,24 +1650,23 @@ async def _find_most_related_entities_from_relationships( entity_names.append(e["tgt_id"]) seen.add(e["tgt_id"]) - node_datas, node_degrees = await asyncio.gather( - asyncio.gather( - *[ - knowledge_graph_inst.get_node(entity_name) - for entity_name in entity_names - ] - ), - asyncio.gather( - *[ - knowledge_graph_inst.node_degree(entity_name) - for entity_name in entity_names - ] - ), + # Batch approach: Retrieve nodes and their degrees concurrently with one query each. + nodes_dict, degrees_dict = await asyncio.gather( + knowledge_graph_inst.get_nodes_batch(entity_names), + knowledge_graph_inst.get_node_degrees_batch(entity_names) ) - node_datas = [ - {**n, "entity_name": k, "rank": d} - for k, n, d in zip(entity_names, node_datas, node_degrees) - ] + + # Rebuild the list in the same order as entity_names + node_datas = [] + for entity_name in entity_names: + node = nodes_dict.get(entity_name) + degree = degrees_dict.get(entity_name, 0) + if node is None: + logger.warning(f"Node '{entity_name}' not found in batch retrieval.") + continue + # Combine the node data with the entity name and computed degree (as rank) + combined = {**node, "entity_name": entity_name, "rank": degree} + node_datas.append(combined) len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size(