diff --git a/lightrag/base.py b/lightrag/base.py index b1fe50a2..e70dddd1 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -96,7 +96,7 @@ class StorageNameSpace: class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set[str] = field(default_factory=set) - + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 593b5734..8a65a46c 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -982,7 +982,10 @@ class LightRAG: await self._insert_done() def query( - self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None + self, + query: str, + param: QueryParam = QueryParam(), + prompt: str | None = None ) -> str | Iterator[str]: """ Perform a sync query. @@ -996,7 +999,8 @@ class LightRAG: str: The result of the query execution. """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, param, prompt)) + + return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore async def aquery( self, @@ -1455,7 +1459,7 @@ class LightRAG: async def get_entity_info( self, entity_name: str, include_vector_data: bool = False - ): + ) -> dict[str, str | None | dict[str, str]]: """Get detailed information of an entity Args: @@ -1475,7 +1479,7 @@ class LightRAG: node_data = await self.chunk_entity_relation_graph.get_node(entity_name) source_id = node_data.get("source_id") if node_data else None - result = { + result: dict[str, str | None | dict[str, str]] = { "entity_name": entity_name, "source_id": source_id, "graph_data": node_data, @@ -1531,7 +1535,7 @@ class LightRAG: ) source_id = edge_data.get("source_id") if edge_data else None - result = { + result: dict[str, str | None | dict[str, str]] = { "src_entity": src_entity, "tgt_entity": tgt_entity, "source_id": source_id, diff --git a/lightrag/operate.py b/lightrag/operate.py index a961cfd9..d6cc9f3c 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2,7 +2,7 @@ import asyncio import json import re from tqdm.asyncio import tqdm as tqdm_async -from typing import Any, Union +from typing import Any, AsyncIterator, Union from collections import Counter, defaultdict from .utils import ( logger, @@ -780,7 +780,7 @@ async def mix_kg_vector_query( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, -) -> str: +) -> str | AsyncIterator[str]: """ Hybrid retrieval implementation combining knowledge graph and vector search. @@ -1505,7 +1505,7 @@ async def naive_query( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, -): +) -> str | AsyncIterator[str]: # Handle cache use_model_func = global_config["llm_model_func"] args_hash = compute_args_hash(query_param.mode, query, cache_type="query")