added type and cleaned code

This commit is contained in:
Yannick Stephan
2025-02-14 23:42:52 +01:00
parent dfa8681924
commit cf6e327bf4
2 changed files with 35 additions and 15 deletions

View File

@@ -95,7 +95,7 @@ class StorageNameSpace:
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
raise NotImplementedError
@@ -130,50 +130,64 @@ class BaseKVStorage(StorageNameSpace):
@dataclass
class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None
embedding_func: EmbeddingFunc | None = None
"""Check if a node exists in the graph."""
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
"""Check if an edge exists in the graph."""
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError
"""Get the degree of a node."""
async def node_degree(self, node_id: str) -> int:
raise NotImplementedError
"""Get the degree of an edge."""
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]:
"""Get a node by its id."""
async def get_node(self, node_id: str) -> Union[dict[str, str], None]:
raise NotImplementedError
"""Get an edge by its source and target node ids."""
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> Union[dict, None]:
) -> Union[dict[str, str], None]:
raise NotImplementedError
"""Get all edges connected to a node."""
async def get_node_edges(
self, source_node_id: str
) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""Upsert a node into the graph."""
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError
"""Upsert an edge into the graph."""
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
self, source_node_id: str,
target_node_id: str,
edge_data: dict[str, str]
) -> None:
raise NotImplementedError
async def delete_node(self, node_id: str):
"""Delete a node from the graph."""
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
"""Embed nodes using an algorithm."""
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
"""Get all labels in the graph."""
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
"""Get a knowledge graph of a node."""
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:

View File

@@ -6,7 +6,7 @@ import configparser
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, Callable, Optional, Union, cast
from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast
from .base import (
BaseGraphStorage,
@@ -983,7 +983,7 @@ class LightRAG:
def query(
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
) -> str:
) -> str | Iterator[str]:
"""
Perform a sync query.
@@ -1003,7 +1003,7 @@ class LightRAG:
query: str,
param: QueryParam = QueryParam(),
prompt: str | None = None,
) -> str:
) -> str | AsyncIterator[str]:
"""
Perform a async query.
@@ -1081,7 +1081,10 @@ class LightRAG:
return response
def query_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
self,
query: str,
prompt: str,
param: QueryParam = QueryParam()
):
"""
1. Extract keywords from the 'query' using new function in operate.py.
@@ -1093,7 +1096,10 @@ class LightRAG:
)
async def aquery_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam()
self,
query: str,
prompt: str,
param: QueryParam = QueryParam()
):
"""
1. Calls extract_keywords_only to get HL/LL keywords from 'query'.