diff --git a/lightrag/base.py b/lightrag/base.py index 7a3b4f5f..9b3e5f00 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -90,7 +90,7 @@ class BaseKVStorage(StorageNameSpace): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: raise NotImplementedError - async def filter_keys(self, data: list[str]) -> set[str]: + async def filter_keys(self, data: set[str]) -> set[str]: """return un-exist keys""" raise NotImplementedError diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index ff184dbd..c61d088d 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -38,7 +38,7 @@ class JsonKVStorage(BaseKVStorage): for id in ids ] - async def filter_keys(self, data: list[str]) -> set[str]: + async def filter_keys(self, data: set[str]) -> set[str]: return set([s for s in data if s not in self._data]) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index 31aa836a..179b17a3 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -52,17 +52,16 @@ import os from dataclasses import dataclass from typing import Any, Union -from lightrag.utils import ( - logger, - load_json, - write_json, -) - from lightrag.base import ( - DocStatus, DocProcessingStatus, + DocStatus, DocStatusStorage, ) +from lightrag.utils import ( + load_json, + logger, + write_json, +) @dataclass @@ -75,15 +74,17 @@ class JsonDocStatusStorage(DocStatusStorage): self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Loaded document status storage with {len(self._data)} records") - async def filter_keys(self, data: list[str]) -> set[str]: + async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" - return set( - [ - k - for k in data - if k not in self._data or self._data[k]["status"] != DocStatus.PROCESSED - ] - ) + return {k for k, _ in self._data.items() if k in data} + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + result: list[dict[str, Any]] = [] + for id in ids: + data = self._data.get(id, None) + if data: + result.append(data) + return result async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" @@ -94,11 +95,19 @@ class JsonDocStatusStorage(DocStatusStorage): async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: """Get all failed documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} + return { + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == DocStatus.FAILED + } async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: """Get all pending documents""" - return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} + return { + k: DocProcessingStatus(**v) + for k, v in self._data.items() + if v["status"] == DocStatus.PENDING + } async def index_done_callback(self): """Save data to file after indexing""" @@ -118,7 +127,11 @@ class JsonDocStatusStorage(DocStatusStorage): async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: """Get document status by ID""" - return self._data.get(doc_id) + data = self._data.get(doc_id) + if data: + return DocProcessingStatus(**data) + else: + return None async def delete(self, doc_ids: list[str]): """Delete document status by IDs""" diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 35902d37..1294a26a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -35,7 +35,7 @@ class MongoKVStorage(BaseKVStorage): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: return list(self._data.find({"_id": {"$in": ids}})) - async def filter_keys(self, data: list[str]) -> set[str]: + async def filter_keys(self, data: set[str]) -> set[str]: existing_ids = [ str(x["_id"]) for x in self._data.find({"_id": {"$in": data}}, {"_id": 1}) ] diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 77fe6198..63df869e 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -421,7 +421,7 @@ class PGDocStatusStorage(DocStatusStorage): def __post_init__(self): pass - async def filter_keys(self, data: list[str]) -> set[str]: + async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that don't exist in storage""" keys = ",".join([f"'{_id}'" for _id in data]) sql = ( diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 05da41b7..ef95d6db 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -32,7 +32,7 @@ class RedisKVStorage(BaseKVStorage): results = await pipe.execute() return [json.loads(result) if result else None for result in results] - async def filter_keys(self, data: list[str]) -> set[str]: + async def filter_keys(self, data: set[str]) -> set[str]: pipe = self._redis.pipeline() for key in data: pipe.exists(f"{self.namespace}:{key}") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5d00c508..3bd3cc8f 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,28 +1,11 @@ import asyncio import os -from tqdm.asyncio import tqdm as tqdm_async +from collections.abc import Coroutine from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, Callable, Coroutine, Optional, Type, Union, cast -from .operate import ( - chunking_by_token_size, - extract_entities, - extract_keywords_only, - kg_query, - kg_query_with_keywords, - mix_kg_vector_query, - naive_query, -) +from typing import Any, Callable, Optional, Type, Union, cast -from .utils import ( - EmbeddingFunc, - compute_mdhash_id, - limit_async_func_call, - convert_response_to_json, - logger, - set_logger, -) from .base import ( BaseGraphStorage, BaseKVStorage, @@ -33,10 +16,25 @@ from .base import ( QueryParam, StorageNameSpace, ) - from .namespace import NameSpace, make_namespace - +from .operate import ( + chunking_by_token_size, + extract_entities, + extract_keywords_only, + kg_query, + kg_query_with_keywords, + mix_kg_vector_query, + naive_query, +) from .prompt import GRAPH_FIELD_SEP +from .utils import ( + EmbeddingFunc, + compute_mdhash_id, + convert_response_to_json, + limit_async_func_call, + logger, + set_logger, +) STORAGES = { "NetworkXStorage": ".kg.networkx_impl", @@ -67,7 +65,6 @@ STORAGES = { def lazy_external_import(module_name: str, class_name: str): """Lazily import a class from an external module based on the package of the caller.""" - # Get the caller's module and package import inspect @@ -113,7 +110,7 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: @dataclass class LightRAG: working_dir: str = field( - default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" + default_factory=lambda: f'./lightrag_cache_{datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}' ) # Default not to use embedding cache embedding_cache_config: dict = field( @@ -496,15 +493,15 @@ class LightRAG: } # 3. Filter out already processed documents - add_doc_keys: set[str] = set() + new_doc_keys: set[str] = set() # Get docs ids - in_process_keys = list(new_docs.keys()) + in_process_keys = set(new_docs.keys()) # Get in progress docs ids - excluded_ids = await self.doc_status.get_by_ids(in_process_keys) + excluded_ids = await self.doc_status.filter_keys(list(in_process_keys)) # Exclude already in process - add_doc_keys = new_docs.keys() - excluded_ids + new_doc_keys = in_process_keys - excluded_ids # Filter - new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys} + new_docs = {doc_id: new_docs[doc_id] for doc_id in new_doc_keys} if not new_docs: logger.info("All documents have been processed or are duplicates") @@ -562,15 +559,12 @@ class LightRAG: # 3. iterate over batches tasks: dict[str, list[Coroutine[Any, Any, None]]] = {} - for batch_idx, ids_doc_processing_status in tqdm_async( - enumerate(batch_docs_list), - desc="Process Batches", - ): + + logger.info(f"Number of batches to process: {len(batch_docs_list)}.") + + for batch_idx, ids_doc_processing_status in enumerate(batch_docs_list): # 4. iterate over batch - for id_doc_processing_status in tqdm_async( - ids_doc_processing_status, - desc=f"Process Batch {batch_idx}", - ): + for id_doc_processing_status in ids_doc_processing_status: id_doc, status_doc = id_doc_processing_status # Update status in processing await self.doc_status.upsert( @@ -644,6 +638,7 @@ class LightRAG: } ) continue + logger.info(f"Completed batch {batch_idx + 1} of {len(batch_docs_list)}.") async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: try: @@ -895,7 +890,6 @@ class LightRAG: 1. Extract keywords from the 'query' using new function in operate.py. 2. Then run the standard aquery() flow with the final prompt (formatted_question). """ - loop = always_get_an_event_loop() return loop.run_until_complete( self.aquery_with_separate_keyword_extraction(query, prompt, param) @@ -908,7 +902,6 @@ class LightRAG: 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. """ - # --------------------- # STEP 1: Keyword Extraction # ---------------------