added typing

This commit is contained in:
Yannick Stephan
2025-02-14 23:49:39 +01:00
parent cf6e327bf4
commit e6520ad6a2
3 changed files with 13 additions and 9 deletions

View File

@@ -96,7 +96,7 @@ class StorageNameSpace:
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
raise NotImplementedError

View File

@@ -982,7 +982,10 @@ class LightRAG:
await self._insert_done()
def query(
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
self,
query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None
) -> str | Iterator[str]:
"""
Perform a sync query.
@@ -996,7 +999,8 @@ class LightRAG:
str: The result of the query execution.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param, prompt))
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
async def aquery(
self,
@@ -1455,7 +1459,7 @@ class LightRAG:
async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False
):
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity
Args:
@@ -1475,7 +1479,7 @@ class LightRAG:
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
source_id = node_data.get("source_id") if node_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
@@ -1531,7 +1535,7 @@ class LightRAG:
)
source_id = edge_data.get("source_id") if edge_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,

View File

@@ -2,7 +2,7 @@ import asyncio
import json
import re
from tqdm.asyncio import tqdm as tqdm_async
from typing import Any, Union
from typing import Any, AsyncIterator, Union
from collections import Counter, defaultdict
from .utils import (
logger,
@@ -780,7 +780,7 @@ async def mix_kg_vector_query(
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
) -> str:
) -> str | AsyncIterator[str]:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1505,7 +1505,7 @@ async def naive_query(
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
):
) -> str | AsyncIterator[str]:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")