cleaned import

This commit is contained in:
Yannick Stephan
2025-02-09 11:24:08 +01:00
parent 61fd3e6127
commit 4cce14e65e
8 changed files with 62 additions and 70 deletions

View File

@@ -1,6 +1,8 @@
from enum import Enum
import os
from dataclasses import dataclass, field
from typing import (
Optional,
TypedDict,
Union,
Literal,
@@ -8,6 +10,8 @@ from typing import (
Any,
)
import numpy as np
from .utils import EmbeddingFunc
@@ -99,9 +103,7 @@ class BaseKVStorage(StorageNameSpace):
async def drop(self) -> None:
raise NotImplementedError
async def get_by_status(
self, status: str
) -> Union[list[dict[str, Any]], None]:
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
raise NotImplementedError
@@ -148,12 +150,12 @@ class BaseGraphStorage(StorageNameSpace):
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
raise NotImplementedError("Node embedding is not used in lightrag.")
async def get_all_labels(self) -> List[str]:
async def get_all_labels(self) -> list[str]:
raise NotImplementedError
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> Dict[str, List[Dict]]:
) -> dict[str, list[dict]]:
raise NotImplementedError
@@ -177,20 +179,20 @@ class DocProcessingStatus:
updated_at: str # ISO format timestamp
chunks_count: Optional[int] = None # Number of chunks after splitting
error: Optional[str] = None # Error message if failed
metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata
metadata: dict[str, Any] = field(default_factory=dict) # Additional metadata
class DocStatusStorage(BaseKVStorage):
"""Base class for document status storage"""
async def get_status_counts(self) -> Dict[str, int]:
async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
raise NotImplementedError
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
async def get_failed_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all failed documents"""
raise NotImplementedError
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
async def get_pending_docs(self) -> dict[str, DocProcessingStatus]:
"""Get all pending documents"""
raise NotImplementedError

View File

@@ -51,8 +51,6 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self) -> None:
self._data = {}
async def get_by_status(
self, status: str
) -> Union[list[dict[str, Any]], None]:
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
result = [v for _, v in self._data.items() if v["status"] == status]
return result if result else None

View File

@@ -77,9 +77,7 @@ class MongoKVStorage(BaseKVStorage):
"""Drop the collection"""
await self._data.drop()
async def get_by_status(
self, status: str
) -> Union[list[dict[str, Any]], None]:
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
"""Get documents by status and ids"""
return self._data.find({"status": status})

View File

@@ -229,9 +229,7 @@ class OracleKVStorage(BaseKVStorage):
res = [{k: v} for k, v in dict_res.items()]
return res
async def get_by_status(
self, status: str
) -> Union[list[dict[str, Any]], None]:
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}

View File

@@ -231,9 +231,7 @@ class PGKVStorage(BaseKVStorage):
else:
return await self.db.query(sql, params, multirows=True)
async def get_by_status(
self, status: str
) -> Union[list[dict[str, Any]], None]:
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}

View File

@@ -59,9 +59,7 @@ class RedisKVStorage(BaseKVStorage):
if keys:
await self._redis.delete(*keys)
async def get_by_status(
self, status: str
) -> Union[list[dict[str, Any]], None]:
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
pipe = self._redis.pipeline()
for key in await self._redis.keys(f"{self.namespace}:*"):
pipe.hgetall(key)

View File

@@ -322,9 +322,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
merge_sql = SQL_TEMPLATES["insert_relationship"]
await self.db.execute(merge_sql, data)
async def get_by_status(
self, status: str
) -> Union[list[dict[str, Any]], None]:
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)

View File

@@ -4,11 +4,16 @@ from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Any, Type, Union
from typing import Any, Type, Union, cast
import traceback
from .operate import (
chunking_by_token_size,
extract_entities
extract_entities,
extract_keywords_only,
kg_query,
kg_query_with_keywords,
mix_kg_vector_query,
naive_query,
# local_query,global_query,hybrid_query,,
)
@@ -19,18 +24,21 @@ from .utils import (
convert_response_to_json,
logger,
set_logger,
statistic_data
statistic_data,
)
from .base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocStatus,
QueryParam,
StorageNameSpace,
)
from .namespace import NameSpace, make_namespace
from .prompt import GRAPH_FIELD_SEP
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
@@ -351,9 +359,10 @@ class LightRAG:
)
async def ainsert(
self, string_or_strings: Union[str, list[str]],
split_by_character: str | None = None,
split_by_character_only: bool = False
self,
string_or_strings: Union[str, list[str]],
split_by_character: str | None = None,
split_by_character_only: bool = False,
):
"""Insert documents with checkpoint support
@@ -368,7 +377,6 @@ class LightRAG:
await self.apipeline_process_chunks(split_by_character, split_by_character_only)
await self.apipeline_process_extract_graph()
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop()
return loop.run_until_complete(
@@ -482,31 +490,27 @@ class LightRAG:
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(
self,
split_by_character: str | None = None,
split_by_character_only: bool = False
) -> None:
self,
split_by_character: str | None = None,
split_by_character_only: bool = False,
) -> None:
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
to_process_doc_keys: list[str] = []
# Process failes
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.FAILED
)
to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED)
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.PENDING
)
to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING)
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
if not to_process_doc_keys:
logger.info("All documents have been processed or are duplicates")
return
return
full_docs_ids = await self.full_docs.get_by_ids(to_process_doc_keys)
new_docs = {}
@@ -515,8 +519,8 @@ class LightRAG:
if not new_docs:
logger.info("All documents have been processed or are duplicates")
return
return
# 2. split docs into chunks, insert chunks, update doc status
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
@@ -526,11 +530,11 @@ class LightRAG:
batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}"
):
doc_status: dict[str, Any] = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
try:
await self.doc_status.upsert({doc_id: doc_status})
@@ -564,14 +568,16 @@ class LightRAG:
except Exception as e:
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status})
logger.error(f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}")
logger.error(
f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
)
continue
async def apipeline_process_extract_graph(self):
@@ -580,22 +586,18 @@ class LightRAG:
to_process_doc_keys: list[str] = []
# Process failes
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.FAILED
)
to_process_docs = await self.full_docs.get_by_status(status=DocStatus.FAILED)
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
# Process Pending
to_process_docs = await self.full_docs.get_by_status(
status=DocStatus.PENDING
)
to_process_docs = await self.full_docs.get_by_status(status=DocStatus.PENDING)
if to_process_docs:
to_process_doc_keys.extend([doc["id"] for doc in to_process_docs])
if not to_process_doc_keys:
logger.info("All documents have been processed or are duplicates")
return
return
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
@@ -606,7 +608,7 @@ class LightRAG:
async def process_chunk(chunk_id: str):
async with semaphore:
chunks:dict[str, Any] = {
chunks: dict[str, Any] = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
# Extract and store entities and relationships
@@ -1051,7 +1053,7 @@ class LightRAG:
return content
return content[:max_length] + "..."
async def get_processing_status(self) -> Dict[str, int]:
async def get_processing_status(self) -> dict[str, int]:
"""Get current document processing status counts
Returns: