From 99f24cd51eb5cdbd3d62a3326da9e5f28092a8e7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sat, 12 Apr 2025 22:42:43 +0800 Subject: [PATCH] Make batch methods in BaseGraphStorage optional with default implementations - Removing the @abstractmethod decorator - Adding default implementations that call the corresponding non-batch methods - Preserving full backward compatibility with existing implementations --- lightrag/base.py | 69 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index 5dd657b8..969a5a8c 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -361,25 +361,74 @@ class BaseGraphStorage(StorageNameSpace, ABC): or None if the node doesn't exist """ - @abstractmethod async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: - """Get nodes as a batch using UNWIND""" + """Get nodes as a batch using UNWIND + + Default implementation fetches nodes one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for node_id in node_ids: + node = await self.get_node(node_id) + if node is not None: + result[node_id] = node + return result - @abstractmethod async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: - """Node degrees as a batch using UNWIND""" + """Node degrees as a batch using UNWIND + + Default implementation fetches node degrees one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for node_id in node_ids: + degree = await self.node_degree(node_id) + result[node_id] = degree + return result - @abstractmethod async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]: - """Edge degrees as a batch using UNWIND also uses node_degrees_batch""" + """Edge degrees as a batch using UNWIND also uses node_degrees_batch + + Default implementation calculates edge degrees one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for src_id, tgt_id in edge_pairs: + degree = await self.edge_degree(src_id, tgt_id) + result[(src_id, tgt_id)] = degree + return result - @abstractmethod async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]: - """Get edges as a batch using UNWIND""" + """Get edges as a batch using UNWIND + + Default implementation fetches edges one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for pair in pairs: + src_id = pair["src"] + tgt_id = pair["tgt"] + edge = await self.get_edge(src_id, tgt_id) + if edge is not None: + result[(src_id, tgt_id)] = edge + return result - @abstractmethod async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]: - """"Get nodes edges as a batch using UNWIND""" + """Get nodes edges as a batch using UNWIND + + Default implementation fetches node edges one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for node_id in node_ids: + edges = await self.get_node_edges(node_id) + result[node_id] = edges if edges is not None else [] + return result @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: