added typing
This commit is contained in:
@@ -982,7 +982,10 @@ class LightRAG:
|
|||||||
await self._insert_done()
|
await self._insert_done()
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
|
self,
|
||||||
|
query: str,
|
||||||
|
param: QueryParam = QueryParam(),
|
||||||
|
prompt: str | None = None
|
||||||
) -> str | Iterator[str]:
|
) -> str | Iterator[str]:
|
||||||
"""
|
"""
|
||||||
Perform a sync query.
|
Perform a sync query.
|
||||||
@@ -996,7 +999,8 @@ class LightRAG:
|
|||||||
str: The result of the query execution.
|
str: The result of the query execution.
|
||||||
"""
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(self.aquery(query, param, prompt))
|
|
||||||
|
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
|
||||||
|
|
||||||
async def aquery(
|
async def aquery(
|
||||||
self,
|
self,
|
||||||
@@ -1455,7 +1459,7 @@ class LightRAG:
|
|||||||
|
|
||||||
async def get_entity_info(
|
async def get_entity_info(
|
||||||
self, entity_name: str, include_vector_data: bool = False
|
self, entity_name: str, include_vector_data: bool = False
|
||||||
):
|
) -> dict[str, str | None | dict[str, str]]:
|
||||||
"""Get detailed information of an entity
|
"""Get detailed information of an entity
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1475,7 +1479,7 @@ class LightRAG:
|
|||||||
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
|
||||||
source_id = node_data.get("source_id") if node_data else None
|
source_id = node_data.get("source_id") if node_data else None
|
||||||
|
|
||||||
result = {
|
result: dict[str, str | None | dict[str, str]] = {
|
||||||
"entity_name": entity_name,
|
"entity_name": entity_name,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
"graph_data": node_data,
|
"graph_data": node_data,
|
||||||
@@ -1531,7 +1535,7 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
source_id = edge_data.get("source_id") if edge_data else None
|
source_id = edge_data.get("source_id") if edge_data else None
|
||||||
|
|
||||||
result = {
|
result: dict[str, str | None | dict[str, str]] = {
|
||||||
"src_entity": src_entity,
|
"src_entity": src_entity,
|
||||||
"tgt_entity": tgt_entity,
|
"tgt_entity": tgt_entity,
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
|
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from typing import Any, Union
|
from typing import Any, AsyncIterator, Union
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from .utils import (
|
from .utils import (
|
||||||
logger,
|
logger,
|
||||||
@@ -780,7 +780,7 @@ async def mix_kg_vector_query(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
) -> str:
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||||
|
|
||||||
@@ -1505,7 +1505,7 @@ async def naive_query(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
):
|
) -> str | AsyncIterator[str]:
|
||||||
# Handle cache
|
# Handle cache
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
|
Reference in New Issue
Block a user