diff --git a/README.md b/README.md index 456d9a72..62dc032b 100644 --- a/README.md +++ b/README.md @@ -408,6 +408,21 @@ rag = LightRAG( with open("./newText.txt") as f: rag.insert(f.read()) ``` + +### Insert using Pipeline +The `apipeline_enqueue_documents` and `apipeline_process_enqueue_documents` functions allow you to perform incremental insertion of documents into the graph. + +This is useful for scenarios where you want to process documents in the background while still allowing the main thread to continue executing. + +And using a routine to process news documents. + +```python +rag = LightRAG(..) +await rag.apipeline_enqueue_documents(string_or_strings) +# Your routine in loop +await rag.apipeline_process_enqueue_documents(string_or_strings) +``` + ### Separate Keyword Extraction We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords. diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 47020fd6..f5269fae 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -121,9 +121,8 @@ async def main(): texts = [x for x in all_text.split("\n") if x] # New mode use pipeline - await rag.apipeline_process_documents(texts) - await rag.apipeline_process_chunks() - await rag.apipeline_process_extract_graph() + await rag.apipeline_enqueue_documents(texts) + await rag.apipeline_process_enqueue_documents() # Old method use ainsert # await rag.ainsert(texts) diff --git a/lightrag/base.py b/lightrag/base.py index e71cac3f..7a3b4f5f 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -1,20 +1,18 @@ +from enum import Enum import os from dataclasses import dataclass, field from typing import ( + Optional, TypedDict, Union, Literal, - Generic, TypeVar, - Optional, - Dict, Any, - List, ) -from enum import Enum import numpy as np + from .utils import EmbeddingFunc TextChunkSchema = TypedDict( @@ -45,7 +43,7 @@ class QueryParam: hl_keywords: list[str] = field(default_factory=list) ll_keywords: list[str] = field(default_factory=list) # Conversation history support - conversation_history: list[dict] = field( + conversation_history: list[dict[str, str]] = field( default_factory=list ) # Format: [{"role": "user/assistant", "content": "message"}] history_turns: int = ( @@ -56,7 +54,7 @@ class QueryParam: @dataclass class StorageNameSpace: namespace: str - global_config: dict + global_config: dict[str, Any] async def index_done_callback(self): """commit the storage operations after indexing""" @@ -72,10 +70,10 @@ class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc meta_fields: set = field(default_factory=set) - async def query(self, query: str, top_k: int) -> list[dict]: + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: raise NotImplementedError - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Use 'content' field from value for embedding, use key as id. If embedding_func is None, use 'embedding' field from value """ @@ -83,28 +81,23 @@ class BaseVectorStorage(StorageNameSpace): @dataclass -class BaseKVStorage(Generic[T], StorageNameSpace): +class BaseKVStorage(StorageNameSpace): embedding_func: EmbeddingFunc - async def all_keys(self) -> list[str]: + async def get_by_id(self, id: str) -> dict[str, Any]: raise NotImplementedError - async def get_by_id(self, id: str) -> Union[T, None]: - raise NotImplementedError - - async def get_by_ids( - self, ids: list[str], fields: Union[set[str], None] = None - ) -> list[Union[T, None]]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: raise NotImplementedError async def filter_keys(self, data: list[str]) -> set[str]: """return un-exist keys""" raise NotImplementedError - async def upsert(self, data: dict[str, T]): + async def upsert(self, data: dict[str, Any]) -> None: raise NotImplementedError - async def drop(self): + async def drop(self) -> None: raise NotImplementedError @@ -151,12 +144,12 @@ class BaseGraphStorage(StorageNameSpace): async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: raise NotImplementedError("Node embedding is not used in lightrag.") - async def get_all_labels(self) -> List[str]: + async def get_all_labels(self) -> list[str]: raise NotImplementedError async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 - ) -> Dict[str, List[Dict]]: + ) -> dict[str, list[dict]]: raise NotImplementedError @@ -173,27 +166,37 @@ class DocStatus(str, Enum): class DocProcessingStatus: """Document processing status data structure""" - content_summary: str # First 100 chars of document content - content_length: int # Total length of document - status: DocStatus # Current processing status - created_at: str # ISO format timestamp - updated_at: str # ISO format timestamp - chunks_count: Optional[int] = None # Number of chunks after splitting - error: Optional[str] = None # Error message if failed - metadata: Dict[str, Any] = field(default_factory=dict) # Additional metadata + content: str + """Original content of the document""" + content_summary: str + """First 100 chars of document content, used for preview""" + content_length: int + """Total length of document""" + status: DocStatus + """Current processing status""" + created_at: str + """ISO format timestamp when document was created""" + updated_at: str + """ISO format timestamp when document was last updated""" + chunks_count: Optional[int] = None + """Number of chunks after splitting, used for processing""" + error: Optional[str] = None + """Error message if failed""" + metadata: dict[str, Any] = field(default_factory=dict) + """Additional metadata""" class DocStatusStorage(BaseKVStorage): """Base class for document status storage""" - async def get_status_counts(self) -> Dict[str, int]: + async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" raise NotImplementedError - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: """Get all failed documents""" raise NotImplementedError - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: """Get all pending documents""" raise NotImplementedError diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 2fb753fe..ff184dbd 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -1,63 +1,13 @@ -""" -JsonDocStatus Storage Module -======================= - -This module provides a storage interface for graphs using NetworkX, a popular Python library for creating, manipulating, and studying the structure, dynamics, and functions of complex networks. - -The `NetworkXStorage` class extends the `BaseGraphStorage` class from the LightRAG library, providing methods to load, save, manipulate, and query graphs using NetworkX. - -Author: lightrag team -Created: 2024-01-25 -License: MIT - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Version: 1.0.0 - -Dependencies: - - NetworkX - - NumPy - - LightRAG - - graspologic - -Features: - - Load and save graphs in various formats (e.g., GEXF, GraphML, JSON) - - Query graph nodes and edges - - Calculate node and edge degrees - - Embed nodes using various algorithms (e.g., Node2Vec) - - Remove nodes and edges from the graph - -Usage: - from lightrag.storage.networkx_storage import NetworkXStorage - -""" - import asyncio import os from dataclasses import dataclass +from typing import Any from lightrag.utils import ( logger, load_json, write_json, ) - from lightrag.base import ( BaseKVStorage, ) @@ -68,25 +18,20 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} + self._data: dict[str, Any] = load_json(self._file_name) or {} self._lock = asyncio.Lock() logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - async def all_keys(self) -> list[str]: - return list(self._data.keys()) - async def index_done_callback(self): write_json(self._data, self._file_name) - async def get_by_id(self, id): - return self._data.get(id, None) + async def get_by_id(self, id: str) -> dict[str, Any]: + return self._data.get(id, {}) - async def get_by_ids(self, ids, fields=None): - if fields is None: - return [self._data.get(id, None) for id in ids] + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: return [ ( - {k: v for k, v in self._data[id].items() if k in fields} + {k: v for k, v in self._data[id].items()} if self._data.get(id, None) else None ) @@ -96,39 +41,9 @@ class JsonKVStorage(BaseKVStorage): async def filter_keys(self, data: list[str]) -> set[str]: return set([s for s in data if s not in self._data]) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) - return left_data - async def drop(self): + async def drop(self) -> None: self._data = {} - - async def filter(self, filter_func): - """Filter key-value pairs based on a filter function - - Args: - filter_func: The filter function, which takes a value as an argument and returns a boolean value - - Returns: - Dict: Key-value pairs that meet the condition - """ - result = {} - async with self._lock: - for key, value in self._data.items(): - if filter_func(value): - result[key] = value - return result - - async def delete(self, ids: list[str]): - """Delete data with specified IDs - - Args: - ids: List of IDs to delete - """ - async with self._lock: - for id in ids: - if id in self._data: - del self._data[id] - await self.index_done_callback() - logger.info(f"Successfully deleted {len(ids)} items from {self.namespace}") diff --git a/lightrag/kg/jsondocstatus_impl.py b/lightrag/kg/jsondocstatus_impl.py index 8f326170..31aa836a 100644 --- a/lightrag/kg/jsondocstatus_impl.py +++ b/lightrag/kg/jsondocstatus_impl.py @@ -50,7 +50,7 @@ Usage: import os from dataclasses import dataclass -from typing import Union, Dict +from typing import Any, Union from lightrag.utils import ( logger, @@ -72,7 +72,7 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} + self._data: dict[str, Any] = load_json(self._file_name) or {} logger.info(f"Loaded document status storage with {len(self._data)} records") async def filter_keys(self, data: list[str]) -> set[str]: @@ -85,18 +85,18 @@ class JsonDocStatusStorage(DocStatusStorage): ] ) - async def get_status_counts(self) -> Dict[str, int]: + async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" counts = {status: 0 for status in DocStatus} for doc in self._data.values(): counts[doc["status"]] += 1 return counts - async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: + async def get_failed_docs(self) -> dict[str, DocProcessingStatus]: """Get all failed documents""" return {k: v for k, v in self._data.items() if v["status"] == DocStatus.FAILED} - async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]: + async def get_pending_docs(self) -> dict[str, DocProcessingStatus]: """Get all pending documents""" return {k: v for k, v in self._data.items() if v["status"] == DocStatus.PENDING} @@ -104,7 +104,7 @@ class JsonDocStatusStorage(DocStatusStorage): """Save data to file after indexing""" write_json(self._data, self._file_name) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: """Update or insert document status Args: @@ -112,10 +112,9 @@ class JsonDocStatusStorage(DocStatusStorage): """ self._data.update(data) await self.index_done_callback() - return data - async def get_by_id(self, id: str): - return self._data.get(id) + async def get_by_id(self, id: str) -> dict[str, Any]: + return self._data.get(id, {}) async def get(self, doc_id: str) -> Union[DocProcessingStatus, None]: """Get document status by ID""" diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 7afc4240..35902d37 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -12,7 +12,7 @@ if not pm.is_installed("motor"): from pymongo import MongoClient from motor.motor_asyncio import AsyncIOMotorClient -from typing import Union, List, Tuple +from typing import Any, Union, List, Tuple from ..utils import logger from ..base import BaseKVStorage, BaseGraphStorage @@ -29,21 +29,11 @@ class MongoKVStorage(BaseKVStorage): self._data = database.get_collection(self.namespace) logger.info(f"Use MongoDB as KV {self.namespace}") - async def all_keys(self) -> list[str]: - return [x["_id"] for x in self._data.find({}, {"_id": 1})] - - async def get_by_id(self, id): + async def get_by_id(self, id: str) -> dict[str, Any]: return self._data.find_one({"_id": id}) - async def get_by_ids(self, ids, fields=None): - if fields is None: - return list(self._data.find({"_id": {"$in": ids}})) - return list( - self._data.find( - {"_id": {"$in": ids}}, - {field: 1 for field in fields}, - ) - ) + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + return list(self._data.find({"_id": {"$in": ids}})) async def filter_keys(self, data: list[str]) -> set[str]: existing_ids = [ @@ -51,7 +41,7 @@ class MongoKVStorage(BaseKVStorage): ] return set([s for s in data if s not in existing_ids]) - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): for mode, items in data.items(): for k, v in tqdm_async(items.items(), desc="Upserting"): @@ -66,7 +56,6 @@ class MongoKVStorage(BaseKVStorage): for k, v in tqdm_async(data.items(), desc="Upserting"): self._data.update_one({"_id": k}, {"$set": v}, upsert=True) data[k]["_id"] = k - return data async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): @@ -81,9 +70,9 @@ class MongoKVStorage(BaseKVStorage): else: return None - async def drop(self): - """ """ - pass + async def drop(self) -> None: + """Drop the collection""" + await self._data.drop() @dataclass diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index a1a05759..b648c9bc 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -4,7 +4,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union +from typing import Any, Union import numpy as np import array import pipmaster as pm @@ -181,7 +181,7 @@ class OracleKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict, None]: + async def get_by_id(self, id: str) -> dict[str, Any]: """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -191,12 +191,9 @@ class OracleKVStorage(BaseKVStorage): res = {} for row in array_res: res[row["id"]] = row - else: - res = await self.db.query(SQL, params) - if res: return res else: - return None + return await self.db.query(SQL, params) async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" @@ -211,7 +208,7 @@ class OracleKVStorage(BaseKVStorage): else: return None - async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -230,29 +227,7 @@ class OracleKVStorage(BaseKVStorage): for row in res: dict_res[row["mode"]][row["id"]] = row res = [{k: v} for k, v in dict_res.items()] - if res: - data = res # [{"data":i} for i in res] - # print(data) - return data - else: - return None - - async def get_by_status_and_ids( - self, status: str, ids: list[str] - ) -> Union[list[dict], None]: - """Specifically for llm_response_cache.""" - if ids is not None: - SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - else: - SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] - params = {"workspace": self.db.workspace, "status": status} - res = await self.db.query(SQL, params, multirows=True) - if res: - return res - else: - return None + return res async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" @@ -270,7 +245,7 @@ class OracleKVStorage(BaseKVStorage): return set(keys) ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { @@ -328,14 +303,6 @@ class OracleKVStorage(BaseKVStorage): } await self.db.execute(upsert_sql, _data) - return None - - async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format( - table_name=namespace_to_table_name(self.namespace) - ) - params = {"workspace": self.db.workspace, "id": id, "status": status} - await self.db.execute(SQL, params) async def index_done_callback(self): if is_namespace( @@ -745,7 +712,6 @@ SQL_TEMPLATES = { "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = :id and a.workspace = :workspace) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 8884d92e..77fe6198 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -30,7 +30,6 @@ from ..base import ( DocStatus, DocProcessingStatus, BaseGraphStorage, - T, ) from ..namespace import NameSpace, is_namespace @@ -184,7 +183,7 @@ class PGKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict, None]: + async def get_by_id(self, id: str) -> dict[str, Any]: """Get doc_full data by id.""" sql = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} @@ -193,12 +192,9 @@ class PGKVStorage(BaseKVStorage): res = {} for row in array_res: res[row["id"]] = row - else: - res = await self.db.query(sql, params) - if res: return res else: - return None + return await self.db.query(sql, params) async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" @@ -214,7 +210,7 @@ class PGKVStorage(BaseKVStorage): return None # Query by id - async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]: + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get doc_chunks data by id""" sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) @@ -231,23 +227,15 @@ class PGKVStorage(BaseKVStorage): dict_res[mode] = {} for row in array_res: dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] + return [{k: v} for k, v in dict_res.items()] else: - res = await self.db.query(sql, params, multirows=True) - if res: - return res - else: - return None + return await self.db.query(sql, params, multirows=True) - async def all_keys(self) -> list[dict]: - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - sql = "select workspace,mode,id from lightrag_llm_cache" - res = await self.db.query(sql, multirows=True) - return res - else: - logger.error( - f"all_keys is only implemented for llm_response_cache, not for {self.namespace}" - ) + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + return await self.db.query(SQL, params, multirows=True) async def filter_keys(self, keys: List[str]) -> Set[str]: """Filter out duplicated content""" @@ -270,7 +258,7 @@ class PGKVStorage(BaseKVStorage): print(params) ################ INSERT METHODS ################ - async def upsert(self, data: Dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): pass elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): @@ -447,14 +435,15 @@ class PGDocStatusStorage(DocStatusStorage): existed = set([element["id"] for element in result]) return set(data) - existed - async def get_by_id(self, id: str) -> Union[T, None]: + async def get_by_id(self, id: str) -> dict[str, Any]: sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" params = {"workspace": self.db.workspace, "id": id} result = await self.db.query(sql, params, True) if result is None or result == []: - return None + return {} else: return DocProcessingStatus( + content=result[0]["content"], content_length=result[0]["content_length"], content_summary=result[0]["content_summary"], status=result[0]["status"], @@ -483,10 +472,9 @@ class PGDocStatusStorage(DocStatusStorage): sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1" params = {"workspace": self.db.workspace, "status": status} result = await self.db.query(sql, params, True) - # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] - # Converting to be a dict return { element["id"]: DocProcessingStatus( + content=result[0]["content"], content_summary=element["content_summary"], content_length=element["content_length"], status=element["status"], @@ -518,6 +506,7 @@ class PGDocStatusStorage(DocStatusStorage): sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status) values($1,$2,$3,$4,$5,$6) on conflict(id,workspace) do update set + content = EXCLUDED.content, content_summary = EXCLUDED.content_summary, content_length = EXCLUDED.content_length, chunks_count = EXCLUDED.chunks_count, @@ -530,6 +519,7 @@ class PGDocStatusStorage(DocStatusStorage): { "workspace": self.db.workspace, "id": k, + "content": v["content"], "content_summary": v["content_summary"], "content_length": v["content_length"], "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 147ea5f3..05da41b7 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -1,4 +1,5 @@ import os +from typing import Any from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm @@ -20,29 +21,15 @@ class RedisKVStorage(BaseKVStorage): self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") - async def all_keys(self) -> list[str]: - keys = await self._redis.keys(f"{self.namespace}:*") - return [key.split(":", 1)[-1] for key in keys] - async def get_by_id(self, id): data = await self._redis.get(f"{self.namespace}:{id}") return json.loads(data) if data else None - async def get_by_ids(self, ids, fields=None): + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: pipe = self._redis.pipeline() for id in ids: pipe.get(f"{self.namespace}:{id}") results = await pipe.execute() - - if fields: - # Filter fields if specified - return [ - {field: value.get(field) for field in fields if field in value} - if (value := json.loads(result)) - else None - for result in results - ] - return [json.loads(result) if result else None for result in results] async def filter_keys(self, data: list[str]) -> set[str]: @@ -54,7 +41,7 @@ class RedisKVStorage(BaseKVStorage): existing_ids = {data[i] for i, exists in enumerate(results) if exists} return set(data) - existing_ids - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: pipe = self._redis.pipeline() for k, v in tqdm_async(data.items(), desc="Upserting"): pipe.set(f"{self.namespace}:{k}", json.dumps(v)) @@ -62,9 +49,8 @@ class RedisKVStorage(BaseKVStorage): for k in data: data[k]["_id"] = k - return data - async def drop(self): + async def drop(self) -> None: keys = await self._redis.keys(f"{self.namespace}:*") if keys: await self._redis.delete(*keys) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index cb819d47..1f454639 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -1,7 +1,7 @@ import asyncio import os from dataclasses import dataclass -from typing import Union +from typing import Any, Union import numpy as np import pipmaster as pm @@ -108,33 +108,20 @@ class TiDBKVStorage(BaseKVStorage): ################ QUERY METHODS ################ - async def get_by_id(self, id: str) -> Union[dict, None]: - """根据 id 获取 doc_full 数据.""" + async def get_by_id(self, id: str) -> dict[str, Any]: + """Fetch doc_full data by id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL, params) - if res: - data = res # {"data":res} - # print (data) - return data - else: - return None + return await self.db.query(SQL, params) # Query by id - async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: - """根据 id 获取 doc_chunks 数据""" + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Fetch doc_chunks data by id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) - # print("get_by_ids:"+SQL) - res = await self.db.query(SQL, multirows=True) - if res: - data = res # [{"data":i} for i in res] - # print(data) - return data - else: - return None + return await self.db.query(SQL, multirows=True) async def filter_keys(self, keys: list[str]) -> set[str]: """过滤掉重复内容""" @@ -158,7 +145,7 @@ class TiDBKVStorage(BaseKVStorage): return data ################ INSERT full_doc AND chunks ################ - async def upsert(self, data: dict[str, dict]): + async def upsert(self, data: dict[str, Any]) -> None: left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): @@ -335,6 +322,11 @@ class TiDBVectorDBStorage(BaseVectorStorage): merge_sql = SQL_TEMPLATES["insert_relationship"] await self.db.execute(merge_sql, data) + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + return await self.db.query(SQL, params, multirows=True) + @dataclass class TiDBGraphStorage(BaseGraphStorage): diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6b925be3..5d00c508 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -4,17 +4,15 @@ from tqdm.asyncio import tqdm as tqdm_async from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Type, cast, Dict - +from typing import Any, Callable, Coroutine, Optional, Type, Union, cast from .operate import ( chunking_by_token_size, extract_entities, - # local_query,global_query,hybrid_query, - kg_query, - naive_query, - mix_kg_vector_query, extract_keywords_only, + kg_query, kg_query_with_keywords, + mix_kg_vector_query, + naive_query, ) from .utils import ( @@ -24,15 +22,16 @@ from .utils import ( convert_response_to_json, logger, set_logger, - statistic_data, ) from .base import ( BaseGraphStorage, BaseKVStorage, BaseVectorStorage, - StorageNameSpace, - QueryParam, + DocProcessingStatus, DocStatus, + DocStatusStorage, + QueryParam, + StorageNameSpace, ) from .namespace import NameSpace, make_namespace @@ -176,15 +175,26 @@ class LightRAG: enable_llm_cache_for_entity_extract: bool = True # extension - addon_params: dict = field(default_factory=dict) - convert_response_to_json_func: callable = convert_response_to_json + addon_params: dict[str, Any] = field(default_factory=dict) + convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( + convert_response_to_json + ) # Add new field for document status storage type doc_status_storage: str = field(default="JsonDocStatusStorage") # Custom Chunking Function - chunking_func: callable = chunking_by_token_size - chunking_func_kwargs: dict = field(default_factory=dict) + chunking_func: Callable[ + [ + str, + Optional[str], + bool, + int, + int, + str, + ], + list[dict[str, Any]], + ] = chunking_by_token_size def __post_init__(self): os.makedirs(self.log_dir, exist_ok=True) @@ -245,19 +255,19 @@ class LightRAG: #### # add embedding func by walter #### - self.full_docs = self.key_string_value_json_storage_cls( + self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS ), embedding_func=self.embedding_func, ) - self.text_chunks = self.key_string_value_json_storage_cls( + self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS ), embedding_func=self.embedding_func, ) - self.chunk_entity_relation_graph = self.graph_storage_cls( + self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( namespace=make_namespace( self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION ), @@ -281,7 +291,7 @@ class LightRAG: embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) - self.chunks_vdb = self.vector_db_storage_cls( + self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( namespace=make_namespace( self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS ), @@ -310,7 +320,7 @@ class LightRAG: # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - self.doc_status = self.doc_status_storage_cls( + self.doc_status: DocStatusStorage = self.doc_status_storage_cls( namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), global_config=global_config, embedding_func=None, @@ -351,17 +361,12 @@ class LightRAG: storage.db = db_client def insert( - self, string_or_strings, split_by_character=None, split_by_character_only=False + self, + string_or_strings: Union[str, list[str]], + split_by_character: str | None = None, + split_by_character_only: bool = False, ): - loop = always_get_an_event_loop() - return loop.run_until_complete( - self.ainsert(string_or_strings, split_by_character, split_by_character_only) - ) - - async def ainsert( - self, string_or_strings, split_by_character=None, split_by_character_only=False - ): - """Insert documents with checkpoint support + """Sync Insert documents with checkpoint support Args: string_or_strings: Single document string or list of document strings @@ -370,154 +375,30 @@ class LightRAG: split_by_character_only: if split_by_character_only is True, split the string by character only, when split_by_character is None, this parameter is ignored. """ - if isinstance(string_or_strings, str): - string_or_strings = [string_or_strings] + loop = always_get_an_event_loop() + return loop.run_until_complete( + self.ainsert(string_or_strings, split_by_character, split_by_character_only) + ) - # 1. Remove duplicate contents from the list - unique_contents = list(set(doc.strip() for doc in string_or_strings)) + async def ainsert( + self, + string_or_strings: Union[str, list[str]], + split_by_character: str | None = None, + split_by_character_only: bool = False, + ): + """Async Insert documents with checkpoint support - # 2. Generate document IDs and initial status - new_docs = { - compute_mdhash_id(content, prefix="doc-"): { - "content": content, - "content_summary": self._get_content_summary(content), - "content_length": len(content), - "status": DocStatus.PENDING, - "created_at": datetime.now().isoformat(), - "updated_at": datetime.now().isoformat(), - } - for content in unique_contents - } - - # 3. Filter out already processed documents - # _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) - _add_doc_keys = set() - for doc_id in new_docs.keys(): - current_doc = await self.doc_status.get_by_id(doc_id) - - if current_doc is None: - _add_doc_keys.add(doc_id) - continue # skip to the next doc_id - - status = None - if isinstance(current_doc, dict): - status = current_doc["status"] - else: - status = current_doc.status - - if status == DocStatus.FAILED: - _add_doc_keys.add(doc_id) - - new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - - if not new_docs: - logger.info("All documents have been processed or are duplicates") - return - - logger.info(f"Processing {len(new_docs)} new unique documents") - - # Process documents in batches - batch_size = self.addon_params.get("insert_batch_size", 10) - for i in range(0, len(new_docs), batch_size): - batch_docs = dict(list(new_docs.items())[i : i + batch_size]) - - for doc_id, doc in tqdm_async( - batch_docs.items(), desc=f"Processing batch {i // batch_size + 1}" - ): - try: - # Update status to processing - doc_status = { - "content_summary": doc["content_summary"], - "content_length": doc["content_length"], - "status": DocStatus.PROCESSING, - "created_at": doc["created_at"], - "updated_at": datetime.now().isoformat(), - } - await self.doc_status.upsert({doc_id: doc_status}) - - # Generate chunks from document - chunks = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_id, - } - for dp in self.chunking_func( - doc["content"], - split_by_character=split_by_character, - split_by_character_only=split_by_character_only, - overlap_token_size=self.chunk_overlap_token_size, - max_token_size=self.chunk_token_size, - tiktoken_model=self.tiktoken_model_name, - **self.chunking_func_kwargs, - ) - } - - # Update status with chunks information - doc_status.update( - { - "chunks_count": len(chunks), - "updated_at": datetime.now().isoformat(), - } - ) - await self.doc_status.upsert({doc_id: doc_status}) - - try: - # Store chunks in vector database - await self.chunks_vdb.upsert(chunks) - - # Extract and store entities and relationships - maybe_new_kg = await extract_entities( - chunks, - knowledge_graph_inst=self.chunk_entity_relation_graph, - entity_vdb=self.entities_vdb, - relationships_vdb=self.relationships_vdb, - llm_response_cache=self.llm_response_cache, - global_config=asdict(self), - ) - - if maybe_new_kg is None: - raise Exception( - "Failed to extract entities and relationships" - ) - - self.chunk_entity_relation_graph = maybe_new_kg - - # Store original document and chunks - await self.full_docs.upsert( - {doc_id: {"content": doc["content"]}} - ) - await self.text_chunks.upsert(chunks) - - # Update status to processed - doc_status.update( - { - "status": DocStatus.PROCESSED, - "updated_at": datetime.now().isoformat(), - } - ) - await self.doc_status.upsert({doc_id: doc_status}) - - except Exception as e: - # Mark as failed if any step fails - doc_status.update( - { - "status": DocStatus.FAILED, - "error": str(e), - "updated_at": datetime.now().isoformat(), - } - ) - await self.doc_status.upsert({doc_id: doc_status}) - raise e - - except Exception as e: - import traceback - - error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" - logger.error(error_msg) - continue - else: - # Only update index when processing succeeds - await self._insert_done() + Args: + string_or_strings: Single document string or list of document strings + split_by_character: if split_by_character is not None, split the string by character, if chunk longer than + chunk_size, split the sub chunk by token size. + split_by_character_only: if split_by_character_only is True, split the string by character only, when + split_by_character is None, this parameter is ignored. + """ + await self.apipeline_enqueue_documents(string_or_strings) + await self.apipeline_process_enqueue_documents( + split_by_character, split_by_character_only + ) def insert_custom_chunks(self, full_text: str, text_chunks: list[str]): loop = always_get_an_event_loop() @@ -586,10 +467,14 @@ class LightRAG: if update_storage: await self._insert_done() - async def apipeline_process_documents(self, string_or_strings): - """Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs - Args: - string_or_strings: Single document string or list of document strings + async def apipeline_enqueue_documents(self, string_or_strings: str | list[str]): + """ + Pipeline for Processing Documents + + 1. Remove duplicate contents from the list + 2. Generate document IDs and initial status + 3. Filter out already processed documents + 4. Enqueue document in status """ if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] @@ -597,183 +482,187 @@ class LightRAG: # 1. Remove duplicate contents from the list unique_contents = list(set(doc.strip() for doc in string_or_strings)) - logger.info( - f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents" - ) - # 2. Generate document IDs and initial status - new_docs = { + new_docs: dict[str, Any] = { compute_mdhash_id(content, prefix="doc-"): { "content": content, "content_summary": self._get_content_summary(content), "content_length": len(content), "status": DocStatus.PENDING, "created_at": datetime.now().isoformat(), - "updated_at": None, + "updated_at": datetime.now().isoformat(), } for content in unique_contents } # 3. Filter out already processed documents - _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) - if len(_not_stored_doc_keys) < len(new_docs): - logger.info( - f"Skipping {len(new_docs) - len(_not_stored_doc_keys)} already existing documents" - ) - new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} + add_doc_keys: set[str] = set() + # Get docs ids + in_process_keys = list(new_docs.keys()) + # Get in progress docs ids + excluded_ids = await self.doc_status.get_by_ids(in_process_keys) + # Exclude already in process + add_doc_keys = new_docs.keys() - excluded_ids + # Filter + new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys} if not new_docs: logger.info("All documents have been processed or are duplicates") - return None + return - # 4. Store original document - for doc_id, doc in new_docs.items(): - await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) - await self.full_docs.change_status(doc_id, DocStatus.PENDING) + # 4. Store status document + await self.doc_status.upsert(new_docs) logger.info(f"Stored {len(new_docs)} new unique documents") - async def apipeline_process_chunks(self): - """Get pendding documents, split into chunks,insert chunks""" - # 1. get all pending and failed documents - _todo_doc_keys = [] - _failed_doc = await self.full_docs.get_by_status_and_ids( - status=DocStatus.FAILED, ids=None - ) - _pendding_doc = await self.full_docs.get_by_status_and_ids( - status=DocStatus.PENDING, ids=None - ) - if _failed_doc: - _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) - if _pendding_doc: - _todo_doc_keys.extend([doc["id"] for doc in _pendding_doc]) - if not _todo_doc_keys: - logger.info("All documents have been processed or are duplicates") - return None - else: - logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") + async def apipeline_process_enqueue_documents( + self, + split_by_character: str | None = None, + split_by_character_only: bool = False, + ) -> None: + """ + Process pending documents by splitting them into chunks, processing + each chunk for entity and relation extraction, and updating the + document status. - new_docs = { - doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys) - } + 1. Get all pending and failed documents + 2. Split document content into chunks + 3. Process each chunk for entity and relation extraction + 4. Update the document status + """ + # 1. get all pending and failed documents + to_process_docs: dict[str, DocProcessingStatus] = {} + + # Fetch failed documents + failed_docs = await self.doc_status.get_failed_docs() + to_process_docs.update(failed_docs) + + pending_docs = await self.doc_status.get_pending_docs() + to_process_docs.update(pending_docs) + + if not to_process_docs: + logger.info("All documents have been processed or are duplicates") + return + + to_process_docs_ids = list(to_process_docs.keys()) + + # Get allready processed documents (text chunks and full docs) + text_chunks_processed_doc_ids = await self.text_chunks.filter_keys( + to_process_docs_ids + ) + full_docs_processed_doc_ids = await self.full_docs.filter_keys( + to_process_docs_ids + ) # 2. split docs into chunks, insert chunks, update doc status - chunk_cnt = 0 batch_size = self.addon_params.get("insert_batch_size", 10) - for i in range(0, len(new_docs), batch_size): - batch_docs = dict(list(new_docs.items())[i : i + batch_size]) - for doc_id, doc in tqdm_async( - batch_docs.items(), - desc=f"Level 1 - Spliting doc in batch {i // batch_size + 1}", + batch_docs_list = [ + list(to_process_docs.items())[i : i + batch_size] + for i in range(0, len(to_process_docs), batch_size) + ] + + # 3. iterate over batches + tasks: dict[str, list[Coroutine[Any, Any, None]]] = {} + for batch_idx, ids_doc_processing_status in tqdm_async( + enumerate(batch_docs_list), + desc="Process Batches", + ): + # 4. iterate over batch + for id_doc_processing_status in tqdm_async( + ids_doc_processing_status, + desc=f"Process Batch {batch_idx}", ): - try: - # Generate chunks from document - chunks = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_id, - "status": DocStatus.PENDING, - } - for dp in chunking_by_token_size( - doc["content"], - overlap_token_size=self.chunk_overlap_token_size, - max_token_size=self.chunk_token_size, - tiktoken_model=self.tiktoken_model_name, - ) - } - chunk_cnt += len(chunks) - await self.text_chunks.upsert(chunks) - await self.text_chunks.change_status(doc_id, DocStatus.PROCESSING) - - try: - # Store chunks in vector database - await self.chunks_vdb.upsert(chunks) - # Update doc status - await self.full_docs.change_status(doc_id, DocStatus.PROCESSED) - except Exception as e: - # Mark as failed if any step fails - await self.full_docs.change_status(doc_id, DocStatus.FAILED) - raise e - except Exception as e: - import traceback - - error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" - logger.error(error_msg) - continue - logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") - - async def apipeline_process_extract_graph(self): - """Get pendding or failed chunks, extract entities and relationships from each chunk""" - # 1. get all pending and failed chunks - _todo_chunk_keys = [] - _failed_chunks = await self.text_chunks.get_by_status_and_ids( - status=DocStatus.FAILED, ids=None - ) - _pendding_chunks = await self.text_chunks.get_by_status_and_ids( - status=DocStatus.PENDING, ids=None - ) - if _failed_chunks: - _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) - if _pendding_chunks: - _todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks]) - if not _todo_chunk_keys: - logger.info("All chunks have been processed or are duplicates") - return None - - # Process documents in batches - batch_size = self.addon_params.get("insert_batch_size", 10) - - semaphore = asyncio.Semaphore( - batch_size - ) # Control the number of tasks that are processed simultaneously - - async def process_chunk(chunk_id): - async with semaphore: - chunks = { - i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) - } - # Extract and store entities and relationships - try: - maybe_new_kg = await extract_entities( - chunks, - knowledge_graph_inst=self.chunk_entity_relation_graph, - entity_vdb=self.entities_vdb, - relationships_vdb=self.relationships_vdb, - llm_response_cache=self.llm_response_cache, - global_config=asdict(self), - ) - if maybe_new_kg is None: - logger.info("No entities or relationships extracted!") - # Update status to processed - await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED) - except Exception as e: - logger.error("Failed to extract entities and relationships") - # Mark as failed if any step fails - await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) - raise e - - with tqdm_async( - total=len(_todo_chunk_keys), - desc="\nLevel 1 - Processing chunks", - unit="chunk", - position=0, - ) as progress: - tasks = [] - for chunk_id in _todo_chunk_keys: - task = asyncio.create_task(process_chunk(chunk_id)) - tasks.append(task) - - for future in asyncio.as_completed(tasks): - await future - progress.update(1) - progress.set_postfix( + id_doc, status_doc = id_doc_processing_status + # Update status in processing + await self.doc_status.upsert( { - "LLM call": statistic_data["llm_call"], - "LLM cache": statistic_data["llm_cache"], + id_doc: { + "status": DocStatus.PROCESSING, + "updated_at": datetime.now().isoformat(), + "content_summary": status_doc.content_summary, + "content_length": status_doc.content_length, + "created_at": status_doc.created_at, + } } ) + # Generate chunks from document + chunks: dict[str, Any] = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": id_doc_processing_status, + } + for dp in self.chunking_func( + status_doc.content, + split_by_character, + split_by_character_only, + self.chunk_overlap_token_size, + self.chunk_token_size, + self.tiktoken_model_name, + ) + } - # Ensure all indexes are updated after each document - await self._insert_done() + # Ensure chunk insertion and graph processing happen sequentially, not in parallel + await self.chunks_vdb.upsert(chunks) + await self._process_entity_relation_graph(chunks) + + tasks[id_doc] = [] + # Check if document already processed the doc + if id_doc not in full_docs_processed_doc_ids: + tasks[id_doc].append( + self.full_docs.upsert({id_doc: {"content": status_doc.content}}) + ) + + # Check if chunks already processed the doc + if id_doc not in text_chunks_processed_doc_ids: + tasks[id_doc].append(self.text_chunks.upsert(chunks)) + + # Process document (text chunks and full docs) in parallel + for id_doc_processing_status, task in tasks.items(): + try: + await asyncio.gather(*task) + await self.doc_status.upsert( + { + id_doc_processing_status: { + "status": DocStatus.PROCESSED, + "chunks_count": len(chunks), + "updated_at": datetime.now().isoformat(), + } + } + ) + await self._insert_done() + + except Exception as e: + logger.error( + f"Failed to process document {id_doc_processing_status}: {str(e)}" + ) + await self.doc_status.upsert( + { + id_doc_processing_status: { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + } + ) + continue + + async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None: + try: + new_kg = await extract_entities( + chunk, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + llm_response_cache=self.llm_response_cache, + global_config=asdict(self), + ) + if new_kg is None: + logger.info("No entities or relationships extracted!") + else: + self.chunk_entity_relation_graph = new_kg + + except Exception as e: + logger.error("Failed to extract entities and relationships") + raise e async def _insert_done(self): tasks = [] @@ -1169,7 +1058,7 @@ class LightRAG: return content return content[:max_length] + "..." - async def get_processing_status(self) -> Dict[str, int]: + async def get_processing_status(self) -> dict[str, int]: """Get current document processing status counts Returns: diff --git a/lightrag/operate.py b/lightrag/operate.py index c8c50f61..811b4194 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2,7 +2,7 @@ import asyncio import json import re from tqdm.asyncio import tqdm as tqdm_async -from typing import Union +from typing import Any, Union from collections import Counter, defaultdict from .utils import ( logger, @@ -36,15 +36,14 @@ import time def chunking_by_token_size( content: str, - split_by_character=None, - split_by_character_only=False, - overlap_token_size=128, - max_token_size=1024, - tiktoken_model="gpt-4o", - **kwargs, -): + split_by_character: Union[str, None] = None, + split_by_character_only: bool = False, + overlap_token_size: int = 128, + max_token_size: int = 1024, + tiktoken_model: str = "gpt-4o", +) -> list[dict[str, Any]]: tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) - results = [] + results: list[dict[str, Any]] = [] if split_by_character: raw_chunks = content.split(split_by_character) new_chunks = [] @@ -568,7 +567,7 @@ async def kg_query( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict, hashing_kv: BaseKVStorage = None, @@ -777,7 +776,7 @@ async def mix_kg_vector_query( entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict, hashing_kv: BaseKVStorage = None, @@ -969,7 +968,7 @@ async def _build_query_context( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, query_param: QueryParam, ): # ll_entities_context, ll_relations_context, ll_text_units_context = "", "", "" @@ -1052,7 +1051,7 @@ async def _get_node_data( query, knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, query_param: QueryParam, ): # get similar entities @@ -1145,7 +1144,7 @@ async def _get_node_data( async def _find_most_related_text_unit_from_entities( node_datas: list[dict], query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, ): text_units = [ @@ -1268,7 +1267,7 @@ async def _get_edge_data( keywords, knowledge_graph_inst: BaseGraphStorage, relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, query_param: QueryParam, ): results = await relationships_vdb.query(keywords, top_k=query_param.top_k) @@ -1421,7 +1420,7 @@ async def _find_most_related_entities_from_relationships( async def _find_related_text_unit_from_relationships( edge_datas: list[dict], query_param: QueryParam, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, knowledge_graph_inst: BaseGraphStorage, ): text_units = [ @@ -1496,7 +1495,7 @@ def combine_contexts(entities, relationships, sources): async def naive_query( query, chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict, hashing_kv: BaseKVStorage = None, @@ -1599,7 +1598,7 @@ async def kg_query_with_keywords( knowledge_graph_inst: BaseGraphStorage, entities_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage[TextChunkSchema], + text_chunks_db: BaseKVStorage, query_param: QueryParam, global_config: dict, hashing_kv: BaseKVStorage = None, diff --git a/lightrag/utils.py b/lightrag/utils.py index ed0b6c06..28d9bfaa 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -98,7 +98,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: return None -def convert_response_to_json(response: str) -> dict: +def convert_response_to_json(response: str) -> dict[str, Any]: json_str = locate_json_string_body_from_string(response) assert json_str is not None, f"Unable to parse JSON from response: {response}" try: