cleanup code
This commit is contained in:
@@ -91,7 +91,7 @@ class BaseKVStorage(StorageNameSpace):
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
@@ -103,10 +103,13 @@ class BaseKVStorage(StorageNameSpace):
|
||||
|
||||
async def drop(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str
|
||||
) -> Union[list[dict[str, Any]], None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseGraphStorage(StorageNameSpace):
|
||||
embedding_func: EmbeddingFunc = None
|
||||
|
@@ -12,6 +12,7 @@ from lightrag.base import (
|
||||
BaseKVStorage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
@@ -30,7 +31,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return self._data.get(id, None)
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
return [
|
||||
(
|
||||
{k: v for k, v in self._data[id].items()}
|
||||
@@ -50,6 +51,8 @@ class JsonKVStorage(BaseKVStorage):
|
||||
async def drop(self) -> None:
|
||||
self._data = {}
|
||||
|
||||
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str
|
||||
) -> Union[list[dict[str, Any]], None]:
|
||||
result = [v for _, v in self._data.items() if v["status"] == status]
|
||||
return result if result else None
|
||||
|
@@ -35,7 +35,7 @@ class MongoKVStorage(BaseKVStorage):
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return self._data.find_one({"_id": id})
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
return list(self._data.find({"_id": {"$in": ids}}))
|
||||
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
@@ -77,7 +77,9 @@ class MongoKVStorage(BaseKVStorage):
|
||||
"""Drop the collection"""
|
||||
await self._data.drop()
|
||||
|
||||
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str
|
||||
) -> Union[list[dict[str, Any]], None]:
|
||||
"""Get documents by status and ids"""
|
||||
return self._data.find({"status": status})
|
||||
|
||||
|
@@ -326,7 +326,8 @@ class OracleKVStorage(BaseKVStorage):
|
||||
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
|
||||
):
|
||||
logger.info("full doc and chunk data had been saved into oracle db!")
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class OracleVectorDBStorage(BaseVectorStorage):
|
||||
# should pass db object to self.db
|
||||
|
@@ -213,7 +213,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
"""Get doc_chunks data by id"""
|
||||
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
@@ -237,12 +237,14 @@ class PGKVStorage(BaseKVStorage):
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str
|
||||
) -> Union[list[dict[str, Any]], None]:
|
||||
"""Specifically for llm_response_cache."""
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
async def all_keys(self) -> list[dict]:
|
||||
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
||||
|
@@ -29,7 +29,7 @@ class RedisKVStorage(BaseKVStorage):
|
||||
data = await self._redis.get(f"{self.namespace}:{id}")
|
||||
return json.loads(data) if data else None
|
||||
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
pipe = self._redis.pipeline()
|
||||
for id in ids:
|
||||
pipe.get(f"{self.namespace}:{id}")
|
||||
@@ -58,11 +58,12 @@ class RedisKVStorage(BaseKVStorage):
|
||||
keys = await self._redis.keys(f"{self.namespace}:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
|
||||
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str
|
||||
) -> Union[list[dict[str, Any]], None]:
|
||||
pipe = self._redis.pipeline()
|
||||
for key in await self._redis.keys(f"{self.namespace}:*"):
|
||||
pipe.hgetall(key)
|
||||
results = await pipe.execute()
|
||||
return [data for data in results if data.get("status") == status] or None
|
||||
|
||||
|
@@ -122,7 +122,7 @@ class TiDBKVStorage(BaseKVStorage):
|
||||
return None
|
||||
|
||||
# Query by id
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
|
||||
"""根据 id 获取 doc_chunks 数据"""
|
||||
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
||||
ids=",".join([f"'{id}'" for id in ids])
|
||||
@@ -333,10 +333,13 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
||||
merge_sql = SQL_TEMPLATES["insert_relationship"]
|
||||
await self.db.execute(merge_sql, data)
|
||||
|
||||
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]:
|
||||
async def get_by_status_and_ids(
|
||||
self, status: str
|
||||
) -> Union[list[dict[str, Any]], None]:
|
||||
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
return await self.db.query(SQL, params, multirows=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TiDBGraphStorage(BaseGraphStorage):
|
||||
|
@@ -629,12 +629,7 @@ class LightRAG:
|
||||
# 4. Store original document
|
||||
for doc_id, doc in new_docs.items():
|
||||
await self.full_docs.upsert(
|
||||
{
|
||||
doc_id: {
|
||||
"content": doc["content"],
|
||||
"status": DocStatus.PENDING
|
||||
}
|
||||
}
|
||||
{doc_id: {"content": doc["content"], "status": DocStatus.PENDING}}
|
||||
)
|
||||
logger.info(f"Stored {len(new_docs)} new unique documents")
|
||||
|
||||
@@ -642,10 +637,14 @@ class LightRAG:
|
||||
"""Get pendding documents, split into chunks,insert chunks"""
|
||||
# 1. get all pending and failed documents
|
||||
_todo_doc_keys = []
|
||||
|
||||
_failed_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.FAILED)
|
||||
_pendding_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.PENDING)
|
||||
|
||||
|
||||
_failed_doc = await self.full_docs.get_by_status_and_ids(
|
||||
status=DocStatus.FAILED
|
||||
)
|
||||
_pendding_doc = await self.full_docs.get_by_status_and_ids(
|
||||
status=DocStatus.PENDING
|
||||
)
|
||||
|
||||
if _failed_doc:
|
||||
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
|
||||
if _pendding_doc:
|
||||
@@ -685,15 +684,19 @@ class LightRAG:
|
||||
)
|
||||
}
|
||||
chunk_cnt += len(chunks)
|
||||
|
||||
|
||||
try:
|
||||
# Store chunks in vector database
|
||||
await self.chunks_vdb.upsert(chunks)
|
||||
# Update doc status
|
||||
await self.text_chunks.upsert({**chunks, "status": DocStatus.PENDING})
|
||||
await self.text_chunks.upsert(
|
||||
{**chunks, "status": DocStatus.PENDING}
|
||||
)
|
||||
except Exception as e:
|
||||
# Mark as failed if any step fails
|
||||
await self.text_chunks.upsert({**chunks, "status": DocStatus.FAILED})
|
||||
await self.text_chunks.upsert(
|
||||
{**chunks, "status": DocStatus.FAILED}
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
import traceback
|
||||
@@ -707,8 +710,12 @@ class LightRAG:
|
||||
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
|
||||
# 1. get all pending and failed chunks
|
||||
_todo_chunk_keys = []
|
||||
_failed_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.FAILED)
|
||||
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.PENDING)
|
||||
_failed_chunks = await self.text_chunks.get_by_status_and_ids(
|
||||
status=DocStatus.FAILED
|
||||
)
|
||||
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
|
||||
status=DocStatus.PENDING
|
||||
)
|
||||
if _failed_chunks:
|
||||
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
|
||||
if _pendding_chunks:
|
||||
@@ -742,11 +749,15 @@ class LightRAG:
|
||||
if maybe_new_kg is None:
|
||||
logger.info("No entities or relationships extracted!")
|
||||
# Update status to processed
|
||||
await self.text_chunks.upsert({chunk_id: {"status": DocStatus.PROCESSED}})
|
||||
await self.text_chunks.upsert(
|
||||
{chunk_id: {"status": DocStatus.PROCESSED}}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract entities and relationships")
|
||||
# Mark as failed if any step fails
|
||||
await self.text_chunks.upsert({chunk_id: {"status": DocStatus.FAILED}})
|
||||
await self.text_chunks.upsert(
|
||||
{chunk_id: {"status": DocStatus.FAILED}}
|
||||
)
|
||||
raise e
|
||||
|
||||
with tqdm_async(
|
||||
|
Reference in New Issue
Block a user