cleanup code

This commit is contained in:
Yannick Stephan
2025-02-08 23:58:15 +01:00
parent 2929d1fc39
commit 50c7f26262
8 changed files with 63 additions and 37 deletions

View File

@@ -91,7 +91,7 @@ class BaseKVStorage(StorageNameSpace):
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
raise NotImplementedError raise NotImplementedError
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
raise NotImplementedError raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]: async def filter_keys(self, data: list[str]) -> set[str]:
@@ -103,10 +103,13 @@ class BaseKVStorage(StorageNameSpace):
async def drop(self) -> None: async def drop(self) -> None:
raise NotImplementedError raise NotImplementedError
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: async def get_by_status_and_ids(
self, status: str
) -> Union[list[dict[str, Any]], None]:
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace): class BaseGraphStorage(StorageNameSpace):
embedding_func: EmbeddingFunc = None embedding_func: EmbeddingFunc = None

View File

@@ -12,6 +12,7 @@ from lightrag.base import (
BaseKVStorage, BaseKVStorage,
) )
@dataclass @dataclass
class JsonKVStorage(BaseKVStorage): class JsonKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
@@ -30,7 +31,7 @@ class JsonKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.get(id, None) return self._data.get(id, None)
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
return [ return [
( (
{k: v for k, v in self._data[id].items()} {k: v for k, v in self._data[id].items()}
@@ -50,6 +51,8 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self) -> None: async def drop(self) -> None:
self._data = {} self._data = {}
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: async def get_by_status_and_ids(
self, status: str
) -> Union[list[dict[str, Any]], None]:
result = [v for _, v in self._data.items() if v["status"] == status] result = [v for _, v in self._data.items() if v["status"] == status]
return result if result else None return result if result else None

View File

@@ -35,7 +35,7 @@ class MongoKVStorage(BaseKVStorage):
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return self._data.find_one({"_id": id}) return self._data.find_one({"_id": id})
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
return list(self._data.find({"_id": {"$in": ids}})) return list(self._data.find({"_id": {"$in": ids}}))
async def filter_keys(self, data: list[str]) -> set[str]: async def filter_keys(self, data: list[str]) -> set[str]:
@@ -77,7 +77,9 @@ class MongoKVStorage(BaseKVStorage):
"""Drop the collection""" """Drop the collection"""
await self._data.drop() await self._data.drop()
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: async def get_by_status_and_ids(
self, status: str
) -> Union[list[dict[str, Any]], None]:
"""Get documents by status and ids""" """Get documents by status and ids"""
return self._data.find({"status": status}) return self._data.find({"status": status})

View File

@@ -326,7 +326,8 @@ class OracleKVStorage(BaseKVStorage):
(NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
): ):
logger.info("full doc and chunk data had been saved into oracle db!") logger.info("full doc and chunk data had been saved into oracle db!")
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db # should pass db object to self.db

View File

@@ -213,7 +213,7 @@ class PGKVStorage(BaseKVStorage):
return None return None
# Query by id # Query by id
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
"""Get doc_chunks data by id""" """Get doc_chunks data by id"""
sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids]) ids=",".join([f"'{id}'" for id in ids])
@@ -237,12 +237,14 @@ class PGKVStorage(BaseKVStorage):
return res return res
else: else:
return None return None
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: async def get_by_status_and_ids(
self, status: str
) -> Union[list[dict[str, Any]], None]:
"""Specifically for llm_response_cache.""" """Specifically for llm_response_cache."""
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
async def all_keys(self) -> list[dict]: async def all_keys(self) -> list[dict]:
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):

View File

@@ -29,7 +29,7 @@ class RedisKVStorage(BaseKVStorage):
data = await self._redis.get(f"{self.namespace}:{id}") data = await self._redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None return json.loads(data) if data else None
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for id in ids: for id in ids:
pipe.get(f"{self.namespace}:{id}") pipe.get(f"{self.namespace}:{id}")
@@ -58,11 +58,12 @@ class RedisKVStorage(BaseKVStorage):
keys = await self._redis.keys(f"{self.namespace}:*") keys = await self._redis.keys(f"{self.namespace}:*")
if keys: if keys:
await self._redis.delete(*keys) await self._redis.delete(*keys)
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: async def get_by_status_and_ids(
self, status: str
) -> Union[list[dict[str, Any]], None]:
pipe = self._redis.pipeline() pipe = self._redis.pipeline()
for key in await self._redis.keys(f"{self.namespace}:*"): for key in await self._redis.keys(f"{self.namespace}:*"):
pipe.hgetall(key) pipe.hgetall(key)
results = await pipe.execute() results = await pipe.execute()
return [data for data in results if data.get("status") == status] or None return [data for data in results if data.get("status") == status] or None

View File

@@ -122,7 +122,7 @@ class TiDBKVStorage(BaseKVStorage):
return None return None
# Query by id # Query by id
async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]:
"""根据 id 获取 doc_chunks 数据""" """根据 id 获取 doc_chunks 数据"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids]) ids=",".join([f"'{id}'" for id in ids])
@@ -333,10 +333,13 @@ class TiDBVectorDBStorage(BaseVectorStorage):
merge_sql = SQL_TEMPLATES["insert_relationship"] merge_sql = SQL_TEMPLATES["insert_relationship"]
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: async def get_by_status_and_ids(
self, status: str
) -> Union[list[dict[str, Any]], None]:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status} params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True) return await self.db.query(SQL, params, multirows=True)
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):

View File

@@ -629,12 +629,7 @@ class LightRAG:
# 4. Store original document # 4. Store original document
for doc_id, doc in new_docs.items(): for doc_id, doc in new_docs.items():
await self.full_docs.upsert( await self.full_docs.upsert(
{ {doc_id: {"content": doc["content"], "status": DocStatus.PENDING}}
doc_id: {
"content": doc["content"],
"status": DocStatus.PENDING
}
}
) )
logger.info(f"Stored {len(new_docs)} new unique documents") logger.info(f"Stored {len(new_docs)} new unique documents")
@@ -642,10 +637,14 @@ class LightRAG:
"""Get pendding documents, split into chunks,insert chunks""" """Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents # 1. get all pending and failed documents
_todo_doc_keys = [] _todo_doc_keys = []
_failed_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.FAILED) _failed_doc = await self.full_docs.get_by_status_and_ids(
_pendding_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.PENDING) status=DocStatus.FAILED
)
_pendding_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.PENDING
)
if _failed_doc: if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) _todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc: if _pendding_doc:
@@ -685,15 +684,19 @@ class LightRAG:
) )
} }
chunk_cnt += len(chunks) chunk_cnt += len(chunks)
try: try:
# Store chunks in vector database # Store chunks in vector database
await self.chunks_vdb.upsert(chunks) await self.chunks_vdb.upsert(chunks)
# Update doc status # Update doc status
await self.text_chunks.upsert({**chunks, "status": DocStatus.PENDING}) await self.text_chunks.upsert(
{**chunks, "status": DocStatus.PENDING}
)
except Exception as e: except Exception as e:
# Mark as failed if any step fails # Mark as failed if any step fails
await self.text_chunks.upsert({**chunks, "status": DocStatus.FAILED}) await self.text_chunks.upsert(
{**chunks, "status": DocStatus.FAILED}
)
raise e raise e
except Exception as e: except Exception as e:
import traceback import traceback
@@ -707,8 +710,12 @@ class LightRAG:
"""Get pendding or failed chunks, extract entities and relationships from each chunk""" """Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks # 1. get all pending and failed chunks
_todo_chunk_keys = [] _todo_chunk_keys = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.FAILED) _failed_chunks = await self.text_chunks.get_by_status_and_ids(
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.PENDING) status=DocStatus.FAILED
)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.PENDING
)
if _failed_chunks: if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks: if _pendding_chunks:
@@ -742,11 +749,15 @@ class LightRAG:
if maybe_new_kg is None: if maybe_new_kg is None:
logger.info("No entities or relationships extracted!") logger.info("No entities or relationships extracted!")
# Update status to processed # Update status to processed
await self.text_chunks.upsert({chunk_id: {"status": DocStatus.PROCESSED}}) await self.text_chunks.upsert(
{chunk_id: {"status": DocStatus.PROCESSED}}
)
except Exception as e: except Exception as e:
logger.error("Failed to extract entities and relationships") logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails # Mark as failed if any step fails
await self.text_chunks.upsert({chunk_id: {"status": DocStatus.FAILED}}) await self.text_chunks.upsert(
{chunk_id: {"status": DocStatus.FAILED}}
)
raise e raise e
with tqdm_async( with tqdm_async(