support pipeline mode

This commit is contained in:
jin
2025-01-16 12:58:15 +08:00
parent d5ae6669ea
commit 6ae8647285
6 changed files with 203 additions and 172 deletions

View File

@@ -89,38 +89,34 @@ async def main():
rag = LightRAG( rag = LightRAG(
# log_level="DEBUG", # log_level="DEBUG",
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
entity_extract_max_gleaning = 1, entity_extract_max_gleaning=1,
enable_llm_cache=True, enable_llm_cache=True,
enable_llm_cache_for_entity_extract = True, enable_llm_cache_for_entity_extract=True,
embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90}, embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
chunk_token_size=CHUNK_TOKEN_SIZE, chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size = MAX_TOKENS, llm_model_max_token_size=MAX_TOKENS,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension, embedding_dim=embedding_dimension,
max_token_size=500, max_token_size=500,
func=embedding_func, func=embedding_func,
), ),
graph_storage="OracleGraphStorage",
graph_storage = "OracleGraphStorage", kv_storage="OracleKVStorage",
kv_storage = "OracleKVStorage",
vector_storage="OracleVectorDBStorage", vector_storage="OracleVectorDBStorage",
addon_params={
addon_params = {"example_number":1, "example_number": 1,
"language":"Simplfied Chinese", "language": "Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"], "entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size":2, "insert_batch_size": 2,
} },
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.set_storage_client(db_client = oracle_db) rag.set_storage_client(db_client=oracle_db)
# Extract and Insert into LightRAG storage # Extract and Insert into LightRAG storage
with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f: with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read() all_text = f.read()
texts = [x for x in all_text.split("\n") if x] texts = [x for x in all_text.split("\n") if x]
@@ -130,7 +126,7 @@ async def main():
await rag.apipeline_process_extract_graph() await rag.apipeline_process_extract_graph()
# Old method use ainsert # Old method use ainsert
#await rag.ainsert(texts) # await rag.ainsert(texts)
# Perform search in different modes # Perform search in different modes
modes = ["naive", "local", "global", "hybrid"] modes = ["naive", "local", "global", "hybrid"]

View File

@@ -3,7 +3,7 @@ import asyncio
# import html # import html
# import os # import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union, List, Dict, Set, Any, Tuple from typing import Union
import numpy as np import numpy as np
import array import array
@@ -170,7 +170,7 @@ class OracleKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
self._max_batch_size = self.global_config.get("embedding_batch_num",10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
################ QUERY METHODS ################ ################ QUERY METHODS ################
@@ -230,7 +230,9 @@ class OracleKVStorage(BaseKVStorage):
else: else:
return None return None
async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]: async def get_by_status_and_ids(
self, status: str, ids: list[str]
) -> Union[list[dict], None]:
"""Specifically for llm_response_cache.""" """Specifically for llm_response_cache."""
if ids is not None: if ids is not None:
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
@@ -259,7 +261,6 @@ class OracleKVStorage(BaseKVStorage):
else: else:
return set(keys) return set(keys)
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
if self.namespace == "text_chunks": if self.namespace == "text_chunks":
@@ -322,9 +323,7 @@ class OracleKVStorage(BaseKVStorage):
return None return None
async def change_status(self, id: str, status: str): async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format( SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
table_name=N_T[self.namespace]
)
params = {"workspace": self.db.workspace, "id": id, "status": status} params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params) await self.db.execute(SQL, params)
@@ -708,47 +707,32 @@ TABLES = {
SQL_TEMPLATES = { SQL_TEMPLATES = {
# SQL for KVStorage # SQL for KVStorage
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
"get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL USING DUAL
ON (a.id = :id and a.workspace = :workspace) ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""", INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
USING DUAL USING DUAL
ON (id = :id and workspace = :workspace) ON (id = :id and workspace = :workspace)
WHEN NOT MATCHED THEN INSERT WHEN NOT MATCHED THEN INSERT
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL USING DUAL
ON (a.id = :id) ON (a.id = :id)
@@ -760,8 +744,6 @@ SQL_TEMPLATES = {
return_value = :return_value, return_value = :return_value,
cache_mode = :cache_mode, cache_mode = :cache_mode,
updatetime = SYSDATE""", updatetime = SYSDATE""",
# SQL for VectorStorage # SQL for VectorStorage
"entities": """SELECT name as entity_name FROM "entities": """SELECT name as entity_name FROM
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance

View File

@@ -26,7 +26,7 @@ from .utils import (
convert_response_to_json, convert_response_to_json,
logger, logger,
set_logger, set_logger,
statistic_data statistic_data,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -40,29 +40,29 @@ from .base import (
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
STORAGES = { STORAGES = {
"JsonKVStorage": '.storage', "JsonKVStorage": ".storage",
"NanoVectorDBStorage": '.storage', "NanoVectorDBStorage": ".storage",
"NetworkXStorage": '.storage', "NetworkXStorage": ".storage",
"JsonDocStatusStorage": '.storage', "JsonDocStatusStorage": ".storage",
"Neo4JStorage": ".kg.neo4j_impl",
"Neo4JStorage":".kg.neo4j_impl", "OracleKVStorage": ".kg.oracle_impl",
"OracleKVStorage":".kg.oracle_impl", "OracleGraphStorage": ".kg.oracle_impl",
"OracleGraphStorage":".kg.oracle_impl", "OracleVectorDBStorage": ".kg.oracle_impl",
"OracleVectorDBStorage":".kg.oracle_impl", "MilvusVectorDBStorge": ".kg.milvus_impl",
"MilvusVectorDBStorge":".kg.milvus_impl", "MongoKVStorage": ".kg.mongo_impl",
"MongoKVStorage":".kg.mongo_impl", "ChromaVectorDBStorage": ".kg.chroma_impl",
"ChromaVectorDBStorage":".kg.chroma_impl", "TiDBKVStorage": ".kg.tidb_impl",
"TiDBKVStorage":".kg.tidb_impl", "TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage":".kg.tidb_impl", "TiDBGraphStorage": ".kg.tidb_impl",
"TiDBGraphStorage":".kg.tidb_impl", "PGKVStorage": ".kg.postgres_impl",
"PGKVStorage":".kg.postgres_impl", "PGVectorStorage": ".kg.postgres_impl",
"PGVectorStorage":".kg.postgres_impl", "AGEStorage": ".kg.age_impl",
"AGEStorage":".kg.age_impl", "PGGraphStorage": ".kg.postgres_impl",
"PGGraphStorage":".kg.postgres_impl", "GremlinStorage": ".kg.gremlin_impl",
"GremlinStorage":".kg.gremlin_impl", "PGDocStatusStorage": ".kg.postgres_impl",
"PGDocStatusStorage":".kg.postgres_impl",
} }
def lazy_external_import(module_name: str, class_name: str): def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller.""" """Lazily import a class from an external module based on the package of the caller."""
@@ -75,6 +75,7 @@ def lazy_external_import(module_name: str, class_name: str):
def import_class(*args, **kwargs): def import_class(*args, **kwargs):
import importlib import importlib
module = importlib.import_module(module_name, package=package) module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name) cls = getattr(module, class_name)
return cls(*args, **kwargs) return cls(*args, **kwargs)
@@ -190,7 +191,7 @@ class LightRAG:
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
# show config # show config
global_config=asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -199,25 +200,27 @@ class LightRAG:
self.embedding_func self.embedding_func
) )
# Initialize all storages # Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage) self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage) self._get_storage_class(self.kv_storage)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage) )
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
self.vector_storage
)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
self.graph_storage
)
self.key_string_value_json_storage_cls = partial( self.key_string_value_json_storage_cls = partial(
self.key_string_value_json_storage_cls, self.key_string_value_json_storage_cls, global_config=global_config
global_config=global_config
) )
self.vector_db_storage_cls = partial( self.vector_db_storage_cls = partial(
self.vector_db_storage_cls, self.vector_db_storage_cls, global_config=global_config
global_config=global_config
) )
self.graph_storage_cls = partial( self.graph_storage_cls = partial(
self.graph_storage_cls, self.graph_storage_cls, global_config=global_config
global_config=global_config
) )
self.json_doc_status_storage = self.key_string_value_json_storage_cls( self.json_doc_status_storage = self.key_string_value_json_storage_cls(
@@ -264,13 +267,15 @@ class LightRAG:
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"): if self.llm_response_cache and hasattr(
self.llm_response_cache, "global_config"
):
hashing_kv = self.llm_response_cache hashing_kv = self.llm_response_cache
else: else:
hashing_kv = self.key_string_value_json_storage_cls( hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache", namespace="llm_response_cache",
embedding_func=None, embedding_func=None,
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
@@ -293,20 +298,23 @@ class LightRAG:
storage_class = lazy_external_import(import_path, storage_name) storage_class = lazy_external_import(import_path, storage_name)
return storage_class return storage_class
def set_storage_client(self,db_client): def set_storage_client(self, db_client):
# Now only tested on Oracle Database # Now only tested on Oracle Database
for storage in [self.vector_db_storage_cls, for storage in [
self.graph_storage_cls, self.vector_db_storage_cls,
self.doc_status, self.full_docs, self.graph_storage_cls,
self.text_chunks, self.doc_status,
self.llm_response_cache, self.full_docs,
self.key_string_value_json_storage_cls, self.text_chunks,
self.chunks_vdb, self.llm_response_cache,
self.relationships_vdb, self.key_string_value_json_storage_cls,
self.entities_vdb, self.chunks_vdb,
self.graph_storage_cls, self.relationships_vdb,
self.chunk_entity_relation_graph, self.entities_vdb,
self.llm_response_cache]: self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache,
]:
# set client # set client
storage.db = db_client storage.db = db_client
@@ -349,11 +357,6 @@ class LightRAG:
for content in unique_contents for content in unique_contents
} }
# 3. Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
# 3. Filter out already processed documents # 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
@@ -401,7 +404,12 @@ class LightRAG:
} }
# Update status with chunks information # Update status with chunks information
doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()}) doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
try: try:
@@ -425,16 +433,30 @@ class LightRAG:
self.chunk_entity_relation_graph = maybe_new_kg self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks) await self.text_chunks.upsert(chunks)
# Update status to processed # Update status to processed
doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()}) doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
except Exception as e: except Exception as e:
# Mark as failed if any step fails # Mark as failed if any step fails
doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()}) doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
raise e raise e
@@ -527,7 +549,9 @@ class LightRAG:
# 1. Remove duplicate contents from the list # 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings)) unique_contents = list(set(doc.strip() for doc in string_or_strings))
logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents") logger.info(
f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
)
# 2. Generate document IDs and initial status # 2. Generate document IDs and initial status
new_docs = { new_docs = {
@@ -545,11 +569,13 @@ class LightRAG:
# 3. Filter out already processed documents # 3. Filter out already processed documents
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs): if len(_not_stored_doc_keys) < len(new_docs):
logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents") logger.info(
f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents"
)
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
if not new_docs: if not new_docs:
logger.info(f"All documents have been processed or are duplicates") logger.info("All documents have been processed or are duplicates")
return None return None
# 4. Store original document # 4. Store original document
@@ -562,8 +588,12 @@ class LightRAG:
"""Get pendding documents, split into chunks,insert chunks""" """Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents # 1. get all pending and failed documents
_todo_doc_keys = [] _todo_doc_keys = []
_failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) _failed_doc = await self.full_docs.get_by_status_and_ids(
_pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) status=DocStatus.FAILED, ids=None
)
_pendding_doc = await self.full_docs.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
if _failed_doc: if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) _todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc: if _pendding_doc:
@@ -575,8 +605,7 @@ class LightRAG:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
new_docs = { new_docs = {
doc["id"]: doc doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
} }
# 2. split docs into chunks, insert chunks, update doc status # 2. split docs into chunks, insert chunks, update doc status
@@ -585,7 +614,8 @@ class LightRAG:
for i in range(0, len(new_docs), batch_size): for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size]) batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async( for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}" batch_docs.items(),
desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}",
): ):
try: try:
# Generate chunks from document # Generate chunks from document
@@ -616,18 +646,23 @@ class LightRAG:
await self.full_docs.change_status(doc_id, DocStatus.FAILED) await self.full_docs.change_status(doc_id, DocStatus.FAILED)
raise e raise e
except Exception as e: except Exception as e:
import traceback import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg) error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
continue logger.error(error_msg)
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
async def apipeline_process_extract_graph(self): async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk""" """Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks # 1. get all pending and failed chunks
_todo_chunk_keys = [] _todo_chunk_keys = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None) _failed_chunks = await self.text_chunks.get_by_status_and_ids(
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None) status=DocStatus.FAILED, ids=None
)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(
status=DocStatus.PENDING, ids=None
)
if _failed_chunks: if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks: if _pendding_chunks:
@@ -639,11 +674,15 @@ class LightRAG:
# Process documents in batches # Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10) batch_size = self.addon_params.get("insert_batch_size", 10)
semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously semaphore = asyncio.Semaphore(
batch_size
) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id): async def process_chunk(chunk_id):
async with semaphore: async with semaphore:
chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])} chunks = {
i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
}
# Extract and store entities and relationships # Extract and store entities and relationships
try: try:
maybe_new_kg = await extract_entities( maybe_new_kg = await extract_entities(
@@ -664,10 +703,12 @@ class LightRAG:
await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
raise e raise e
with tqdm_async(total=len(_todo_chunk_keys), with tqdm_async(
desc="\nLevel 1 - Processing chunks", total=len(_todo_chunk_keys),
unit="chunk", desc="\nLevel 1 - Processing chunks",
position=0) as progress: unit="chunk",
position=0,
) as progress:
tasks = [] tasks = []
for chunk_id in _todo_chunk_keys: for chunk_id in _todo_chunk_keys:
task = asyncio.create_task(process_chunk(chunk_id)) task = asyncio.create_task(process_chunk(chunk_id))
@@ -676,10 +717,12 @@ class LightRAG:
for future in asyncio.as_completed(tasks): for future in asyncio.as_completed(tasks):
await future await future
progress.update(1) progress.update(1)
progress.set_postfix({ progress.set_postfix(
'LLM call': statistic_data["llm_call"], {
'LLM cache': statistic_data["llm_cache"], "LLM call": statistic_data["llm_call"],
}) "LLM cache": statistic_data["llm_cache"],
}
)
# Ensure all indexes are updated after each document # Ensure all indexes are updated after each document
await self._insert_done() await self._insert_done()

View File

@@ -20,7 +20,7 @@ from .utils import (
handle_cache, handle_cache,
save_to_cache, save_to_cache,
CacheData, CacheData,
statistic_data statistic_data,
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -105,7 +105,9 @@ async def _handle_entity_relation_summary(
llm_max_tokens = global_config["llm_model_max_token_size"] llm_max_tokens = global_config["llm_model_max_token_size"]
tiktoken_model_name = global_config["tiktoken_model_name"] tiktoken_model_name = global_config["tiktoken_model_name"]
summary_max_tokens = global_config["entity_summary_to_max_tokens"] summary_max_tokens = global_config["entity_summary_to_max_tokens"]
language = global_config["addon_params"].get("language", PROMPTS["DEFAULT_LANGUAGE"]) language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name) tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
if len(tokens) < summary_max_tokens: # No need for summary if len(tokens) < summary_max_tokens: # No need for summary
@@ -360,7 +362,7 @@ async def extract_entities(
llm_response_cache.global_config = new_config llm_response_cache.global_config = new_config
need_to_restore = True need_to_restore = True
if history_messages: if history_messages:
history = json.dumps(history_messages,ensure_ascii=False) history = json.dumps(history_messages, ensure_ascii=False)
_prompt = history + "\n" + input_text _prompt = history + "\n" + input_text
else: else:
_prompt = input_text _prompt = input_text
@@ -394,7 +396,7 @@ async def extract_entities(
return await use_llm_func(input_text) return await use_llm_func(input_text)
async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
""""Prpocess a single chunk """ "Prpocess a single chunk
Args: Args:
chunk_key_dp (tuple[str, TextChunkSchema]): chunk_key_dp (tuple[str, TextChunkSchema]):
("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
@@ -472,7 +474,9 @@ async def extract_entities(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks), total=len(ordered_chunks),
desc="Level 2 - Extracting entities and relationships", desc="Level 2 - Extracting entities and relationships",
unit="chunk", position=1,leave=False unit="chunk",
position=1,
leave=False,
): ):
results.append(await result) results.append(await result)
@@ -494,7 +498,9 @@ async def extract_entities(
), ),
total=len(maybe_nodes), total=len(maybe_nodes),
desc="Level 3 - Inserting entities", desc="Level 3 - Inserting entities",
unit="entity", position=2,leave=False unit="entity",
position=2,
leave=False,
): ):
all_entities_data.append(await result) all_entities_data.append(await result)
@@ -511,7 +517,9 @@ async def extract_entities(
), ),
total=len(maybe_edges), total=len(maybe_edges),
desc="Level 3 - Inserting relationships", desc="Level 3 - Inserting relationships",
unit="relationship", position=3,leave=False unit="relationship",
position=3,
leave=False,
): ):
all_relationships_data.append(await result) all_relationships_data.append(await result)

View File

@@ -41,7 +41,7 @@ logging.getLogger("httpx").setLevel(logging.WARNING)
def set_logger(log_file: str): def set_logger(log_file: str):
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter( formatter = logging.Formatter(
@@ -458,7 +458,7 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return None, None, None, None return None, None, None, None
# For naive mode, only use simple cache matching # For naive mode, only use simple cache matching
#if mode == "naive": # if mode == "naive":
if mode == "default": if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
@@ -479,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
quantized = min_val = max_val = None quantized = min_val = max_val = None
if is_embedding_cache_enabled: if is_embedding_cache_enabled:
# Use embedding cache # Use embedding cache
embedding_model_func = hashing_kv.global_config["embedding_func"].func #["func"] embedding_model_func = hashing_kv.global_config[
"embedding_func"
].func # ["func"]
llm_model_func = hashing_kv.global_config.get("llm_model_func") llm_model_func = hashing_kv.global_config.get("llm_model_func")
current_embedding = await embedding_model_func([prompt]) current_embedding = await embedding_model_func([prompt])