cleaned code

This commit is contained in:
Yannick Stephan
2025-02-09 14:55:52 +01:00
parent 58d776561d
commit 82481ecf28
7 changed files with 18 additions and 62 deletions

View File

@@ -84,9 +84,6 @@ class BaseVectorStorage(StorageNameSpace):
class BaseKVStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
async def all_keys(self) -> list[str]:
raise NotImplementedError
async def get_by_id(self, id: str) -> dict[str, Any]:
raise NotImplementedError
@@ -103,9 +100,6 @@ class BaseKVStorage(StorageNameSpace):
async def drop(self) -> None:
raise NotImplementedError
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
raise NotImplementedError
@dataclass
class BaseGraphStorage(StorageNameSpace):

View File

@@ -1,7 +1,7 @@
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Union
from typing import Any
from lightrag.utils import (
logger,
@@ -21,10 +21,7 @@ class JsonKVStorage(BaseKVStorage):
self._data: dict[str, Any] = load_json(self._file_name) or {}
self._lock = asyncio.Lock()
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
@@ -49,8 +46,4 @@ class JsonKVStorage(BaseKVStorage):
self._data.update(left_data)
async def drop(self) -> None:
self._data = {}
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
result = [v for _, v in self._data.items() if v["status"] == status]
return result if result else None
self._data = {}

View File

@@ -29,9 +29,6 @@ class MongoKVStorage(BaseKVStorage):
self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as KV {self.namespace}")
async def all_keys(self) -> list[str]:
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
async def get_by_id(self, id: str) -> dict[str, Any]:
return self._data.find_one({"_id": id})
@@ -77,11 +74,6 @@ class MongoKVStorage(BaseKVStorage):
"""Drop the collection"""
await self._data.drop()
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
"""Get documents by status and ids"""
return self._data.find({"status": status})
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""

View File

@@ -229,12 +229,6 @@ class OracleKVStorage(BaseKVStorage):
res = [{k: v} for k, v in dict_res.items()]
return res
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def filter_keys(self, keys: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(

View File

@@ -237,16 +237,6 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def all_keys(self) -> list[dict]:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
sql = "select workspace,mode,id from lightrag_llm_cache"
res = await self.db.query(sql, multirows=True)
return res
else:
logger.error(
f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
)
async def filter_keys(self, keys: List[str]) -> Set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Union
from typing import Any
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass
import pipmaster as pm
@@ -20,11 +20,7 @@ class RedisKVStorage(BaseKVStorage):
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}")
async def all_keys(self) -> list[str]:
keys = await self._redis.keys(f"{self.namespace}:*")
return [key.split(":", 1)[-1] for key in keys]
async def get_by_id(self, id):
data = await self._redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
@@ -57,11 +53,4 @@ class RedisKVStorage(BaseKVStorage):
async def drop(self) -> None:
keys = await self._redis.keys(f"{self.namespace}:*")
if keys:
await self._redis.delete(*keys)
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
pipe = self._redis.pipeline()
for key in await self._redis.keys(f"{self.namespace}:*"):
pipe.hgetall(key)
results = await pipe.execute()
return [data for data in results if data.get("status") == status] or None
await self._redis.delete(*keys)

View File

@@ -29,6 +29,7 @@ from .base import (
BaseKVStorage,
BaseVectorStorage,
DocStatus,
DocStatusStorage,
QueryParam,
StorageNameSpace,
)
@@ -319,7 +320,7 @@ class LightRAG:
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status: BaseKVStorage = self.doc_status_storage_cls(
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
global_config=global_config,
embedding_func=None,
@@ -394,10 +395,8 @@ class LightRAG:
split_by_character_only: if split_by_character_only is True, split the string by character only, when
split_by_character is None, this parameter is ignored.
"""
await self.apipeline_process_documents(string_or_strings)
await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character_only
)
await self.apipeline_enqueue_documents(string_or_strings)
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
loop = always_get_an_event_loop()
@@ -496,8 +495,13 @@ class LightRAG:
# 3. Filter out already processed documents
add_doc_keys: set[str] = set()
excluded_ids = await self.doc_status.all_keys()
# Get docs ids
in_process_keys = list(new_docs.keys())
# Get in progress docs ids
excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
# Exclude already in process
add_doc_keys = new_docs.keys() - excluded_ids
# Filter
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
if not new_docs:
@@ -513,12 +517,12 @@ class LightRAG:
to_process_doc_keys: list[str] = []
# Fetch failed documents
failed_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
failed_docs = await self.doc_status.get_failed_docs()
if failed_docs:
to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
# Fetch pending documents
pending_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
pending_docs = await self.doc_status.get_pending_docs()
if pending_docs:
to_process_doc_keys.extend([doc["id"] for doc in pending_docs])