cleaned code
This commit is contained in:
@@ -84,9 +84,6 @@ class BaseVectorStorage(StorageNameSpace):
|
||||
class BaseKVStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -103,9 +100,6 @@ class BaseKVStorage(StorageNameSpace):
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGraphStorage(StorageNameSpace):
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from lightrag.utils import (
|
||||
logger,
|
||||
@@ -21,10 +21,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
self._lock = asyncio.Lock()
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return list(self._data.keys())
|
||||
|
||||
|
||||
async def index_done_callback(self):
|
||||
write_json(self._data, self._file_name)
|
||||
|
||||
@@ -49,8 +46,4 @@ class JsonKVStorage(BaseKVStorage):
|
||||
self._data.update(left_data)
|
||||
|
||||
async def drop(self) -> None:
|
||||
self._data = {}
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
result = [v for _, v in self._data.items() if v["status"] == status]
|
||||
return result if result else None
|
||||
self._data = {}
|
@@ -29,9 +29,6 @@ class MongoKVStorage(BaseKVStorage):
|
||||
self._data = database.get_collection(self.namespace)
|
||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||
return self._data.find_one({"_id": id})
|
||||
|
||||
@@ -77,11 +74,6 @@ class MongoKVStorage(BaseKVStorage):
|
||||
"""Drop the collection"""
|
||||
await self._data.drop()
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
"""Get documents by status and ids"""
|
||||
return self._data.find({"status": status})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MongoGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
|
@@ -229,12 +229,6 @@ class OracleKVStorage(BaseKVStorage):
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
return res
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||
|
@@ -237,16 +237,6 @@ class PGKVStorage(BaseKVStorage):
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def all_keys(self) -> list[dict]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
sql = "select workspace,mode,id from lightrag_llm_cache"
|
||||
res = await self.db.query(sql, multirows=True)
|
||||
return res
|
||||
else:
|
||||
logger.error(
|
||||
f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
|
||||
)
|
||||
|
||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
||||
"""Filter out duplicated content"""
|
||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
from tqdm.asyncio import tqdm as tqdm_async
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
@@ -20,11 +20,7 @@ class RedisKVStorage(BaseKVStorage):
|
||||
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
|
||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||
logger.info(f"Use Redis as KV {self.namespace}")
|
||||
|
||||
async def all_keys(self) -> list[str]:
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
return [key.split(":", 1)[-1] for key in keys]
|
||||
|
||||
|
||||
async def get_by_id(self, id):
|
||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
@@ -57,11 +53,4 @@ class RedisKVStorage(BaseKVStorage):
|
||||
async def drop(self) -> None:
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
|
||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
pipe = self._redis.pipeline()
|
||||
for key in await self._redis.keys(f"{self.namespace}:*"):
|
||||
pipe.hgetall(key)
|
||||
results = await pipe.execute()
|
||||
return [data for data in results if data.get("status") == status] or None
|
||||
await self._redis.delete(*keys)
|
@@ -29,6 +29,7 @@ from .base import (
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage,
|
||||
DocStatus,
|
||||
DocStatusStorage,
|
||||
QueryParam,
|
||||
StorageNameSpace,
|
||||
)
|
||||
@@ -319,7 +320,7 @@ class LightRAG:
|
||||
|
||||
# Initialize document status storage
|
||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||
self.doc_status: BaseKVStorage = self.doc_status_storage_cls(
|
||||
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
||||
global_config=global_config,
|
||||
embedding_func=None,
|
||||
@@ -394,10 +395,8 @@ class LightRAG:
|
||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||
split_by_character is None, this parameter is ignored.
|
||||
"""
|
||||
await self.apipeline_process_documents(string_or_strings)
|
||||
await self.apipeline_process_enqueue_documents(
|
||||
split_by_character, split_by_character_only
|
||||
)
|
||||
await self.apipeline_enqueue_documents(string_or_strings)
|
||||
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
|
||||
|
||||
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
||||
loop = always_get_an_event_loop()
|
||||
@@ -496,8 +495,13 @@ class LightRAG:
|
||||
|
||||
# 3. Filter out already processed documents
|
||||
add_doc_keys: set[str] = set()
|
||||
excluded_ids = await self.doc_status.all_keys()
|
||||
# Get docs ids
|
||||
in_process_keys = list(new_docs.keys())
|
||||
# Get in progress docs ids
|
||||
excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
|
||||
# Exclude already in process
|
||||
add_doc_keys = new_docs.keys() - excluded_ids
|
||||
# Filter
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
||||
|
||||
if not new_docs:
|
||||
@@ -513,12 +517,12 @@ class LightRAG:
|
||||
to_process_doc_keys: list[str] = []
|
||||
|
||||
# Fetch failed documents
|
||||
failed_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
|
||||
failed_docs = await self.doc_status.get_failed_docs()
|
||||
if failed_docs:
|
||||
to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
|
||||
|
||||
# Fetch pending documents
|
||||
pending_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
|
||||
pending_docs = await self.doc_status.get_pending_docs()
|
||||
if pending_docs:
|
||||
to_process_doc_keys.extend([doc["id"] for doc in pending_docs])
|
||||
|
||||
|
Reference in New Issue
Block a user