added typing

This commit is contained in:
Yannick Stephan
2025-02-14 23:49:39 +01:00
parent cf6e327bf4
commit e6520ad6a2
3 changed files with 13 additions and 9 deletions

View File

@@ -982,7 +982,10 @@ class LightRAG:
await self._insert_done()
def query(
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
self,
query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None
) -> str | Iterator[str]:
"""
Perform a sync query.
@@ -996,7 +999,8 @@ class LightRAG:
str: The result of the query execution.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param, prompt))
return loop.run_until_complete(self.aquery(query, param, prompt)) # type: ignore
async def aquery(
self,
@@ -1455,7 +1459,7 @@ class LightRAG:
async def get_entity_info(
self, entity_name: str, include_vector_data: bool = False
):
) -> dict[str, str | None | dict[str, str]]:
"""Get detailed information of an entity
Args:
@@ -1475,7 +1479,7 @@ class LightRAG:
node_data = await self.chunk_entity_relation_graph.get_node(entity_name)
source_id = node_data.get("source_id") if node_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"entity_name": entity_name,
"source_id": source_id,
"graph_data": node_data,
@@ -1531,7 +1535,7 @@ class LightRAG:
)
source_id = edge_data.get("source_id") if edge_data else None
result = {
result: dict[str, str | None | dict[str, str]] = {
"src_entity": src_entity,
"tgt_entity": tgt_entity,
"source_id": source_id,