From 7e526d343696b6345147e0b9cadd8aa2ecc67d9b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Fri, 14 Feb 2025 23:52:05 +0100 Subject: [PATCH] cleaned code --- lightrag/base.py | 23 ++++++++++++++++++----- lightrag/lightrag.py | 25 ++++++++----------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/lightrag/base.py b/lightrag/base.py index e70dddd1..2b3e5cad 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -96,7 +96,7 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) - + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError @@ -132,62 +132,75 @@ class BaseKVStorage(StorageNameSpace): class BaseGraphStorage(StorageNameSpace): embedding_func: EmbeddingFunc | None = None """Check if a node exists in the graph.""" + async def has_node(self, node_id: str) -> bool: raise NotImplementedError """Check if an edge exists in the graph.""" + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError """Get the degree of a node.""" + async def node_degree(self, node_id: str) -> int: raise NotImplementedError """Get the degree of an edge.""" + async def edge_degree(self, src_id: str, tgt_id: str) -> int: raise NotImplementedError """Get a node by its id.""" + async def get_node(self, node_id: str) -> Union[dict[str, str], None]: raise NotImplementedError """Get an edge by its source and target node ids.""" + async def get_edge( self, source_node_id: str, target_node_id: str ) -> Union[dict[str, str], None]: raise NotImplementedError """Get all edges connected to a node.""" + async def get_node_edges( self, source_node_id: str ) -> Union[list[tuple[str, str]], None]: raise NotImplementedError """Upsert a node into the graph.""" + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: raise NotImplementedError """Upsert an edge into the graph.""" + async def upsert_edge( - self, source_node_id: str, - target_node_id: str, - edge_data: dict[str, str] + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: raise NotImplementedError """Delete a node from the graph.""" + async def delete_node(self, node_id: str) -> None: raise NotImplementedError """Embed nodes using an algorithm.""" - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: + + async def embed_nodes( + self, algorithm: str + ) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") """Get all labels in the graph.""" + async def get_all_labels(self) -> list[str]: raise NotImplementedError """Get a knowledge graph of a node.""" + async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8a65a46c..08855d60 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -982,10 +982,7 @@ class LightRAG: await self._insert_done() def query( - self, - query: str, - param: QueryParam = QueryParam(), - prompt: str | None = None + self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None ) -> str | Iterator[str]: """ Perform a sync query. @@ -999,8 +996,8 @@ class LightRAG: str: The result of the query execution. """ loop = always_get_an_event_loop() - - return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore + + return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore async def aquery( self, @@ -1085,10 +1082,7 @@ class LightRAG: return response def query_with_separate_keyword_extraction( - self, - query: str, - prompt: str, - param: QueryParam = QueryParam() + self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ 1. Extract keywords from the 'query' using new function in operate.py. @@ -1100,10 +1094,7 @@ class LightRAG: ) async def aquery_with_separate_keyword_extraction( - self, - query: str, - prompt: str, - param: QueryParam = QueryParam() + self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. @@ -1127,8 +1118,8 @@ class LightRAG: ), ) - param.hl_keywords = (hl_keywords,) - param.ll_keywords = (ll_keywords,) + param.hl_keywords = hl_keywords + param.ll_keywords = ll_keywords # --------------------- # STEP 2: Final Query Logic @@ -1156,7 +1147,7 @@ class LightRAG: self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE ), global_config=asdict(self), - embedding_func=self.embedding_funcne, + embedding_func=self.embedding_func, ), ) elif param.mode == "naive":