support pipeline mode

This commit is contained in:
jin
2025-01-16 12:52:37 +08:00
parent 17a2ec2bc4
commit d5ae6669ea
5 changed files with 374 additions and 323 deletions

View File

@@ -87,12 +87,14 @@ async def main():
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG(
working_dir=WORKING_DIR,
# log_level="DEBUG",
working_dir=WORKING_DIR,
entity_extract_max_gleaning = 1,
enable_llm_cache=False,
embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90},
enable_llm_cache=True,
enable_llm_cache_for_entity_extract = True,
embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90},
chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size = MAX_TOKENS,
@@ -106,34 +108,30 @@ async def main():
graph_storage = "OracleGraphStorage",
kv_storage = "OracleKVStorage",
vector_storage="OracleVectorDBStorage",
doc_status_storage="OracleDocStatusStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"},
addon_params = {"example_number":1,
"language":"Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size":2,
}
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.key_string_value_json_storage_cls.db = oracle_db
rag.vector_db_storage_cls.db = oracle_db
rag.graph_storage_cls.db = oracle_db
rag.doc_status_storage_cls.db = oracle_db
rag.doc_status.db = oracle_db
rag.full_docs.db = oracle_db
rag.text_chunks.db = oracle_db
rag.llm_response_cache.db = oracle_db
rag.key_string_value_json_storage_cls.db = oracle_db
rag.chunks_vdb.db = oracle_db
rag.relationships_vdb.db = oracle_db
rag.entities_vdb.db = oracle_db
rag.graph_storage_cls.db = oracle_db
rag.chunk_entity_relation_graph.db = oracle_db
rag.llm_response_cache.db = oracle_db
rag.set_storage_client(db_client = oracle_db)
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())
with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read()
texts = [x for x in all_text.split("\n") if x]
# New mode use pipeline
await rag.apipeline_process_documents(texts)
await rag.apipeline_process_chunks()
await rag.apipeline_process_extract_graph()
# Old method use ainsert
#await rag.ainsert(texts)
# Perform search in different modes
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:

View File

@@ -12,9 +12,6 @@ from ..base import (
BaseGraphStorage,
BaseKVStorage,
BaseVectorStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
)
import oracledb
@@ -156,8 +153,6 @@ class OracleDB:
if data is None:
await cursor.execute(sql)
else:
# print(data)
# print(sql)
await cursor.execute(sql, data)
await connection.commit()
except Exception as e:
@@ -175,7 +170,7 @@ class OracleKVStorage(BaseKVStorage):
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num",10)
################ QUERY METHODS ################
@@ -204,12 +199,11 @@ class OracleKVStorage(BaseKVStorage):
array_res = await self.db.query(SQL, params, multirows=True)
res = {}
for row in array_res:
res[row["id"]] = row
res[row["id"]] = row
return res
else:
return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""get doc_chunks data based on id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
@@ -228,8 +222,7 @@ class OracleKVStorage(BaseKVStorage):
dict_res[mode] = {}
for row in res:
dict_res[row["mode"]][row["id"]] = row
res = [{k: v} for k, v in dict_res.items()]
res = [{k: v} for k, v in dict_res.items()]
if res:
data = res # [{"data":i} for i in res]
# print(data)
@@ -237,38 +230,42 @@ class OracleKVStorage(BaseKVStorage):
else:
return None
async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]:
"""Specifically for llm_response_cache."""
if ids is not None:
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
else:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, multirows=True)
if res:
return res
else:
return None
async def filter_keys(self, keys: list[str]) -> set[str]:
"""remove duplicated"""
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
)
params = {"workspace": self.db.workspace}
try:
await self.db.query(SQL, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(SQL)
print(params)
res = await self.db.query(SQL, params, multirows=True)
data = None
if res:
exist_keys = [key["id"] for key in res]
data = set([s for s in keys if s not in exist_keys])
return data
else:
exist_keys = []
data = set([s for s in keys if s not in exist_keys])
return data
return set(keys)
################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
# print(self._data)
# values = []
if self.namespace == "text_chunks":
list_data = [
{
"__id__": k,
"id": k,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
@@ -284,33 +281,30 @@ class OracleKVStorage(BaseKVStorage):
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# print(list_data)
merge_sql = SQL_TEMPLATES["merge_chunk"]
for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"]
data = {
"check_id": item["__id__"],
"id": item["__id__"],
_data = {
"id": item["id"],
"content": item["content"],
"workspace": self.db.workspace,
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": item["__vector__"],
"status": item["status"],
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, _data)
if self.namespace == "full_docs":
for k, v in self._data.items():
for k, v in data.items():
# values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"]
data = {
_data = {
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
await self.db.execute(merge_sql, _data)
if self.namespace == "llm_response_cache":
for mode, items in data.items():
@@ -325,102 +319,20 @@ class OracleKVStorage(BaseKVStorage):
}
await self.db.execute(upsert_sql, _data)
return left_data
return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(
table_name=N_T[self.namespace]
)
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]:
logger.info("full doc and chunk data had been saved into oracle db!")
@dataclass
class OracleDocStatusStorage(DocStatusStorage):
"""Oracle implementation of document status storage"""
# should pass db object to self.db
db: OracleDB = None
meta_fields = None
def __post_init__(self):
pass
async def filter_keys(self, ids: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
ids = ",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if res:
existed = set([element["id"] for element in res])
return set(ids) - existed
else:
return set(ids)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
SQL = SQL_TEMPLATES["get_status_counts"]
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {}
for doc in res:
counts[doc["status"]] = doc["count"]
return counts
async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status"""
SQL = SQL_TEMPLATES["get_docs_by_status"]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, True)
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
# Converting to be a dict
return {
element["id"]: DocProcessingStatus(
#content_summary=element["content_summary"],
content_summary = "",
content_length=element["CONTENT_LENGTH"],
status=element["STATUS"],
created_at=element["CREATETIME"],
updated_at=element["UPDATETIME"],
chunks_count=-1,
#chunks_count=element["chunks_count"],
)
for element in res
}
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def index_done_callback(self):
"""Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into ORACLE db!")
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
SQL = SQL_TEMPLATES["merge_doc_status"]
for k, v in data.items():
# chunks_count is optional
params = {
"workspace": self.db.workspace,
"id": k,
"content_summary": v["content_summary"],
"content_length": v["content_length"],
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"],
}
await self.db.execute(SQL, params)
return data
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db
@@ -466,7 +378,7 @@ class OracleGraphStorage(BaseGraphStorage):
def __post_init__(self):
"""从graphml文件加载图"""
self._max_batch_size = self.global_config["embedding_batch_num"]
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################
@@ -500,7 +412,6 @@ class OracleGraphStorage(BaseGraphStorage):
"content": content,
"content_vector": content_vector,
}
# print(merge_sql)
await self.db.execute(merge_sql, data)
# self._graph.add_node(node_id, **node_data)
@@ -718,9 +629,10 @@ TABLES = {
},
"LIGHTRAG_DOC_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id varchar(256) PRIMARY KEY,
id varchar(256),
workspace varchar(1024),
full_doc_id varchar(256),
status varchar(256),
chunk_order_index NUMBER,
tokens NUMBER,
content CLOB,
@@ -795,9 +707,9 @@ TABLES = {
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
@@ -808,24 +720,34 @@ SQL_TEMPLATES = {
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
"get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
USING DUAL
ON (a.id = :check_id)
WHEN NOT MATCHED THEN
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
ON (id = :id and workspace = :workspace)
WHEN NOT MATCHED THEN INSERT
(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL
@@ -838,27 +760,7 @@ SQL_TEMPLATES = {
return_value = :return_value,
cache_mode = :cache_mode,
updatetime = SYSDATE""",
"get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})",
"get_status_counts": """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""",
"get_docs_by_status": """select content_length,status,
TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime
from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""",
"merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status)
WHEN MATCHED THEN UPDATE
SET content_summary = :content_summary,
content_length = :content_length,
chunks_count = :chunks_count,
status = :status,
updatetime = SYSDATE""",
# SQL for VectorStorage
"entities": """SELECT name as entity_name FROM

View File

@@ -26,6 +26,7 @@ from .utils import (
convert_response_to_json,
logger,
set_logger,
statistic_data
)
from .base import (
BaseGraphStorage,
@@ -36,22 +37,31 @@ from .base import (
DocStatus,
)
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
JsonDocStatusStorage,
)
from .prompt import GRAPH_FIELD_SEP
STORAGES = {
"JsonKVStorage": '.storage',
"NanoVectorDBStorage": '.storage',
"NetworkXStorage": '.storage',
"JsonDocStatusStorage": '.storage',
# future KG integrations
# from .kg.ArangoDB_impl import (
# GraphStorage as ArangoDBStorage
# )
"Neo4JStorage":".kg.neo4j_impl",
"OracleKVStorage":".kg.oracle_impl",
"OracleGraphStorage":".kg.oracle_impl",
"OracleVectorDBStorage":".kg.oracle_impl",
"MilvusVectorDBStorge":".kg.milvus_impl",
"MongoKVStorage":".kg.mongo_impl",
"ChromaVectorDBStorage":".kg.chroma_impl",
"TiDBKVStorage":".kg.tidb_impl",
"TiDBVectorDBStorage":".kg.tidb_impl",
"TiDBGraphStorage":".kg.tidb_impl",
"PGKVStorage":".kg.postgres_impl",
"PGVectorStorage":".kg.postgres_impl",
"AGEStorage":".kg.age_impl",
"PGGraphStorage":".kg.postgres_impl",
"GremlinStorage":".kg.gremlin_impl",
"PGDocStatusStorage":".kg.postgres_impl",
}
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
@@ -65,36 +75,13 @@ def lazy_external_import(module_name: str, class_name: str):
def import_class(*args, **kwargs):
import importlib
# Import the module using importlib
module = importlib.import_module(module_name, package=package)
# Get the class from the module and instantiate it
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage")
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
@@ -198,52 +185,64 @@ class LightRAG:
logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# @TODO: should move all storage setup here to leverage initial start params attached to self.
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self._get_storage_class()[self.kv_storage]
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
self.vector_storage
]
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.graph_storage
]
if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir)
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
)
# show config
global_config=asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
# Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage)
self.key_string_value_json_storage_cls = partial(
self.key_string_value_json_storage_cls,
global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls,
global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls,
global_config=global_config
)
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace="json_doc_status_storage",
embedding_func=None,
)
self.llm_response_cache = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
)
####
# add embedding func by walter
####
self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
####
@@ -252,73 +251,64 @@ class LightRAG:
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(
self.llm_model_func,
hashing_kv=self.llm_response_cache
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
hashing_kv=hashing_kv,
**self.llm_model_kwargs,
)
)
# Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls(
namespace="doc_status",
global_config=asdict(self),
global_config=global_config,
embedding_func=None,
)
def _get_storage_class(self) -> dict:
return {
# kv storage
"JsonKVStorage": JsonKVStorage,
"OracleKVStorage": OracleKVStorage,
"OracleDocStatusStorage":OracleDocStatusStorage,
"MongoKVStorage": MongoKVStorage,
"TiDBKVStorage": TiDBKVStorage,
# vector storage
"NanoVectorDBStorage": NanoVectorDBStorage,
"OracleVectorDBStorage": OracleVectorDBStorage,
"MilvusVectorDBStorge": MilvusVectorDBStorge,
"ChromaVectorDBStorage": ChromaVectorDBStorage,
"TiDBVectorDBStorage": TiDBVectorDBStorage,
# graph storage
"NetworkXStorage": NetworkXStorage,
"Neo4JStorage": Neo4JStorage,
"OracleGraphStorage": OracleGraphStorage,
"AGEStorage": AGEStorage,
"PGGraphStorage": PGGraphStorage,
"PGKVStorage": PGKVStorage,
"PGDocStatusStorage": PGDocStatusStorage,
"PGVectorStorage": PGVectorStorage,
"TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage
"JsonDocStatusStorage": JsonDocStatusStorage,
}
def _get_storage_class(self, storage_name: str) -> dict:
import_path = STORAGES[storage_name]
storage_class = lazy_external_import(import_path, storage_name)
return storage_class
def set_storage_client(self,db_client):
# Now only tested on Oracle Database
for storage in [self.vector_db_storage_cls,
self.graph_storage_cls,
self.doc_status, self.full_docs,
self.text_chunks,
self.llm_response_cache,
self.key_string_value_json_storage_cls,
self.chunks_vdb,
self.relationships_vdb,
self.entities_vdb,
self.graph_storage_cls,
self.chunk_entity_relation_graph,
self.llm_response_cache]:
# set client
storage.db = db_client
def insert(
self, string_or_strings, split_by_character=None, split_by_character_only=False
@@ -358,6 +348,11 @@ class LightRAG:
}
for content in unique_contents
}
# 3. Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
# 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
@@ -406,12 +401,7 @@ class LightRAG:
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()})
await self.doc_status.upsert({doc_id: doc_status})
try:
@@ -435,30 +425,16 @@ class LightRAG:
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
# Update status to processed
doc_status.update(
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()})
await self.doc_status.upsert({doc_id: doc_status})
except Exception as e:
# Mark as failed if any step fails
doc_status.update(
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()})
await self.doc_status.upsert({doc_id: doc_status})
raise e
@@ -540,6 +516,174 @@ class LightRAG:
if update_storage:
await self._insert_done()
async def apipeline_process_documents(self, string_or_strings):
"""Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs
Args:
string_or_strings: Single document string or list of document strings
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents")
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": None,
}
for content in unique_contents
}
# 3. Filter out already processed documents
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs):
logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents")
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
if not new_docs:
logger.info(f"All documents have been processed or are duplicates")
return None
# 4. Store original document
for doc_id, doc in new_docs.items():
await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
await self.full_docs.change_status(doc_id, DocStatus.PENDING)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(self):
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
_todo_doc_keys = []
_failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None)
_pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None)
if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc:
_todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
if not _todo_doc_keys:
logger.info("All documents have been processed or are duplicates")
return None
else:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
new_docs = {
doc["id"]: doc
for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
}
# 2. split docs into chunks, insert chunks, update doc status
chunk_cnt = 0
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}"
):
try:
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
"status": DocStatus.PENDING,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
chunk_cnt += len(chunks)
await self.text_chunks.upsert(chunks)
await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED)
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Update doc status
await self.full_docs.change_status(doc_id, DocStatus.PROCESSED)
except Exception as e:
# Mark as failed if any step fails
await self.full_docs.change_status(doc_id, DocStatus.FAILED)
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks
_todo_chunk_keys = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None)
if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
if not _todo_chunk_keys:
logger.info("All chunks have been processed or are duplicates")
return None
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id):
async with semaphore:
chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])}
# Extract and store entities and relationships
try:
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.info("No entities or relationships extracted!")
# Update status to processed
await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED)
except Exception as e:
logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails
await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
raise e
with tqdm_async(total=len(_todo_chunk_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0) as progress:
tasks = []
for chunk_id in _todo_chunk_keys:
task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task)
for future in asyncio.as_completed(tasks):
await future
progress.update(1)
progress.set_postfix({
'LLM call': statistic_data["llm_call"],
'LLM cache': statistic_data["llm_cache"],
})
# Ensure all indexes are updated after each document
await self._insert_done()
async def _insert_done(self):
tasks = []
for storage_inst in [

View File

@@ -20,6 +20,7 @@ from .utils import (
handle_cache,
save_to_cache,
CacheData,
statistic_data
)
from .base import (
BaseGraphStorage,
@@ -371,14 +372,16 @@ async def extract_entities(
if need_to_restore:
llm_response_cache.global_config = global_config
if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return
statistic_data["llm_call"] += 1
if history_messages:
res: str = await use_llm_func(
input_text, history_messages=history_messages
)
else:
res: str = await use_llm_func(input_text)
res: str = await use_llm_func(input_text)
await save_to_cache(
llm_response_cache,
CacheData(args_hash=arg_hash, content=res, prompt=_prompt),
@@ -459,10 +462,8 @@ async def extract_entities(
now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"])
]
print(
logger.debug(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
)
return dict(maybe_nodes), dict(maybe_edges)
@@ -470,8 +471,8 @@ async def extract_entities(
for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks),
desc="Extracting entities from chunks",
unit="chunk",
desc="Level 2 - Extracting entities and relationships",
unit="chunk", position=1,leave=False
):
results.append(await result)
@@ -482,7 +483,7 @@ async def extract_entities(
maybe_nodes[k].extend(v)
for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v)
logger.info("Inserting entities into storage...")
logger.debug("Inserting entities into storage...")
all_entities_data = []
for result in tqdm_async(
asyncio.as_completed(
@@ -492,12 +493,12 @@ async def extract_entities(
]
),
total=len(maybe_nodes),
desc="Inserting entities",
unit="entity",
desc="Level 3 - Inserting entities",
unit="entity", position=2,leave=False
):
all_entities_data.append(await result)
logger.info("Inserting relationships into storage...")
logger.debug("Inserting relationships into storage...")
all_relationships_data = []
for result in tqdm_async(
asyncio.as_completed(
@@ -509,8 +510,8 @@ async def extract_entities(
]
),
total=len(maybe_edges),
desc="Inserting relationships",
unit="relationship",
desc="Level 3 - Inserting relationships",
unit="relationship", position=3,leave=False
):
all_relationships_data.append(await result)

View File

@@ -30,8 +30,13 @@ class UnlimitedSemaphore:
ENCODER = None
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
logger = logging.getLogger("lightrag")
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
@@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return None, None, None, None
# For naive mode, only use simple cache matching
if mode == "naive":
#if mode == "naive":
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else: