cleaned import

This commit is contained in:
Yannick Stephan
2025-02-09 11:24:08 +01:00
parent 61fd3e6127
commit 4cce14e65e
8 changed files with 62 additions and 70 deletions

View File

@@ -1,6 +1,8 @@
from enum import Enum
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from typing import (
Optional,
TypedDict, TypedDict,
Union, Union,
Literal, Literal,
@@ -8,6 +10,8 @@ from typing import (
Any, Any,
) )
import numpy as np
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
@@ -99,9 +103,7 @@ class BaseKVStorage(StorageNameSpace):
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError raise NotImplementedError
async def get_by_status( async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
self, status: str
) -> Union[list[dict[str, Any]], None]:
raise NotImplementedError raise NotImplementedError
@@ -148,12 +150,12 @@ class BaseGraphStorage(StorageNameSpace):
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.") raise NotImplementedError("Node embedding is not used in lightrag.")
async def get_all_labels(self) -> List[str]: async def get_all_labels(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5
) -> Dict[str, List[Dict]]: ) -> dict[str, list[dict]]:
raise NotImplementedError raise NotImplementedError
@@ -177,20 +179,20 @@ class DocProcessingStatus:
updated_at: str # ISO format timestamp updated_at: str # ISO format timestamp
chunks_count: Optional[int] = None # Number of chunks after splitting chunks_count: Optional[int] = None # Number of chunks after splitting
error: Optional[str] = None # Error message if failed error: Optional[str] = None # Error message if failed
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata metadata: dict[str, Any] = field(default_factory=dict) # Additional metadata
class DocStatusStorage(BaseKVStorage): class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage""" """Base class for document status storage"""
async def get_status_counts(self) -> Dict[str, int]: async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status""" """Get counts of documents in each status"""
raise NotImplementedError raise NotImplementedError
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents""" """Get all failed documents"""
raise NotImplementedError raise NotImplementedError
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents""" """Get all pending documents"""
raise NotImplementedError raise NotImplementedError

View File

@@ -51,8 +51,6 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self) -> None: async def drop(self) -> None:
self._data = {} self._data = {}
async def get_by_status( async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
self, status: str
) -> Union[list[dict[str, Any]], None]:
result = [v for _, v in self._data.items() if v["status"] == status] result = [v for _, v in self._data.items() if v["status"] == status]
return result if result else None return result if result else None

View File

@@ -77,9 +77,7 @@ class MongoKVStorage(BaseKVStorage):
"""Drop the collection""" """Drop the collection"""
await self._data.drop() await self._data.drop()
async def get_by_status( async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
self, status: str
) -> Union[list[dict[str, Any]], None]:
"""Get documents by status and ids""" """Get documents by status and ids"""
return self._data.find({"status": status}) return self._data.find({"status": status})

View File

@@ -229,9 +229,7 @@ class OracleKVStorage(BaseKVStorage):
res = [{k: v} for k, v in dict_res.items()] res = [{k: v} for k, v in dict_res.items()]
return res return res
async def get_by_status( async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
self, status: str
) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache.""" """Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}

View File

@@ -231,9 +231,7 @@ class PGKVStorage(BaseKVStorage):
else: else:
return await self.db.query(sql, params, multirows=True) return await self.db.query(sql, params, multirows=True)
async def get_by_status( async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
self, status: str
) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache.""" """Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}

View File

@@ -59,9 +59,7 @@ class RedisKVStorage(BaseKVStorage):
if keys: if keys:
await self._redis.delete(*keys) await self._redis.delete(*keys)
async def get_by_status( async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
self, status: str
) -> Union[list[dict[str, Any]], None]:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for key in await self._redis.keys(f"{self.namespace}:*"): for key in await self._redis.keys(f"{self.namespace}:*"):
pipe.hgetall(key) pipe.hgetall(key)

View File

@@ -322,9 +322,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
merge_sql = SQL_TEMPLATES["insert_relationship"] merge_sql = SQL_TEMPLATES["insert_relationship"]
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
async def get_by_status( async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
self, status: str
) -> Union[list[dict[str, Any]], None]:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)

View File

@@ -4,11 +4,16 @@ from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, Type, Union from typing import Any, Type, Union, cast
import traceback import traceback
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities extract_entities,
extract_keywords_only,
kg_query,
kg_query_with_keywords,
mix_kg_vector_query,
naive_query,
# local_query,global_query,hybrid_query,, # local_query,global_query,hybrid_query,,
) )
@@ -19,18 +24,21 @@ from .utils import (
convert_response_to_json, convert_response_to_json,
logger, logger,
set_logger, set_logger,
statistic_data statistic_data,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
BaseKVStorage, BaseKVStorage,
BaseVectorStorage, BaseVectorStorage,
DocStatus, DocStatus,
QueryParam,
StorageNameSpace,
) )
from .namespace import NameSpace, make_namespace from .namespace import NameSpace, make_namespace
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
STORAGES = { STORAGES = {
"NetworkXStorage": ".kg.networkx_impl", "NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl", "JsonKVStorage": ".kg.json_kv_impl",
@@ -351,9 +359,10 @@ class LightRAG:
) )
async def ainsert( async def ainsert(
self, string_or_strings: Union[str, list[str]], self,
string_or_strings: Union[str, list[str]],
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False split_by_character_only: bool = False,
): ):
"""Insert documents with checkpoint support """Insert documents with checkpoint support
@@ -368,7 +377,6 @@ class LightRAG:
await self.apipeline_process_chunks(split_by_character, split_by_character_only) await self.apipeline_process_chunks(split_by_character, split_by_character_only)
await self.apipeline_process_extract_graph() await self.apipeline_process_extract_graph()
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]): def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
@@ -482,25 +490,21 @@ class LightRAG:
logger.info(f"Stored {len(new_docs)} new unique documents") logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks( async def apipeline_process_chunks(
self, self,
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False split_by_character_only: bool = False,
) -> None: ) -> None:
"""Get pendding documents, split into chunks,insert chunks""" """Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents # 1. get all pending and failed documents
to_process_doc_keys: list[str] = [] to_process_doc_keys: list[str] = []
# Process failes # Process failes
to_process_docs = await self.full_docs.get_by_status( to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED)
status=DocStatus.FAILED
)
if to_process_docs: if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending # Process Pending
to_process_docs = await self.full_docs.get_by_status( to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING)
status=DocStatus.PENDING
)
if to_process_docs: if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
@@ -526,11 +530,11 @@ class LightRAG:
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}" batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
): ):
doc_status: dict[str, Any] = { doc_status: dict[str, Any] = {
"content_summary": doc["content_summary"], "content_summary": doc["content_summary"],
"content_length": doc["content_length"], "content_length": doc["content_length"],
"status": DocStatus.PROCESSING, "status": DocStatus.PROCESSING,
"created_at": doc["created_at"], "created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
} }
try: try:
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
@@ -564,14 +568,16 @@ class LightRAG:
except Exception as e: except Exception as e:
doc_status.update( doc_status.update(
{ {
"status": DocStatus.FAILED, "status": DocStatus.FAILED,
"error": str(e), "error": str(e),
"updated_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(),
} }
) )
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
logger.error(f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}") logger.error(
f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
)
continue continue
async def apipeline_process_extract_graph(self): async def apipeline_process_extract_graph(self):
@@ -580,16 +586,12 @@ class LightRAG:
to_process_doc_keys: list[str] = [] to_process_doc_keys: list[str] = []
# Process failes # Process failes
to_process_docs = await self.full_docs.get_by_status( to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED)
status=DocStatus.FAILED
)
if to_process_docs: if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending # Process Pending
to_process_docs = await self.full_docs.get_by_status( to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING)
status=DocStatus.PENDING
)
if to_process_docs: if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs]) to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
@@ -606,7 +608,7 @@ class LightRAG:
async def process_chunk(chunk_id: str): async def process_chunk(chunk_id: str):
async with semaphore: async with semaphore:
chunks:dict[str, Any] = { chunks: dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
} }
# Extract and store entities and relationships # Extract and store entities and relationships
@@ -1051,7 +1053,7 @@ class LightRAG:
return content return content
return content[:max_length] + "..." return content[:max_length] + "..."
async def get_processing_status(self) -> Dict[str, int]: async def get_processing_status(self) -> dict[str, int]:
"""Get current document processing status counts """Get current document processing status counts
Returns: Returns: