From 1699b10a255c8ab8e72f3738ff2852158baa8bb9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 27 Feb 2025 15:14:54 +0800 Subject: [PATCH] Refactor direct client/graph access to reduce redundant get calls in vector/graph ops --- lightrag/kg/nano_vector_db_impl.py | 25 +++++++------------ lightrag/kg/networkx_impl.py | 40 +++++++++++------------------- 2 files changed, 23 insertions(+), 42 deletions(-) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 64b0e720..953a19a7 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -97,8 +97,7 @@ class NanoVectorDBStorage(BaseVectorStorage): for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] with self._storage_lock: - client = self._get_client() - results = client.upsert(datas=list_data) + results = self._get_client().upsert(datas=list_data) return results else: # sometimes the embedding is not returned correctly. just log it. @@ -112,8 +111,7 @@ class NanoVectorDBStorage(BaseVectorStorage): embedding = embedding[0] with self._storage_lock: - client = self._get_client() - results = client.query( + results = self._get_client().query( query=embedding, top_k=top_k, better_than_threshold=self.cosine_better_than_threshold, @@ -131,8 +129,7 @@ class NanoVectorDBStorage(BaseVectorStorage): @property def client_storage(self): - client = self._get_client() - return getattr(client, "_NanoVectorDB__storage") + return getattr(self._get_client(), "_NanoVectorDB__storage") async def delete(self, ids: list[str]): """Delete vectors with specified IDs @@ -142,8 +139,7 @@ class NanoVectorDBStorage(BaseVectorStorage): """ try: with self._storage_lock: - client = self._get_client() - client.delete(ids) + self._get_client().delete(ids) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) @@ -158,10 +154,9 @@ class NanoVectorDBStorage(BaseVectorStorage): ) with self._storage_lock: - client = self._get_client() # Check if the entity exists - if client.get([entity_id]): - client.delete([entity_id]) + if self._get_client().get([entity_id]): + self._get_client().delete([entity_id]) logger.debug(f"Successfully deleted entity {entity_name}") else: logger.debug(f"Entity {entity_name} not found in storage") @@ -171,8 +166,7 @@ class NanoVectorDBStorage(BaseVectorStorage): async def delete_entity_relation(self, entity_name: str) -> None: try: with self._storage_lock: - client = self._get_client() - storage = getattr(client, "_NanoVectorDB__storage") + storage = getattr(self._get_client(), "_NanoVectorDB__storage") relations = [ dp for dp in storage["data"] @@ -184,7 +178,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: - client.delete(ids_to_delete) + self._get_client().delete(ids_to_delete) logger.debug( f"Deleted {len(ids_to_delete)} relations for {entity_name}" ) @@ -195,5 +189,4 @@ class NanoVectorDBStorage(BaseVectorStorage): async def index_done_callback(self) -> None: with self._storage_lock: - client = self._get_client() - client.save() + self._get_client().save() diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index aec49e6c..db059393 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -115,65 +115,54 @@ class NetworkXStorage(BaseGraphStorage): async def index_done_callback(self) -> None: with self._storage_lock: - graph = self._get_graph() - NetworkXStorage.write_nx_graph(graph, self._graphml_xml_file) + NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file) async def has_node(self, node_id: str) -> bool: with self._storage_lock: - graph = self._get_graph() - return graph.has_node(node_id) + return self._get_graph().has_node(node_id) async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: with self._storage_lock: - graph = self._get_graph() - return graph.has_edge(source_node_id, target_node_id) + return self._get_graph().has_edge(source_node_id, target_node_id) async def get_node(self, node_id: str) -> dict[str, str] | None: with self._storage_lock: - graph = self._get_graph() - return graph.nodes.get(node_id) + return self._get_graph().nodes.get(node_id) async def node_degree(self, node_id: str) -> int: with self._storage_lock: - graph = self._get_graph() - return graph.degree(node_id) + return self._get_graph().degree(node_id) async def edge_degree(self, src_id: str, tgt_id: str) -> int: with self._storage_lock: - graph = self._get_graph() - return graph.degree(src_id) + graph.degree(tgt_id) + return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: with self._storage_lock: - graph = self._get_graph() - return graph.edges.get((source_node_id, target_node_id)) + return self._get_graph().edges.get((source_node_id, target_node_id)) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: with self._storage_lock: - graph = self._get_graph() - if graph.has_node(source_node_id): - return list(graph.edges(source_node_id)) + if self._get_graph().has_node(source_node_id): + return list(self._get_graph().edges(source_node_id)) return None async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: with self._storage_lock: - graph = self._get_graph() - graph.add_node(node_id, **node_data) + self._get_graph().add_node(node_id, **node_data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: with self._storage_lock: - graph = self._get_graph() - graph.add_edge(source_node_id, target_node_id, **edge_data) + self._get_graph().add_edge(source_node_id, target_node_id, **edge_data) async def delete_node(self, node_id: str) -> None: with self._storage_lock: - graph = self._get_graph() - if graph.has_node(node_id): - graph.remove_node(node_id) + if self._get_graph().has_node(node_id): + self._get_graph().remove_node(node_id) logger.debug(f"Node {node_id} deleted from the graph.") else: logger.warning(f"Node {node_id} not found in the graph for deletion.") @@ -227,9 +216,8 @@ class NetworkXStorage(BaseGraphStorage): [label1, label2, ...] # Alphabetically sorted label list """ with self._storage_lock: - graph = self._get_graph() labels = set() - for node in graph.nodes(): + for node in self._get_graph().nodes(): labels.add(str(node)) # Add node id as a label # Return sorted list