cleaned code

This commit is contained in:
Yannick Stephan
2025-02-14 23:52:05 +01:00
parent e6520ad6a2
commit 7e526d3436
2 changed files with 26 additions and 22 deletions

View File

@@ -132,62 +132,75 @@ class BaseKVStorage(StorageNameSpace):
class BaseGraphStorage(StorageNameSpace): class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc | None = None embedding_func: EmbeddingFunc | None = None
"""Check if a node exists in the graph.""" """Check if a node exists in the graph."""
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
raise NotImplementedError raise NotImplementedError
"""Check if an edge exists in the graph.""" """Check if an edge exists in the graph."""
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError raise NotImplementedError
"""Get the degree of a node.""" """Get the degree of a node."""
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
raise NotImplementedError raise NotImplementedError
"""Get the degree of an edge.""" """Get the degree of an edge."""
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError raise NotImplementedError
"""Get a node by its id.""" """Get a node by its id."""
async def get_node(self, node_id: str) -> Union[dict[str, str], None]: async def get_node(self, node_id: str) -> Union[dict[str, str], None]:
raise NotImplementedError raise NotImplementedError
"""Get an edge by its source and target node ids.""" """Get an edge by its source and target node ids."""
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict[str, str], None]: ) -> Union[dict[str, str], None]:
raise NotImplementedError raise NotImplementedError
"""Get all edges connected to a node.""" """Get all edges connected to a node."""
async def get_node_edges( async def get_node_edges(
self, source_node_id: str self, source_node_id: str
) -> Union[list[tuple[str, str]], None]: ) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError raise NotImplementedError
"""Upsert a node into the graph.""" """Upsert a node into the graph."""
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError raise NotImplementedError
"""Upsert an edge into the graph.""" """Upsert an edge into the graph."""
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
target_node_id: str,
edge_data: dict[str, str]
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
"""Delete a node from the graph.""" """Delete a node from the graph."""
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
"""Embed nodes using an algorithm.""" """Embed nodes using an algorithm."""
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.") raise NotImplementedError("Node embedding is not used in lightrag.")
"""Get all labels in the graph.""" """Get all labels in the graph."""
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
"""Get a knowledge graph of a node.""" """Get a knowledge graph of a node."""
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -982,10 +982,7 @@ class LightRAG:
await self._insert_done() await self._insert_done()
def query( def query(
self, self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None
) -> str | Iterator[str]: ) -> str | Iterator[str]:
""" """
Perform a sync query. Perform a sync query.
@@ -1000,7 +997,7 @@ class LightRAG:
""" """
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
async def aquery( async def aquery(
self, self,
@@ -1085,10 +1082,7 @@ class LightRAG:
return response return response
def query_with_separate_keyword_extraction( def query_with_separate_keyword_extraction(
self, self, query: str, prompt: str, param: QueryParam = QueryParam()
query: str,
prompt: str,
param: QueryParam = QueryParam()
): ):
""" """
1. Extract keywords from the 'query' using new function in operate.py. 1. Extract keywords from the 'query' using new function in operate.py.
@@ -1100,10 +1094,7 @@ class LightRAG:
) )
async def aquery_with_separate_keyword_extraction( async def aquery_with_separate_keyword_extraction(
self, self, query: str, prompt: str, param: QueryParam = QueryParam()
query: str,
prompt: str,
param: QueryParam = QueryParam()
): ):
""" """
1. Calls extract_keywords_only to get HL/LL keywords from 'query'. 1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
@@ -1127,8 +1118,8 @@ class LightRAG:
), ),
) )
param.hl_keywords = (hl_keywords,) param.hl_keywords = hl_keywords
param.ll_keywords = (ll_keywords,) param.ll_keywords = ll_keywords
# --------------------- # ---------------------
# STEP 2: Final Query Logic # STEP 2: Final Query Logic
@@ -1156,7 +1147,7 @@ class LightRAG:
self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
), ),
global_config=asdict(self), global_config=asdict(self),
embedding_func=self.embedding_funcne, embedding_func=self.embedding_func,
), ),
) )
elif param.mode == "naive": elif param.mode == "naive":