support pipeline mode

This commit is contained in:
jin
2025-01-16 12:52:37 +08:00
parent 17a2ec2bc4
commit d5ae6669ea
5 changed files with 374 additions and 323 deletions

View File

@@ -87,12 +87,14 @@ async def main():
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG( rag = LightRAG(
# log_level="DEBUG",
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
entity_extract_max_gleaning = 1, entity_extract_max_gleaning = 1,
enable_llm_cache=False, enable_llm_cache=True,
embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90},
enable_llm_cache_for_entity_extract = True, enable_llm_cache_for_entity_extract = True,
embedding_cache_config= None, # {"enabled": True,"similarity_threshold": 0.90},
chunk_token_size=CHUNK_TOKEN_SIZE, chunk_token_size=CHUNK_TOKEN_SIZE,
llm_model_max_token_size = MAX_TOKENS, llm_model_max_token_size = MAX_TOKENS,
@@ -106,33 +108,29 @@ async def main():
graph_storage = "OracleGraphStorage", graph_storage = "OracleGraphStorage",
kv_storage = "OracleKVStorage", kv_storage = "OracleKVStorage",
vector_storage="OracleVectorDBStorage", vector_storage="OracleVectorDBStorage",
doc_status_storage="OracleDocStatusStorage",
addon_params = {"example_number":1, "language":"Simplfied Chinese"}, addon_params = {"example_number":1,
"language":"Simplfied Chinese",
"entity_types": ["organization", "person", "geo", "event"],
"insert_batch_size":2,
}
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.key_string_value_json_storage_cls.db = oracle_db rag.set_storage_client(db_client = oracle_db)
rag.vector_db_storage_cls.db = oracle_db
rag.graph_storage_cls.db = oracle_db
rag.doc_status_storage_cls.db = oracle_db
rag.doc_status.db = oracle_db
rag.full_docs.db = oracle_db
rag.text_chunks.db = oracle_db
rag.llm_response_cache.db = oracle_db
rag.key_string_value_json_storage_cls.db = oracle_db
rag.chunks_vdb.db = oracle_db
rag.relationships_vdb.db = oracle_db
rag.entities_vdb.db = oracle_db
rag.graph_storage_cls.db = oracle_db
rag.chunk_entity_relation_graph.db = oracle_db
rag.llm_response_cache.db = oracle_db
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
# Extract and Insert into LightRAG storage # Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f: with open(WORKING_DIR+"/docs.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read()) all_text = f.read()
texts = [x for x in all_text.split("\n") if x]
# New mode use pipeline
await rag.apipeline_process_documents(texts)
await rag.apipeline_process_chunks()
await rag.apipeline_process_extract_graph()
# Old method use ainsert
#await rag.ainsert(texts)
# Perform search in different modes # Perform search in different modes
modes = ["naive", "local", "global", "hybrid"] modes = ["naive", "local", "global", "hybrid"]

View File

@@ -12,9 +12,6 @@ from ..base import (
BaseGraphStorage, BaseGraphStorage,
BaseKVStorage, BaseKVStorage,
BaseVectorStorage, BaseVectorStorage,
DocStatusStorage,
DocStatus,
DocProcessingStatus,
) )
import oracledb import oracledb
@@ -156,8 +153,6 @@ class OracleDB:
if data is None: if data is None:
await cursor.execute(sql) await cursor.execute(sql)
else: else:
# print(data)
# print(sql)
await cursor.execute(sql, data) await cursor.execute(sql, data)
await connection.commit() await connection.commit()
except Exception as e: except Exception as e:
@@ -175,7 +170,7 @@ class OracleKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config.get("embedding_batch_num",10)
################ QUERY METHODS ################ ################ QUERY METHODS ################
@@ -209,7 +204,6 @@ class OracleKVStorage(BaseKVStorage):
else: else:
return None return None
# Query by id
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
"""get doc_chunks data based on id""" """get doc_chunks data based on id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
@@ -229,7 +223,6 @@ class OracleKVStorage(BaseKVStorage):
for row in res: for row in res:
dict_res[row["mode"]][row["id"]] = row dict_res[row["mode"]][row["id"]] = row
res = [{k: v} for k, v in dict_res.items()] res = [{k: v} for k, v in dict_res.items()]
if res: if res:
data = res # [{"data":i} for i in res] data = res # [{"data":i} for i in res]
# print(data) # print(data)
@@ -237,38 +230,42 @@ class OracleKVStorage(BaseKVStorage):
else: else:
return None return None
async def get_by_status_and_ids(self, status: str, ids: list[str]) -> Union[list[dict], None]:
"""Specifically for llm_response_cache."""
if ids is not None:
SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
else:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, multirows=True)
if res:
return res
else:
return None
async def filter_keys(self, keys: list[str]) -> set[str]: async def filter_keys(self, keys: list[str]) -> set[str]:
"""remove duplicated""" """Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["filter_keys"].format( SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
) )
params = {"workspace": self.db.workspace} params = {"workspace": self.db.workspace}
try:
await self.db.query(SQL, params)
except Exception as e:
logger.error(f"Oracle database error: {e}")
print(SQL)
print(params)
res = await self.db.query(SQL, params, multirows=True) res = await self.db.query(SQL, params, multirows=True)
data = None
if res: if res:
exist_keys = [key["id"] for key in res] exist_keys = [key["id"] for key in res]
data = set([s for s in keys if s not in exist_keys]) data = set([s for s in keys if s not in exist_keys])
return data
else: else:
exist_keys = [] return set(keys)
data = set([s for s in keys if s not in exist_keys])
return data
################ INSERT METHODS ################ ################ INSERT METHODS ################
async def upsert(self, data: dict[str, dict]): async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
# print(self._data)
# values = []
if self.namespace == "text_chunks": if self.namespace == "text_chunks":
list_data = [ list_data = [
{ {
"__id__": k, "id": k,
**{k1: v1 for k1, v1 in v.items()}, **{k1: v1 for k1, v1 in v.items()},
} }
for k, v in data.items() for k, v in data.items()
@@ -284,33 +281,30 @@ class OracleKVStorage(BaseKVStorage):
embeddings = np.concatenate(embeddings_list) embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
# print(list_data)
merge_sql = SQL_TEMPLATES["merge_chunk"]
for item in list_data: for item in list_data:
merge_sql = SQL_TEMPLATES["merge_chunk"] _data = {
data = { "id": item["id"],
"check_id": item["__id__"],
"id": item["__id__"],
"content": item["content"], "content": item["content"],
"workspace": self.db.workspace, "workspace": self.db.workspace,
"tokens": item["tokens"], "tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"], "chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"], "full_doc_id": item["full_doc_id"],
"content_vector": item["__vector__"], "content_vector": item["__vector__"],
"status": item["status"],
} }
# print(merge_sql) await self.db.execute(merge_sql, _data)
await self.db.execute(merge_sql, data)
if self.namespace == "full_docs": if self.namespace == "full_docs":
for k, v in self._data.items(): for k, v in data.items():
# values.clear() # values.clear()
merge_sql = SQL_TEMPLATES["merge_doc_full"] merge_sql = SQL_TEMPLATES["merge_doc_full"]
data = { _data = {
"id": k, "id": k,
"content": v["content"], "content": v["content"],
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
# print(merge_sql) await self.db.execute(merge_sql, _data)
await self.db.execute(merge_sql, data)
if self.namespace == "llm_response_cache": if self.namespace == "llm_response_cache":
for mode, items in data.items(): for mode, items in data.items():
@@ -325,102 +319,20 @@ class OracleKVStorage(BaseKVStorage):
} }
await self.db.execute(upsert_sql, _data) await self.db.execute(upsert_sql, _data)
return left_data return None
async def change_status(self, id: str, status: str):
SQL = SQL_TEMPLATES["change_status"].format(
table_name=N_T[self.namespace]
)
params = {"workspace": self.db.workspace, "id": id, "status": status}
await self.db.execute(SQL, params)
async def index_done_callback(self): async def index_done_callback(self):
if self.namespace in ["full_docs", "text_chunks"]: if self.namespace in ["full_docs", "text_chunks"]:
logger.info("full doc and chunk data had been saved into oracle db!") logger.info("full doc and chunk data had been saved into oracle db!")
@dataclass
class OracleDocStatusStorage(DocStatusStorage):
"""Oracle implementation of document status storage"""
# should pass db object to self.db
db: OracleDB = None
meta_fields = None
def __post_init__(self):
pass
async def filter_keys(self, ids: list[str]) -> set[str]:
"""Return keys that don't exist in storage"""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
ids = ",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
if res:
existed = set([element["id"] for element in res])
return set(ids) - existed
else:
return set(ids)
async def get_status_counts(self) -> Dict[str, int]:
"""Get counts of documents in each status"""
SQL = SQL_TEMPLATES["get_status_counts"]
params = {"workspace": self.db.workspace}
res = await self.db.query(SQL, params, True)
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {}
for doc in res:
counts[doc["status"]] = doc["count"]
return counts
async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
"""Get all documents by status"""
SQL = SQL_TEMPLATES["get_docs_by_status"]
params = {"workspace": self.db.workspace, "status": status}
res = await self.db.query(SQL, params, True)
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
# Converting to be a dict
return {
element["id"]: DocProcessingStatus(
#content_summary=element["content_summary"],
content_summary = "",
content_length=element["CONTENT_LENGTH"],
status=element["STATUS"],
created_at=element["CREATETIME"],
updated_at=element["UPDATETIME"],
chunks_count=-1,
#chunks_count=element["chunks_count"],
)
for element in res
}
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all failed documents"""
return await self.get_docs_by_status(DocStatus.FAILED)
async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
"""Get all pending documents"""
return await self.get_docs_by_status(DocStatus.PENDING)
async def index_done_callback(self):
"""Save data after indexing, but for ORACLE, we already saved them during the upsert stage, so no action to take here"""
logger.info("Doc status had been saved into ORACLE db!")
async def upsert(self, data: dict[str, dict]):
"""Update or insert document status
Args:
data: Dictionary of document IDs and their status data
"""
SQL = SQL_TEMPLATES["merge_doc_status"]
for k, v in data.items():
# chunks_count is optional
params = {
"workspace": self.db.workspace,
"id": k,
"content_summary": v["content_summary"],
"content_length": v["content_length"],
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
"status": v["status"],
}
await self.db.execute(SQL, params)
return data
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
# should pass db object to self.db # should pass db object to self.db
@@ -466,7 +378,7 @@ class OracleGraphStorage(BaseGraphStorage):
def __post_init__(self): def __post_init__(self):
"""从graphml文件加载图""" """从graphml文件加载图"""
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
#################### insert method ################ #################### insert method ################
@@ -500,7 +412,6 @@ class OracleGraphStorage(BaseGraphStorage):
"content": content, "content": content,
"content_vector": content_vector, "content_vector": content_vector,
} }
# print(merge_sql)
await self.db.execute(merge_sql, data) await self.db.execute(merge_sql, data)
# self._graph.add_node(node_id, **node_data) # self._graph.add_node(node_id, **node_data)
@@ -718,9 +629,10 @@ TABLES = {
}, },
"LIGHTRAG_DOC_CHUNKS": { "LIGHTRAG_DOC_CHUNKS": {
"ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
id varchar(256) PRIMARY KEY, id varchar(256),
workspace varchar(1024), workspace varchar(1024),
full_doc_id varchar(256), full_doc_id varchar(256),
status varchar(256),
chunk_order_index NUMBER, chunk_order_index NUMBER,
tokens NUMBER, tokens NUMBER,
content CLOB, content CLOB,
@@ -795,9 +707,9 @@ TABLES = {
SQL_TEMPLATES = { SQL_TEMPLATES = {
# SQL for KVStorage # SQL for KVStorage
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
"get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
@@ -808,24 +720,34 @@ SQL_TEMPLATES = {
"get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
"get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
"get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
"get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
"get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
"change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
"merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL USING DUAL
ON (a.id = :id and a.workspace = :workspace) ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT(id,content,workspace) values(:id,:content,:workspace)""", INSERT(id,content,workspace) values(:id,:content,:workspace)""",
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
USING DUAL USING DUAL
ON (a.id = :check_id) ON (id = :id and workspace = :workspace)
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN INSERT
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
"upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
USING DUAL USING DUAL
@@ -839,26 +761,6 @@ SQL_TEMPLATES = {
cache_mode = :cache_mode, cache_mode = :cache_mode,
updatetime = SYSDATE""", updatetime = SYSDATE""",
"get_by_id_doc_status": "SELECT id FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace AND id IN ({ids})",
"get_status_counts": """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_FULL WHERE workspace=:workspace GROUP BY STATUS""",
"get_docs_by_status": """select content_length,status,
TO_CHAR(created_at,'YYYY-MM-DD HH24:MI:SS') as created_at,TO_CHAR(updatetime,'YYYY-MM-DD HH24:MI:SS') as updatetime
from LIGHTRAG_DOC_STATUS where workspace=:workspace and status=:status""",
"merge_doc_status":"""MERGE INTO LIGHTRAG_DOC_FULL a
USING DUAL
ON (a.id = :id and a.workspace = :workspace)
WHEN NOT MATCHED THEN
INSERT (id,content_summary,content_length,chunks_count,status) values(:id,:content_summary,:content_length,:chunks_count,:status)
WHEN MATCHED THEN UPDATE
SET content_summary = :content_summary,
content_length = :content_length,
chunks_count = :chunks_count,
status = :status,
updatetime = SYSDATE""",
# SQL for VectorStorage # SQL for VectorStorage
"entities": """SELECT name as entity_name FROM "entities": """SELECT name as entity_name FROM

View File

@@ -26,6 +26,7 @@ from .utils import (
convert_response_to_json, convert_response_to_json,
logger, logger,
set_logger, set_logger,
statistic_data
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -36,22 +37,31 @@ from .base import (
DocStatus, DocStatus,
) )
from .storage import (
JsonKVStorage,
NanoVectorDBStorage,
NetworkXStorage,
JsonDocStatusStorage,
)
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
STORAGES = {
"JsonKVStorage": '.storage',
"NanoVectorDBStorage": '.storage',
"NetworkXStorage": '.storage',
"JsonDocStatusStorage": '.storage',
# future KG integrations "Neo4JStorage":".kg.neo4j_impl",
"OracleKVStorage":".kg.oracle_impl",
# from .kg.ArangoDB_impl import ( "OracleGraphStorage":".kg.oracle_impl",
# GraphStorage as ArangoDBStorage "OracleVectorDBStorage":".kg.oracle_impl",
# ) "MilvusVectorDBStorge":".kg.milvus_impl",
"MongoKVStorage":".kg.mongo_impl",
"ChromaVectorDBStorage":".kg.chroma_impl",
"TiDBKVStorage":".kg.tidb_impl",
"TiDBVectorDBStorage":".kg.tidb_impl",
"TiDBGraphStorage":".kg.tidb_impl",
"PGKVStorage":".kg.postgres_impl",
"PGVectorStorage":".kg.postgres_impl",
"AGEStorage":".kg.age_impl",
"PGGraphStorage":".kg.postgres_impl",
"GremlinStorage":".kg.gremlin_impl",
"PGDocStatusStorage":".kg.postgres_impl",
}
def lazy_external_import(module_name: str, class_name: str): def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller.""" """Lazily import a class from an external module based on the package of the caller."""
@@ -65,36 +75,13 @@ def lazy_external_import(module_name: str, class_name: str):
def import_class(*args, **kwargs): def import_class(*args, **kwargs):
import importlib import importlib
# Import the module using importlib
module = importlib.import_module(module_name, package=package) module = importlib.import_module(module_name, package=package)
# Get the class from the module and instantiate it
cls = getattr(module, class_name) cls = getattr(module, class_name)
return cls(*args, **kwargs) return cls(*args, **kwargs)
return import_class return import_class
Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
OracleDocStatusStorage = lazy_external_import(".kg.oracle_impl", "OracleDocStatusStorage")
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
""" """
Ensure that there is always an event loop available. Ensure that there is always an event loop available.
@@ -198,34 +185,49 @@ class LightRAG:
logger.setLevel(self.log_level) logger.setLevel(self.log_level)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# @TODO: should move all storage setup here to leverage initial start params attached to self.
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
self._get_storage_class()[self.kv_storage]
)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
self.vector_storage
]
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
self.graph_storage
]
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
self.llm_response_cache = self.key_string_value_json_storage_cls( # show config
namespace="llm_response_cache", global_config=asdict(self)
global_config=asdict(self), _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init LLM
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
self.embedding_func
)
# Initialize all storages
self.key_string_value_json_storage_cls: Type[BaseKVStorage] = self._get_storage_class(self.kv_storage)
self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(self.vector_storage)
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(self.graph_storage)
self.key_string_value_json_storage_cls = partial(
self.key_string_value_json_storage_cls,
global_config=global_config
)
self.vector_db_storage_cls = partial(
self.vector_db_storage_cls,
global_config=global_config
)
self.graph_storage_cls = partial(
self.graph_storage_cls,
global_config=global_config
)
self.json_doc_status_storage = self.key_string_value_json_storage_cls(
namespace="json_doc_status_storage",
embedding_func=None, embedding_func=None,
) )
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( self.llm_response_cache = self.key_string_value_json_storage_cls(
self.embedding_func namespace="llm_response_cache",
embedding_func=None,
) )
#### ####
@@ -233,17 +235,14 @@ class LightRAG:
#### ####
self.full_docs = self.key_string_value_json_storage_cls( self.full_docs = self.key_string_value_json_storage_cls(
namespace="full_docs", namespace="full_docs",
global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.text_chunks = self.key_string_value_json_storage_cls( self.text_chunks = self.key_string_value_json_storage_cls(
namespace="text_chunks", namespace="text_chunks",
global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
self.chunk_entity_relation_graph = self.graph_storage_cls( self.chunk_entity_relation_graph = self.graph_storage_cls(
namespace="chunk_entity_relation", namespace="chunk_entity_relation",
global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
#### ####
@@ -252,73 +251,64 @@ class LightRAG:
self.entities_vdb = self.vector_db_storage_cls( self.entities_vdb = self.vector_db_storage_cls(
namespace="entities", namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"entity_name"}, meta_fields={"entity_name"},
) )
self.relationships_vdb = self.vector_db_storage_cls( self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships", namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}, meta_fields={"src_id", "tgt_id"},
) )
self.chunks_vdb = self.vector_db_storage_cls( self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks", namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func, embedding_func=self.embedding_func,
) )
if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config"):
hashing_kv = self.llm_response_cache
else:
hashing_kv = self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
embedding_func=None,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial( partial(
self.llm_model_func, self.llm_model_func,
hashing_kv=self.llm_response_cache hashing_kv=hashing_kv,
if self.llm_response_cache
and hasattr(self.llm_response_cache, "global_config")
else self.key_string_value_json_storage_cls(
namespace="llm_response_cache",
global_config=asdict(self),
embedding_func=None,
),
**self.llm_model_kwargs, **self.llm_model_kwargs,
) )
) )
# Initialize document status storage # Initialize document status storage
self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage] self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
self.doc_status = self.doc_status_storage_cls( self.doc_status = self.doc_status_storage_cls(
namespace="doc_status", namespace="doc_status",
global_config=asdict(self), global_config=global_config,
embedding_func=None, embedding_func=None,
) )
def _get_storage_class(self) -> dict: def _get_storage_class(self, storage_name: str) -> dict:
return { import_path = STORAGES[storage_name]
# kv storage storage_class = lazy_external_import(import_path, storage_name)
"JsonKVStorage": JsonKVStorage, return storage_class
"OracleKVStorage": OracleKVStorage,
"OracleDocStatusStorage":OracleDocStatusStorage, def set_storage_client(self,db_client):
"MongoKVStorage": MongoKVStorage, # Now only tested on Oracle Database
"TiDBKVStorage": TiDBKVStorage, for storage in [self.vector_db_storage_cls,
# vector storage self.graph_storage_cls,
"NanoVectorDBStorage": NanoVectorDBStorage, self.doc_status, self.full_docs,
"OracleVectorDBStorage": OracleVectorDBStorage, self.text_chunks,
"MilvusVectorDBStorge": MilvusVectorDBStorge, self.llm_response_cache,
"ChromaVectorDBStorage": ChromaVectorDBStorage, self.key_string_value_json_storage_cls,
"TiDBVectorDBStorage": TiDBVectorDBStorage, self.chunks_vdb,
# graph storage self.relationships_vdb,
"NetworkXStorage": NetworkXStorage, self.entities_vdb,
"Neo4JStorage": Neo4JStorage, self.graph_storage_cls,
"OracleGraphStorage": OracleGraphStorage, self.chunk_entity_relation_graph,
"AGEStorage": AGEStorage, self.llm_response_cache]:
"PGGraphStorage": PGGraphStorage, # set client
"PGKVStorage": PGKVStorage, storage.db = db_client
"PGDocStatusStorage": PGDocStatusStorage,
"PGVectorStorage": PGVectorStorage,
"TiDBGraphStorage": TiDBGraphStorage,
"GremlinStorage": GremlinStorage,
# "ArangoDBStorage": ArangoDBStorage
"JsonDocStatusStorage": JsonDocStatusStorage,
}
def insert( def insert(
self, string_or_strings, split_by_character=None, split_by_character_only=False self, string_or_strings, split_by_character=None, split_by_character_only=False
@@ -359,6 +349,11 @@ class LightRAG:
for content in unique_contents for content in unique_contents
} }
# 3. Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
# 3. Filter out already processed documents # 3. Filter out already processed documents
_add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys())) _add_doc_keys = await self.doc_status.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
@@ -406,12 +401,7 @@ class LightRAG:
} }
# Update status with chunks information # Update status with chunks information
doc_status.update( doc_status.update({"chunks_count": len(chunks),"updated_at": datetime.now().isoformat()})
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
try: try:
@@ -435,30 +425,16 @@ class LightRAG:
self.chunk_entity_relation_graph = maybe_new_kg self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks) await self.text_chunks.upsert(chunks)
# Update status to processed # Update status to processed
doc_status.update( doc_status.update({"status": DocStatus.PROCESSED,"updated_at": datetime.now().isoformat()})
{
"status": DocStatus.PROCESSED,
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
except Exception as e: except Exception as e:
# Mark as failed if any step fails # Mark as failed if any step fails
doc_status.update( doc_status.update({"status": DocStatus.FAILED,"error": str(e),"updated_at": datetime.now().isoformat()})
{
"status": DocStatus.FAILED,
"error": str(e),
"updated_at": datetime.now().isoformat(),
}
)
await self.doc_status.upsert({doc_id: doc_status}) await self.doc_status.upsert({doc_id: doc_status})
raise e raise e
@@ -540,6 +516,174 @@ class LightRAG:
if update_storage: if update_storage:
await self._insert_done() await self._insert_done()
async def apipeline_process_documents(self, string_or_strings):
"""Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs
Args:
string_or_strings: Single document string or list of document strings
"""
if isinstance(string_or_strings, str):
string_or_strings = [string_or_strings]
# 1. Remove duplicate contents from the list
unique_contents = list(set(doc.strip() for doc in string_or_strings))
logger.info(f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents")
# 2. Generate document IDs and initial status
new_docs = {
compute_mdhash_id(content, prefix="doc-"): {
"content": content,
"content_summary": self._get_content_summary(content),
"content_length": len(content),
"status": DocStatus.PENDING,
"created_at": datetime.now().isoformat(),
"updated_at": None,
}
for content in unique_contents
}
# 3. Filter out already processed documents
_not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
if len(_not_stored_doc_keys) < len(new_docs):
logger.info(f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents")
new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
if not new_docs:
logger.info(f"All documents have been processed or are duplicates")
return None
# 4. Store original document
for doc_id, doc in new_docs.items():
await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
await self.full_docs.change_status(doc_id, DocStatus.PENDING)
logger.info(f"Stored {len(new_docs)} new unique documents")
async def apipeline_process_chunks(self):
"""Get pendding documents, split into chunks,insert chunks"""
# 1. get all pending and failed documents
_todo_doc_keys = []
_failed_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.FAILED,ids = None)
_pendding_doc = await self.full_docs.get_by_status_and_ids(status = DocStatus.PENDING,ids = None)
if _failed_doc:
_todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
if _pendding_doc:
_todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
if not _todo_doc_keys:
logger.info("All documents have been processed or are duplicates")
return None
else:
logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
new_docs = {
doc["id"]: doc
for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
}
# 2. split docs into chunks, insert chunks, update doc status
chunk_cnt = 0
batch_size = self.addon_params.get("insert_batch_size", 10)
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}"
):
try:
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
"status": DocStatus.PENDING,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
chunk_cnt += len(chunks)
await self.text_chunks.upsert(chunks)
await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED)
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Update doc status
await self.full_docs.change_status(doc_id, DocStatus.PROCESSED)
except Exception as e:
# Mark as failed if any step fails
await self.full_docs.change_status(doc_id, DocStatus.FAILED)
raise e
except Exception as e:
import traceback
error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
continue
logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
async def apipeline_process_extract_graph(self):
"""Get pendding or failed chunks, extract entities and relationships from each chunk"""
# 1. get all pending and failed chunks
_todo_chunk_keys = []
_failed_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.FAILED,ids = None)
_pendding_chunks = await self.text_chunks.get_by_status_and_ids(status = DocStatus.PENDING,ids = None)
if _failed_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
if _pendding_chunks:
_todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
if not _todo_chunk_keys:
logger.info("All chunks have been processed or are duplicates")
return None
# Process documents in batches
batch_size = self.addon_params.get("insert_batch_size", 10)
semaphore = asyncio.Semaphore(batch_size) # Control the number of tasks that are processed simultaneously
async def process_chunk(chunk_id):
async with semaphore:
chunks = {i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])}
# Extract and store entities and relationships
try:
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
llm_response_cache=self.llm_response_cache,
global_config=asdict(self),
)
if maybe_new_kg is None:
logger.info("No entities or relationships extracted!")
# Update status to processed
await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED)
except Exception as e:
logger.error("Failed to extract entities and relationships")
# Mark as failed if any step fails
await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
raise e
with tqdm_async(total=len(_todo_chunk_keys),
desc="\nLevel 1 - Processing chunks",
unit="chunk",
position=0) as progress:
tasks = []
for chunk_id in _todo_chunk_keys:
task = asyncio.create_task(process_chunk(chunk_id))
tasks.append(task)
for future in asyncio.as_completed(tasks):
await future
progress.update(1)
progress.set_postfix({
'LLM call': statistic_data["llm_call"],
'LLM cache': statistic_data["llm_cache"],
})
# Ensure all indexes are updated after each document
await self._insert_done()
async def _insert_done(self): async def _insert_done(self):
tasks = [] tasks = []
for storage_inst in [ for storage_inst in [

View File

@@ -20,6 +20,7 @@ from .utils import (
handle_cache, handle_cache,
save_to_cache, save_to_cache,
CacheData, CacheData,
statistic_data
) )
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
@@ -371,8 +372,10 @@ async def extract_entities(
if need_to_restore: if need_to_restore:
llm_response_cache.global_config = global_config llm_response_cache.global_config = global_config
if cached_return: if cached_return:
logger.debug(f"Found cache for {arg_hash}")
statistic_data["llm_cache"] += 1
return cached_return return cached_return
statistic_data["llm_call"] += 1
if history_messages: if history_messages:
res: str = await use_llm_func( res: str = await use_llm_func(
input_text, history_messages=history_messages input_text, history_messages=history_messages
@@ -459,10 +462,8 @@ async def extract_entities(
now_ticks = PROMPTS["process_tickers"][ now_ticks = PROMPTS["process_tickers"][
already_processed % len(PROMPTS["process_tickers"]) already_processed % len(PROMPTS["process_tickers"])
] ]
print( logger.debug(
f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
end="",
flush=True,
) )
return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)
@@ -470,8 +471,8 @@ async def extract_entities(
for result in tqdm_async( for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
total=len(ordered_chunks), total=len(ordered_chunks),
desc="Extracting entities from chunks", desc="Level 2 - Extracting entities and relationships",
unit="chunk", unit="chunk", position=1,leave=False
): ):
results.append(await result) results.append(await result)
@@ -482,7 +483,7 @@ async def extract_entities(
maybe_nodes[k].extend(v) maybe_nodes[k].extend(v)
for k, v in m_edges.items(): for k, v in m_edges.items():
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
logger.info("Inserting entities into storage...") logger.debug("Inserting entities into storage...")
all_entities_data = [] all_entities_data = []
for result in tqdm_async( for result in tqdm_async(
asyncio.as_completed( asyncio.as_completed(
@@ -492,12 +493,12 @@ async def extract_entities(
] ]
), ),
total=len(maybe_nodes), total=len(maybe_nodes),
desc="Inserting entities", desc="Level 3 - Inserting entities",
unit="entity", unit="entity", position=2,leave=False
): ):
all_entities_data.append(await result) all_entities_data.append(await result)
logger.info("Inserting relationships into storage...") logger.debug("Inserting relationships into storage...")
all_relationships_data = [] all_relationships_data = []
for result in tqdm_async( for result in tqdm_async(
asyncio.as_completed( asyncio.as_completed(
@@ -509,8 +510,8 @@ async def extract_entities(
] ]
), ),
total=len(maybe_edges), total=len(maybe_edges),
desc="Inserting relationships", desc="Level 3 - Inserting relationships",
unit="relationship", unit="relationship", position=3,leave=False
): ):
all_relationships_data.append(await result) all_relationships_data.append(await result)

View File

@@ -30,8 +30,13 @@ class UnlimitedSemaphore:
ENCODER = None ENCODER = None
statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
logger = logging.getLogger("lightrag") logger = logging.getLogger("lightrag")
# Set httpx logging level to WARNING
logging.getLogger("httpx").setLevel(logging.WARNING)
def set_logger(log_file: str): def set_logger(log_file: str):
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
return None, None, None, None return None, None, None, None
# For naive mode, only use simple cache matching # For naive mode, only use simple cache matching
if mode == "naive": #if mode == "naive":
if mode == "default":
if exists_func(hashing_kv, "get_by_mode_and_id"): if exists_func(hashing_kv, "get_by_mode_and_id"):
mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
else: else: