diff --git a/.gitignore b/.gitignore index 0e0ec299..ec95f8a5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,4 @@ rag_storage venv/ examples/input/ examples/output/ -.DS_Store \ No newline at end of file +.DS_Store diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 8a5439e2..6de6e0a7 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -89,49 +89,45 @@ async def main(): rag = LightRAG( # log_level="DEBUG", working_dir=WORKING_DIR, - entity_extract_max_gleaning = 1, - + entity_extract_max_gleaning=1, enable_llm_cache=True, - enable_llm_cache_for_entity_extract = True, - embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, - - + enable_llm_cache_for_entity_extract=True, + embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90}, chunk_token_size=CHUNK_TOKEN_SIZE, - llm_model_max_token_size = MAX_TOKENS, + llm_model_max_token_size=MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, max_token_size=500, func=embedding_func, - ), - - graph_storage = "OracleGraphStorage", - kv_storage = "OracleKVStorage", + ), + graph_storage="OracleGraphStorage", + kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", - - addon_params = {"example_number":1, - "language":"Simplfied Chinese", - "entity_types": ["organization", "person", "geo", "event"], - "insert_batch_size":2, - } + addon_params={ + "example_number": 1, + "language": "Simplfied Chinese", + "entity_types": ["organization", "person", "geo", "event"], + "insert_batch_size": 2, + }, ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.set_storage_client(db_client = oracle_db) + # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool + rag.set_storage_client(db_client=oracle_db) # Extract and Insert into LightRAG storage - with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f: + with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: all_text = f.read() texts = [x for x in all_text.split("\n") if x] - + # New mode use pipeline await rag.apipeline_process_documents(texts) - await rag.apipeline_process_chunks() + await rag.apipeline_process_chunks() await rag.apipeline_process_extract_graph() # Old method use ainsert - #await rag.ainsert(texts) - + # await rag.ainsert(texts) + # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] for mode in modes: diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index c9deed4e..e30b6909 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -3,7 +3,7 @@ import asyncio # import html # import os from dataclasses import dataclass -from typing import Union, List, Dict, Set, Any, Tuple +from typing import Union import numpy as np import array @@ -170,7 +170,7 @@ class OracleKVStorage(BaseKVStorage): def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config.get("embedding_batch_num",10) + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) ################ QUERY METHODS ################ @@ -190,7 +190,7 @@ class OracleKVStorage(BaseKVStorage): return res else: return None - + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: """Specifically for llm_response_cache.""" SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] @@ -199,11 +199,11 @@ class OracleKVStorage(BaseKVStorage): array_res = await self.db.query(SQL, params, multirows=True) res = {} for row in array_res: - res[row["id"]] = row + res[row["id"]] = row return res else: return None - + async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( @@ -222,7 +222,7 @@ class OracleKVStorage(BaseKVStorage): dict_res[mode] = {} for row in res: dict_res[row["mode"]][row["id"]] = row - res = [{k: v} for k, v in dict_res.items()] + res = [{k: v} for k, v in dict_res.items()] if res: data = res # [{"data":i} for i in res] # print(data) @@ -230,7 +230,9 @@ class OracleKVStorage(BaseKVStorage): else: return None - async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]: + async def get_by_status_and_ids( + self, status: str, ids: list[str] + ) -> Union[list[dict], None]: """Specifically for llm_response_cache.""" if ids is not None: SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( @@ -244,7 +246,7 @@ class OracleKVStorage(BaseKVStorage): return res else: return None - + async def filter_keys(self, keys: list[str]) -> set[str]: """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( @@ -258,7 +260,6 @@ class OracleKVStorage(BaseKVStorage): return data else: return set(keys) - ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): @@ -281,7 +282,7 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - + merge_sql = SQL_TEMPLATES["merge_chunk"] for item in list_data: _data = { @@ -320,11 +321,9 @@ class OracleKVStorage(BaseKVStorage): await self.db.execute(upsert_sql, _data) return None - + async def change_status(self, id: str, status: str): - SQL = SQL_TEMPLATES["change_status"].format( - table_name=N_T[self.namespace] - ) + SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) params = {"workspace": self.db.workspace, "id": id, "status": status} await self.db.execute(SQL, params) @@ -673,8 +672,8 @@ TABLES = { }, "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - id varchar(256) PRIMARY KEY, - workspace varchar(1024), + id varchar(256) PRIMARY KEY, + workspace varchar(1024), cache_mode varchar(256), model_name varchar(256), original_prompt clob, @@ -708,47 +707,32 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", - "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", - "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", - "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", - "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", - "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", - "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", - "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", - "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a USING DUAL ON (a.id = :id and a.workspace = :workspace) WHEN NOT MATCHED THEN INSERT(id,content,workspace) values(:id,:content,:workspace)""", - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS USING DUAL ON (id = :id and workspace = :workspace) WHEN NOT MATCHED THEN INSERT (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, - "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a USING DUAL ON (a.id = :id) @@ -760,8 +744,6 @@ SQL_TEMPLATES = { return_value = :return_value, cache_mode = :cache_mode, updatetime = SYSDATE""", - - # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -818,7 +800,7 @@ SQL_TEMPLATES = { INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) WHEN MATCHED THEN - UPDATE SET + UPDATE SET entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL @@ -827,7 +809,7 @@ SQL_TEMPLATES = { INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) WHEN MATCHED THEN - UPDATE SET + UPDATE SET weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7d8cdf45..0902fc50 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -26,7 +26,7 @@ from .utils import ( convert_response_to_json, logger, set_logger, - statistic_data + statistic_data, ) from .base import ( BaseGraphStorage, @@ -39,30 +39,30 @@ from .base import ( from .prompt import GRAPH_FIELD_SEP -STORAGES = { - "JsonKVStorage": '.storage', - "NanoVectorDBStorage": '.storage', - "NetworkXStorage": '.storage', - "JsonDocStatusStorage": '.storage', - - "Neo4JStorage":".kg.neo4j_impl", - "OracleKVStorage":".kg.oracle_impl", - "OracleGraphStorage":".kg.oracle_impl", - "OracleVectorDBStorage":".kg.oracle_impl", - "MilvusVectorDBStorge":".kg.milvus_impl", - "MongoKVStorage":".kg.mongo_impl", - "ChromaVectorDBStorage":".kg.chroma_impl", - "TiDBKVStorage":".kg.tidb_impl", - "TiDBVectorDBStorage":".kg.tidb_impl", - "TiDBGraphStorage":".kg.tidb_impl", - "PGKVStorage":".kg.postgres_impl", - "PGVectorStorage":".kg.postgres_impl", - "AGEStorage":".kg.age_impl", - "PGGraphStorage":".kg.postgres_impl", - "GremlinStorage":".kg.gremlin_impl", - "PGDocStatusStorage":".kg.postgres_impl", +STORAGES = { + "JsonKVStorage": ".storage", + "NanoVectorDBStorage": ".storage", + "NetworkXStorage": ".storage", + "JsonDocStatusStorage": ".storage", + "Neo4JStorage": ".kg.neo4j_impl", + "OracleKVStorage": ".kg.oracle_impl", + "OracleGraphStorage": ".kg.oracle_impl", + "OracleVectorDBStorage": ".kg.oracle_impl", + "MilvusVectorDBStorge": ".kg.milvus_impl", + "MongoKVStorage": ".kg.mongo_impl", + "ChromaVectorDBStorage": ".kg.chroma_impl", + "TiDBKVStorage": ".kg.tidb_impl", + "TiDBVectorDBStorage": ".kg.tidb_impl", + "TiDBGraphStorage": ".kg.tidb_impl", + "PGKVStorage": ".kg.postgres_impl", + "PGVectorStorage": ".kg.postgres_impl", + "AGEStorage": ".kg.age_impl", + "PGGraphStorage": ".kg.postgres_impl", + "GremlinStorage": ".kg.gremlin_impl", + "PGDocStatusStorage": ".kg.postgres_impl", } + def lazy_external_import(module_name: str, class_name: str): """Lazily import a class from an external module based on the package of the caller.""" @@ -75,6 +75,7 @@ def lazy_external_import(module_name: str, class_name: str): def import_class(*args, **kwargs): import importlib + module = importlib.import_module(module_name, package=package) cls = getattr(module, class_name) return cls(*args, **kwargs) @@ -190,7 +191,7 @@ class LightRAG: os.makedirs(self.working_dir) # show config - global_config=asdict(self) + global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -198,31 +199,33 @@ class LightRAG: self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.embedding_func ) - # Initialize all storages - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage) - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage) - + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self._get_storage_class(self.kv_storage) + ) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + self.vector_storage + ) + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + self.graph_storage + ) + self.key_string_value_json_storage_cls = partial( - self.key_string_value_json_storage_cls, - global_config=global_config + self.key_string_value_json_storage_cls, global_config=global_config ) self.vector_db_storage_cls = partial( - self.vector_db_storage_cls, - global_config=global_config + self.vector_db_storage_cls, global_config=global_config ) self.graph_storage_cls = partial( - self.graph_storage_cls, - global_config=global_config + self.graph_storage_cls, global_config=global_config ) self.json_doc_status_storage = self.key_string_value_json_storage_cls( namespace="json_doc_status_storage", - embedding_func=None, + embedding_func=None, ) self.llm_response_cache = self.key_string_value_json_storage_cls( @@ -264,13 +267,15 @@ class LightRAG: embedding_func=self.embedding_func, ) - if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"): + if self.llm_response_cache and hasattr( + self.llm_response_cache, "global_config" + ): hashing_kv = self.llm_response_cache else: hashing_kv = self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - embedding_func=None, - ) + namespace="llm_response_cache", + embedding_func=None, + ) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -292,21 +297,24 @@ class LightRAG: import_path = STORAGES[storage_name] storage_class = lazy_external_import(import_path, storage_name) return storage_class - - def set_storage_client(self,db_client): + + def set_storage_client(self, db_client): # Now only tested on Oracle Database - for storage in [self.vector_db_storage_cls, - self.graph_storage_cls, - self.doc_status, self.full_docs, - self.text_chunks, - self.llm_response_cache, - self.key_string_value_json_storage_cls, - self.chunks_vdb, - self.relationships_vdb, - self.entities_vdb, - self.graph_storage_cls, - self.chunk_entity_relation_graph, - self.llm_response_cache]: + for storage in [ + self.vector_db_storage_cls, + self.graph_storage_cls, + self.doc_status, + self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.key_string_value_json_storage_cls, + self.chunks_vdb, + self.relationships_vdb, + self.entities_vdb, + self.graph_storage_cls, + self.chunk_entity_relation_graph, + self.llm_response_cache, + ]: # set client storage.db = db_client @@ -348,11 +356,6 @@ class LightRAG: } for content in unique_contents } - - # 3. Store original document and chunks - await self.full_docs.upsert( - {doc_id: {"content": doc["content"]}} - ) # 3. Filter out already processed documents _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) @@ -401,7 +404,12 @@ class LightRAG: } # Update status with chunks information - doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "chunks_count": len(chunks), + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) try: @@ -425,16 +433,30 @@ class LightRAG: self.chunk_entity_relation_graph = maybe_new_kg - + # Store original document and chunks + await self.full_docs.upsert( + {doc_id: {"content": doc["content"]}} + ) await self.text_chunks.upsert(chunks) # Update status to processed - doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "status": DocStatus.PROCESSED, + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) except Exception as e: # Mark as failed if any step fails - doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()}) + doc_status.update( + { + "status": DocStatus.FAILED, + "error": str(e), + "updated_at": datetime.now().isoformat(), + } + ) await self.doc_status.upsert({doc_id: doc_status}) raise e @@ -527,7 +549,9 @@ class LightRAG: # 1. Remove duplicate contents from the list unique_contents = list(set(doc.strip() for doc in string_or_strings)) - logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents") + logger.info( + f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents" + ) # 2. Generate document IDs and initial status new_docs = { @@ -542,28 +566,34 @@ class LightRAG: for content in unique_contents } - # 3. Filter out already processed documents + # 3. Filter out already processed documents _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) if len(_not_stored_doc_keys) < len(new_docs): - logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents") + logger.info( + f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents" + ) new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} if not new_docs: - logger.info(f"All documents have been processed or are duplicates") + logger.info("All documents have been processed or are duplicates") return None - # 4. Store original document + # 4. Store original document for doc_id, doc in new_docs.items(): await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) await self.full_docs.change_status(doc_id, DocStatus.PENDING) logger.info(f"Stored {len(new_docs)} new unique documents") - + async def apipeline_process_chunks(self): - """Get pendding documents, split into chunks,insert chunks""" - # 1. get all pending and failed documents + """Get pendding documents, split into chunks,insert chunks""" + # 1. get all pending and failed documents _todo_doc_keys = [] - _failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) - _pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + _failed_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) if _failed_doc: _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) if _pendding_doc: @@ -573,10 +603,9 @@ class LightRAG: return None else: logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") - + new_docs = { - doc["id"]: doc - for doc in await self.full_docs.get_by_ids(_todo_doc_keys) + doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys) } # 2. split docs into chunks, insert chunks, update doc status @@ -585,8 +614,9 @@ class LightRAG: for i in range(0, len(new_docs), batch_size): batch_docs = dict(list(new_docs.items())[i : i + batch_size]) for doc_id, doc in tqdm_async( - batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}" - ): + batch_docs.items(), + desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}", + ): try: # Generate chunks from document chunks = { @@ -616,18 +646,23 @@ class LightRAG: await self.full_docs.change_status(doc_id, DocStatus.FAILED) raise e except Exception as e: - import traceback - error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" - logger.error(error_msg) - continue - logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") + import traceback + + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") async def apipeline_process_extract_graph(self): """Get pendding or failed chunks, extract entities and relationships from each chunk""" - # 1. get all pending and failed chunks + # 1. get all pending and failed chunks _todo_chunk_keys = [] - _failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) - _pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) + _failed_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) if _failed_chunks: _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) if _pendding_chunks: @@ -635,15 +670,19 @@ class LightRAG: if not _todo_chunk_keys: logger.info("All chunks have been processed or are duplicates") return None - + # Process documents in batches batch_size = self.addon_params.get("insert_batch_size", 10) - semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously + semaphore = asyncio.Semaphore( + batch_size + ) # Control the number of tasks that are processed simultaneously - async def process_chunk(chunk_id): + async def process_chunk(chunk_id): async with semaphore: - chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])} + chunks = { + i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) + } # Extract and store entities and relationships try: maybe_new_kg = await extract_entities( @@ -662,25 +701,29 @@ class LightRAG: logger.error("Failed to extract entities and relationships") # Mark as failed if any step fails await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) - raise e + raise e - with tqdm_async(total=len(_todo_chunk_keys), - desc="\nLevel 1 - Processing chunks", - unit="chunk", - position=0) as progress: + with tqdm_async( + total=len(_todo_chunk_keys), + desc="\nLevel 1 - Processing chunks", + unit="chunk", + position=0, + ) as progress: tasks = [] for chunk_id in _todo_chunk_keys: task = asyncio.create_task(process_chunk(chunk_id)) tasks.append(task) - + for future in asyncio.as_completed(tasks): await future progress.update(1) - progress.set_postfix({ - 'LLM call': statistic_data["llm_call"], - 'LLM cache': statistic_data["llm_cache"], - }) - + progress.set_postfix( + { + "LLM call": statistic_data["llm_call"], + "LLM cache": statistic_data["llm_cache"], + } + ) + # Ensure all indexes are updated after each document await self._insert_done() diff --git a/lightrag/operate.py b/lightrag/operate.py index f9e48dbf..fbfd1fbd 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -20,7 +20,7 @@ from .utils import ( handle_cache, save_to_cache, CacheData, - statistic_data + statistic_data, ) from .base import ( BaseGraphStorage, @@ -105,7 +105,9 @@ async def _handle_entity_relation_summary( llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] summary_max_tokens = global_config["entity_summary_to_max_tokens"] - language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) if len(tokens) < summary_max_tokens: # No need for summary @@ -360,7 +362,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages,ensure_ascii=False) + history = json.dumps(history_messages, ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -381,7 +383,7 @@ async def extract_entities( input_text, history_messages=history_messages ) else: - res: str = await use_llm_func(input_text) + res: str = await use_llm_func(input_text) await save_to_cache( llm_response_cache, CacheData(args_hash=arg_hash, content=res, prompt=_prompt), @@ -394,7 +396,7 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): - """"Prpocess a single chunk + """ "Prpocess a single chunk Args: chunk_key_dp (tuple[str, TextChunkSchema]): ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) @@ -472,7 +474,9 @@ async def extract_entities( asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), total=len(ordered_chunks), desc="Level 2 - Extracting entities and relationships", - unit="chunk", position=1,leave=False + unit="chunk", + position=1, + leave=False, ): results.append(await result) @@ -494,7 +498,9 @@ async def extract_entities( ), total=len(maybe_nodes), desc="Level 3 - Inserting entities", - unit="entity", position=2,leave=False + unit="entity", + position=2, + leave=False, ): all_entities_data.append(await result) @@ -511,7 +517,9 @@ async def extract_entities( ), total=len(maybe_edges), desc="Level 3 - Inserting relationships", - unit="relationship", position=3,leave=False + unit="relationship", + position=3, + leave=False, ): all_relationships_data.append(await result) diff --git a/lightrag/utils.py b/lightrag/utils.py index a83c0382..ce556ab2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -41,7 +41,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING) def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -458,7 +458,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return None, None, None, None # For naive mode, only use simple cache matching - #if mode == "naive": + # if mode == "naive": if mode == "default": if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} @@ -479,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"] + embedding_model_func = hashing_kv.global_config[ + "embedding_func" + ].func # ["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt])