diff --git a/lightrag/base.py b/lightrag/base.py index 2b655549..a53c8a83 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -91,7 +91,7 @@ class BaseKVStorage(StorageNameSpace): async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: raise NotImplementedError - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: raise NotImplementedError async def filter_keys(self, data: list[str]) -> set[str]: @@ -103,10 +103,13 @@ class BaseKVStorage(StorageNameSpace): async def drop(self) -> None: raise NotImplementedError - - async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: + + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: raise NotImplementedError - + + @dataclass class BaseGraphStorage(StorageNameSpace): embedding_func: EmbeddingFunc = None diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 927ffe32..70a60aa2 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -12,6 +12,7 @@ from lightrag.base import ( BaseKVStorage, ) + @dataclass class JsonKVStorage(BaseKVStorage): def __post_init__(self): @@ -30,7 +31,7 @@ class JsonKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return self._data.get(id, None) - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: return [ ( {k: v for k, v in self._data[id].items()} @@ -50,6 +51,8 @@ class JsonKVStorage(BaseKVStorage): async def drop(self) -> None: self._data = {} - async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: result = [v for _, v in self._data.items() if v["status"] == status] return result if result else None diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 7cfdb994..ce703dfb 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -35,7 +35,7 @@ class MongoKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return self._data.find_one({"_id": id}) - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: return list(self._data.find({"_id": {"$in": ids}})) async def filter_keys(self, data: list[str]) -> set[str]: @@ -77,7 +77,9 @@ class MongoKVStorage(BaseKVStorage): """Drop the collection""" await self._data.drop() - async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: """Get documents by status and ids""" return self._data.find({"status": status}) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index f51a5eb8..3c064eba 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -326,7 +326,8 @@ class OracleKVStorage(BaseKVStorage): (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS), ): logger.info("full doc and chunk data had been saved into oracle db!") - + + @dataclass class OracleVectorDBStorage(BaseVectorStorage): # should pass db object to self.db diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c6757765..ba11fea7 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -213,7 +213,7 @@ class PGKVStorage(BaseKVStorage): return None # Query by id - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: """Get doc_chunks data by id""" sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -237,12 +237,14 @@ class PGKVStorage(BaseKVStorage): return res else: return None - - async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: + + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: """Specifically for llm_response_cache.""" SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} - return await self.db.query(SQL, params, multirows=True) + return await self.db.query(SQL, params, multirows=True) async def all_keys(self) -> list[dict]: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index b3ff890f..095cc3b6 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -29,7 +29,7 @@ class RedisKVStorage(BaseKVStorage): data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: pipe = self._redis.pipeline() for id in ids: pipe.get(f"{self.namespace}:{id}") @@ -58,11 +58,12 @@ class RedisKVStorage(BaseKVStorage): keys = await self._redis.keys(f"{self.namespace}:*") if keys: await self._redis.delete(*keys) - - async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: + + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: pipe = self._redis.pipeline() for key in await self._redis.keys(f"{self.namespace}:*"): pipe.hgetall(key) results = await pipe.execute() return [data for data in results if data.get("status") == status] or None - diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 81450e87..b8e6e985 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -122,7 +122,7 @@ class TiDBKVStorage(BaseKVStorage): return None # Query by id - async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: + async def get_by_ids(self, ids: list[str]) -> list[Union[dict[str, Any], None]]: """根据 id 获取 doc_chunks 数据""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -333,10 +333,13 @@ class TiDBVectorDBStorage(BaseVectorStorage): merge_sql = SQL_TEMPLATES["insert_relationship"] await self.db.execute(merge_sql, data) - async def get_by_status_and_ids(self, status: str) -> Union[list[dict[str, Any]], None]: + async def get_by_status_and_ids( + self, status: str + ) -> Union[list[dict[str, Any]], None]: SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} - return await self.db.query(SQL, params, multirows=True) + return await self.db.query(SQL, params, multirows=True) + @dataclass class TiDBGraphStorage(BaseGraphStorage): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0ae47d1f..7a87e0e7 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -629,12 +629,7 @@ class LightRAG: # 4. Store original document for doc_id, doc in new_docs.items(): await self.full_docs.upsert( - { - doc_id: { - "content": doc["content"], - "status": DocStatus.PENDING - } - } + {doc_id: {"content": doc["content"], "status": DocStatus.PENDING}} ) logger.info(f"Stored {len(new_docs)} new unique documents") @@ -642,10 +637,14 @@ class LightRAG: """Get pendding documents, split into chunks,insert chunks""" # 1. get all pending and failed documents _todo_doc_keys = [] - - _failed_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.FAILED) - _pendding_doc = await self.full_docs.get_by_status_and_ids(status=DocStatus.PENDING) - + + _failed_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.FAILED + ) + _pendding_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.PENDING + ) + if _failed_doc: _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) if _pendding_doc: @@ -685,15 +684,19 @@ class LightRAG: ) } chunk_cnt += len(chunks) - + try: # Store chunks in vector database await self.chunks_vdb.upsert(chunks) # Update doc status - await self.text_chunks.upsert({**chunks, "status": DocStatus.PENDING}) + await self.text_chunks.upsert( + {**chunks, "status": DocStatus.PENDING} + ) except Exception as e: # Mark as failed if any step fails - await self.text_chunks.upsert({**chunks, "status": DocStatus.FAILED}) + await self.text_chunks.upsert( + {**chunks, "status": DocStatus.FAILED} + ) raise e except Exception as e: import traceback @@ -707,8 +710,12 @@ class LightRAG: """Get pendding or failed chunks, extract entities and relationships from each chunk""" # 1. get all pending and failed chunks _todo_chunk_keys = [] - _failed_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.FAILED) - _pendding_chunks = await self.text_chunks.get_by_status_and_ids(status=DocStatus.PENDING) + _failed_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.FAILED + ) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.PENDING + ) if _failed_chunks: _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) if _pendding_chunks: @@ -742,11 +749,15 @@ class LightRAG: if maybe_new_kg is None: logger.info("No entities or relationships extracted!") # Update status to processed - await self.text_chunks.upsert({chunk_id: {"status": DocStatus.PROCESSED}}) + await self.text_chunks.upsert( + {chunk_id: {"status": DocStatus.PROCESSED}} + ) except Exception as e: logger.error("Failed to extract entities and relationships") # Mark as failed if any step fails - await self.text_chunks.upsert({chunk_id: {"status": DocStatus.FAILED}}) + await self.text_chunks.upsert( + {chunk_id: {"status": DocStatus.FAILED}} + ) raise e with tqdm_async(