diff --git a/lightrag/base.py b/lightrag/base.py index 8e6a212d..b1fe50a2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -95,7 +95,7 @@ class StorageNameSpace: @dataclass class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc - meta_fields: set = field(default_factory=set) + meta_fields: set[str] = field(default_factory=set) async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError @@ -130,50 +130,64 @@ class BaseKVStorage(StorageNameSpace): @dataclass class BaseGraphStorage(StorageNameSpace): - embedding_func: EmbeddingFunc = None - + embedding_func: EmbeddingFunc | None = None + """Check if a node exists in the graph.""" async def has_node(self, node_id: str) -> bool: raise NotImplementedError + """Check if an edge exists in the graph.""" async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError + """Get the degree of a node.""" async def node_degree(self, node_id: str) -> int: raise NotImplementedError + """Get the degree of an edge.""" async def edge_degree(self, src_id: str, tgt_id: str) -> int: raise NotImplementedError - async def get_node(self, node_id: str) -> Union[dict, None]: + """Get a node by its id.""" + async def get_node(self, node_id: str) -> Union[dict[str, str], None]: raise NotImplementedError + """Get an edge by its source and target node ids.""" async def get_edge( self, source_node_id: str, target_node_id: str - ) -> Union[dict, None]: + ) -> Union[dict[str, str], None]: raise NotImplementedError + """Get all edges connected to a node.""" async def get_node_edges( self, source_node_id: str ) -> Union[list[tuple[str, str]], None]: raise NotImplementedError - async def upsert_node(self, node_id: str, node_data: dict[str, str]): + """Upsert a node into the graph.""" + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: raise NotImplementedError + """Upsert an edge into the graph.""" async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ): + self, source_node_id: str, + target_node_id: str, + edge_data: dict[str, str] + ) -> None: raise NotImplementedError - async def delete_node(self, node_id: str): + """Delete a node from the graph.""" + async def delete_node(self, node_id: str) -> None: raise NotImplementedError - async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + """Embed nodes using an algorithm.""" + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") + """Get all labels in the graph.""" async def get_all_labels(self) -> list[str]: raise NotImplementedError + """Get a knowledge graph of a node.""" async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 15bb6cc2..593b5734 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -6,7 +6,7 @@ import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, Callable, Optional, Union, cast +from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast from .base import ( BaseGraphStorage, @@ -983,7 +983,7 @@ class LightRAG: def query( self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None - ) -> str: + ) -> str | Iterator[str]: """ Perform a sync query. @@ -1003,7 +1003,7 @@ class LightRAG: query: str, param: QueryParam = QueryParam(), prompt: str | None = None, - ) -> str: + ) -> str | AsyncIterator[str]: """ Perform a async query. @@ -1081,7 +1081,10 @@ class LightRAG: return response def query_with_separate_keyword_extraction( - self, query: str, prompt: str, param: QueryParam = QueryParam() + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() ): """ 1. Extract keywords from the 'query' using new function in operate.py. @@ -1093,7 +1096,10 @@ class LightRAG: ) async def aquery_with_separate_keyword_extraction( - self, query: str, prompt: str, param: QueryParam = QueryParam() + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() ): """ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'.