added type and cleaned code

This commit is contained in:
Yannick Stephan
2025-02-14 23:42:52 +01:00
parent dfa8681924
commit cf6e327bf4
2 changed files with 35 additions and 15 deletions

View File

@@ -95,7 +95,7 @@ class StorageNameSpace:
@dataclass @dataclass
class BaseVectorStorage(StorageNameSpace): class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set) meta_fields: set[str] = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
raise NotImplementedError raise NotImplementedError
@@ -130,50 +130,64 @@ class BaseKVStorage(StorageNameSpace):
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace): class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None embedding_func: EmbeddingFunc | None = None
"""Check if a node exists in the graph."""
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
raise NotImplementedError raise NotImplementedError
"""Check if an edge exists in the graph."""
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
raise NotImplementedError raise NotImplementedError
"""Get the degree of a node."""
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
raise NotImplementedError raise NotImplementedError
"""Get the degree of an edge."""
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
raise NotImplementedError raise NotImplementedError
async def get_node(self, node_id: str) -> Union[dict, None]: """Get a node by its id."""
async def get_node(self, node_id: str) -> Union[dict[str, str], None]:
raise NotImplementedError raise NotImplementedError
"""Get an edge by its source and target node ids."""
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict[str, str], None]:
raise NotImplementedError raise NotImplementedError
"""Get all edges connected to a node."""
async def get_node_edges( async def get_node_edges(
self, source_node_id: str self, source_node_id: str
) -> Union[list[tuple[str, str]], None]: ) -> Union[list[tuple[str, str]], None]:
raise NotImplementedError raise NotImplementedError
async def upsert_node(self, node_id: str, node_data: dict[str, str]): """Upsert a node into the graph."""
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
raise NotImplementedError raise NotImplementedError
"""Upsert an edge into the graph."""
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str,
): target_node_id: str,
edge_data: dict[str, str]
) -> None:
raise NotImplementedError raise NotImplementedError
async def delete_node(self, node_id: str): """Delete a node from the graph."""
async def delete_node(self, node_id: str) -> None:
raise NotImplementedError raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: """Embed nodes using an algorithm."""
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.") raise NotImplementedError("Node embedding is not used in lightrag.")
"""Get all labels in the graph."""
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
"""Get a knowledge graph of a node."""
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph: ) -> KnowledgeGraph:

View File

@@ -6,7 +6,7 @@ import configparser
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Callable, Optional, Union, cast from typing import Any, AsyncIterator, Callable, Iterator, Optional, Union, cast
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -983,7 +983,7 @@ class LightRAG:
def query( def query(
self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None self, query: str, param: QueryParam = QueryParam(), prompt: str | None = None
) -> str: ) -> str | Iterator[str]:
""" """
Perform a sync query. Perform a sync query.
@@ -1003,7 +1003,7 @@ class LightRAG:
query: str, query: str,
param: QueryParam = QueryParam(), param: QueryParam = QueryParam(),
prompt: str | None = None, prompt: str | None = None,
) -> str: ) -> str | AsyncIterator[str]:
""" """
Perform a async query. Perform a async query.
@@ -1081,7 +1081,10 @@ class LightRAG:
return response return response
def query_with_separate_keyword_extraction( def query_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam() self,
query: str,
prompt: str,
param: QueryParam = QueryParam()
): ):
""" """
1. Extract keywords from the 'query' using new function in operate.py. 1. Extract keywords from the 'query' using new function in operate.py.
@@ -1093,7 +1096,10 @@ class LightRAG:
) )
async def aquery_with_separate_keyword_extraction( async def aquery_with_separate_keyword_extraction(
self, query: str, prompt: str, param: QueryParam = QueryParam() self,
query: str,
prompt: str,
param: QueryParam = QueryParam()
): ):
""" """
1. Calls extract_keywords_only to get HL/LL keywords from 'query'. 1. Calls extract_keywords_only to get HL/LL keywords from 'query'.