added typing
This commit is contained in:
@@ -982,7 +982,10 @@ class LightRAG:
|
||||
await self._insert_done()
|
||||
|
||||
def query(
|
||||
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
prompt: str | None = None
|
||||
) -> str | Iterator[str]:
|
||||
"""
|
||||
Perform a sync query.
|
||||
@@ -996,7 +999,8 @@ class LightRAG:
|
||||
str: The result of the query execution.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, param, prompt))
|
||||
|
||||
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
||||
|
||||
async def aquery(
|
||||
self,
|
||||
@@ -1455,7 +1459,7 @@ class LightRAG:
|
||||
|
||||
async def get_entity_info(
|
||||
self, entity_name: str, include_vector_data: bool = False
|
||||
):
|
||||
) -> dict[str, str | None | dict[str, str]]:
|
||||
"""Get detailed information of an entity
|
||||
|
||||
Args:
|
||||
@@ -1475,7 +1479,7 @@ class LightRAG:
|
||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||
source_id = node_data.get("source_id") if node_data else None
|
||||
|
||||
result = {
|
||||
result: dict[str, str | None | dict[str, str]] = {
|
||||
"entity_name": entity_name,
|
||||
"source_id": source_id,
|
||||
"graph_data": node_data,
|
||||
@@ -1531,7 +1535,7 @@ class LightRAG:
|
||||
)
|
||||
source_id = edge_data.get("source_id") if edge_data else None
|
||||
|
||||
result = {
|
||||
result: dict[str, str | None | dict[str, str]] = {
|
||||
"src_entity": src_entity,
|
||||
"tgt_entity": tgt_entity,
|
||||
"source_id": source_id,
|
||||
|
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from typing import Any, Union
|
||||
from typing import Any, AsyncIterator, Union
|
||||
from collections import Counter, defaultdict
|
||||
from .utils import (
|
||||
logger,
|
||||
@@ -780,7 +780,7 @@ async def mix_kg_vector_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
) -> str:
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
|
||||
@@ -1505,7 +1505,7 @@ async def naive_query(
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
):
|
||||
) -> str | AsyncIterator[str]:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
|
Reference in New Issue
Block a user