cleaned code
This commit is contained in:
@@ -84,9 +84,6 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
class BaseKVStorage(StorageNameSpace):
|
class BaseKVStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
|
|
||||||
async def all_keys(self) -> list[str]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -103,9 +100,6 @@ class BaseKVStorage(StorageNameSpace):
|
|||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseGraphStorage(StorageNameSpace):
|
class BaseGraphStorage(StorageNameSpace):
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
@@ -22,9 +22,6 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||||
|
|
||||||
async def all_keys(self) -> list[str]:
|
|
||||||
return list(self._data.keys())
|
|
||||||
|
|
||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
write_json(self._data, self._file_name)
|
write_json(self._data, self._file_name)
|
||||||
|
|
||||||
@@ -50,7 +47,3 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
async def drop(self) -> None:
|
async def drop(self) -> None:
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
|
||||||
result = [v for _, v in self._data.items() if v["status"] == status]
|
|
||||||
return result if result else None
|
|
||||||
|
@@ -29,9 +29,6 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
self._data = database.get_collection(self.namespace)
|
self._data = database.get_collection(self.namespace)
|
||||||
logger.info(f"Use MongoDB as KV {self.namespace}")
|
logger.info(f"Use MongoDB as KV {self.namespace}")
|
||||||
|
|
||||||
async def all_keys(self) -> list[str]:
|
|
||||||
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any]:
|
async def get_by_id(self, id: str) -> dict[str, Any]:
|
||||||
return self._data.find_one({"_id": id})
|
return self._data.find_one({"_id": id})
|
||||||
|
|
||||||
@@ -77,11 +74,6 @@ class MongoKVStorage(BaseKVStorage):
|
|||||||
"""Drop the collection"""
|
"""Drop the collection"""
|
||||||
await self._data.drop()
|
await self._data.drop()
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
|
||||||
"""Get documents by status and ids"""
|
|
||||||
return self._data.find({"status": status})
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoGraphStorage(BaseGraphStorage):
|
class MongoGraphStorage(BaseGraphStorage):
|
||||||
"""
|
"""
|
||||||
|
@@ -229,12 +229,6 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
res = [{k: v} for k, v in dict_res.items()]
|
res = [{k: v} for k, v in dict_res.items()]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
|
||||||
"""Specifically for llm_response_cache."""
|
|
||||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
|
||||||
params = {"workspace": self.db.workspace, "status": status}
|
|
||||||
return await self.db.query(SQL, params, multirows=True)
|
|
||||||
|
|
||||||
async def filter_keys(self, keys: list[str]) -> set[str]:
|
async def filter_keys(self, keys: list[str]) -> set[str]:
|
||||||
"""Return keys that don't exist in storage"""
|
"""Return keys that don't exist in storage"""
|
||||||
SQL = SQL_TEMPLATES["filter_keys"].format(
|
SQL = SQL_TEMPLATES["filter_keys"].format(
|
||||||
|
@@ -237,16 +237,6 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
params = {"workspace": self.db.workspace, "status": status}
|
params = {"workspace": self.db.workspace, "status": status}
|
||||||
return await self.db.query(SQL, params, multirows=True)
|
return await self.db.query(SQL, params, multirows=True)
|
||||||
|
|
||||||
async def all_keys(self) -> list[dict]:
|
|
||||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
|
||||||
sql = "select workspace,mode,id from lightrag_llm_cache"
|
|
||||||
res = await self.db.query(sql, multirows=True)
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
from tqdm.asyncio import tqdm as tqdm_async
|
from tqdm.asyncio import tqdm as tqdm_async
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
@@ -21,10 +21,6 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
||||||
logger.info(f"Use Redis as KV {self.namespace}")
|
logger.info(f"Use Redis as KV {self.namespace}")
|
||||||
|
|
||||||
async def all_keys(self) -> list[str]:
|
|
||||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
|
||||||
return [key.split(":", 1)[-1] for key in keys]
|
|
||||||
|
|
||||||
async def get_by_id(self, id):
|
async def get_by_id(self, id):
|
||||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||||
return json.loads(data) if data else None
|
return json.loads(data) if data else None
|
||||||
@@ -58,10 +54,3 @@ class RedisKVStorage(BaseKVStorage):
|
|||||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||||
if keys:
|
if keys:
|
||||||
await self._redis.delete(*keys)
|
await self._redis.delete(*keys)
|
||||||
|
|
||||||
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
|
||||||
pipe = self._redis.pipeline()
|
|
||||||
for key in await self._redis.keys(f"{self.namespace}:*"):
|
|
||||||
pipe.hgetall(key)
|
|
||||||
results = await pipe.execute()
|
|
||||||
return [data for data in results if data.get("status") == status] or None
|
|
||||||
|
@@ -29,6 +29,7 @@ from .base import (
|
|||||||
BaseKVStorage,
|
BaseKVStorage,
|
||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
DocStatus,
|
DocStatus,
|
||||||
|
DocStatusStorage,
|
||||||
QueryParam,
|
QueryParam,
|
||||||
StorageNameSpace,
|
StorageNameSpace,
|
||||||
)
|
)
|
||||||
@@ -319,7 +320,7 @@ class LightRAG:
|
|||||||
|
|
||||||
# Initialize document status storage
|
# Initialize document status storage
|
||||||
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
||||||
self.doc_status: BaseKVStorage = self.doc_status_storage_cls(
|
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
||||||
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
||||||
global_config=global_config,
|
global_config=global_config,
|
||||||
embedding_func=None,
|
embedding_func=None,
|
||||||
@@ -394,10 +395,8 @@ class LightRAG:
|
|||||||
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
||||||
split_by_character is None, this parameter is ignored.
|
split_by_character is None, this parameter is ignored.
|
||||||
"""
|
"""
|
||||||
await self.apipeline_process_documents(string_or_strings)
|
await self.apipeline_enqueue_documents(string_or_strings)
|
||||||
await self.apipeline_process_enqueue_documents(
|
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
|
||||||
split_by_character, split_by_character_only
|
|
||||||
)
|
|
||||||
|
|
||||||
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
@@ -496,8 +495,13 @@ class LightRAG:
|
|||||||
|
|
||||||
# 3. Filter out already processed documents
|
# 3. Filter out already processed documents
|
||||||
add_doc_keys: set[str] = set()
|
add_doc_keys: set[str] = set()
|
||||||
excluded_ids = await self.doc_status.all_keys()
|
# Get docs ids
|
||||||
|
in_process_keys = list(new_docs.keys())
|
||||||
|
# Get in progress docs ids
|
||||||
|
excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
|
||||||
|
# Exclude already in process
|
||||||
add_doc_keys = new_docs.keys() - excluded_ids
|
add_doc_keys = new_docs.keys() - excluded_ids
|
||||||
|
# Filter
|
||||||
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
||||||
|
|
||||||
if not new_docs:
|
if not new_docs:
|
||||||
@@ -513,12 +517,12 @@ class LightRAG:
|
|||||||
to_process_doc_keys: list[str] = []
|
to_process_doc_keys: list[str] = []
|
||||||
|
|
||||||
# Fetch failed documents
|
# Fetch failed documents
|
||||||
failed_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
|
failed_docs = await self.doc_status.get_failed_docs()
|
||||||
if failed_docs:
|
if failed_docs:
|
||||||
to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
|
to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
|
||||||
|
|
||||||
# Fetch pending documents
|
# Fetch pending documents
|
||||||
pending_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
|
pending_docs = await self.doc_status.get_pending_docs()
|
||||||
if pending_docs:
|
if pending_docs:
|
||||||
to_process_doc_keys.extend([doc["id"] for doc in pending_docs])
|
to_process_doc_keys.extend([doc["id"] for doc in pending_docs])
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user