diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 7489d923..00000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.gitignore b/.gitignore index 5749adb5..ec95f8a5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ rag_storage venv/ examples/input/ examples/output/ +.DS_Store diff --git a/get_all_edges_nx.py b/examples/get_all_edges_nx.py similarity index 100% rename from get_all_edges_nx.py rename to examples/get_all_edges_nx.py diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index bbb69319..6de6e0a7 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" APIKEY = "ocigenerativeai" CHATMODEL = "cohere.command-r-plus" EMBEDMODEL = "cohere.embed-multilingual-v3.0" - +CHUNK_TOKEN_SIZE = 1024 +MAX_TOKENS = 4000 if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -86,30 +87,46 @@ async def main(): # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( - enable_llm_cache=False, + # log_level="DEBUG", working_dir=WORKING_DIR, - chunk_token_size=512, + entity_extract_max_gleaning=1, + enable_llm_cache=True, + enable_llm_cache_for_entity_extract=True, + embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90}, + chunk_token_size=CHUNK_TOKEN_SIZE, + llm_model_max_token_size=MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, - max_token_size=512, + max_token_size=500, func=embedding_func, ), graph_storage="OracleGraphStorage", kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", + addon_params={ + "example_number": 1, + "language": "Simplfied Chinese", + "entity_types": ["organization", "person", "geo", "event"], + "insert_batch_size": 2, + }, ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.graph_storage_cls.db = oracle_db - rag.key_string_value_json_storage_cls.db = oracle_db - rag.vector_db_storage_cls.db = oracle_db - # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + rag.set_storage_client(db_client=oracle_db) # Extract and Insert into LightRAG storage - with open("./dickens/demo.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) + with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: + all_text = f.read() + texts = [x for x in all_text.split("\n") if x] + + # New mode use pipeline + await rag.apipeline_process_documents(texts) + await rag.apipeline_process_chunks() + await rag.apipeline_process_extract_graph() + + # Old method use ainsert + # await rag.ainsert(texts) # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] diff --git a/test.py b/examples/test.py similarity index 100% rename from test.py rename to examples/test.py diff --git a/test_chromadb.py b/examples/test_chromadb.py similarity index 100% rename from test_chromadb.py rename to examples/test_chromadb.py diff --git a/test_neo4j.py b/examples/test_neo4j.py similarity index 100% rename from test_neo4j.py rename to examples/test_neo4j.py diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 34745312..e30b6909 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -153,8 +153,6 @@ class OracleDB: if data is None: await cursor.execute(sql) else: - # print(data) - # print(sql) await cursor.execute(sql, data) await connection.commit() except Exception as e: @@ -167,35 +165,64 @@ class OracleDB: @dataclass class OracleKVStorage(BaseKVStorage): # should pass db object to self.db + db: OracleDB = None + meta_fields = None + def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> Union[dict, None]: - """根据 id 获取 doc_full 数据.""" + """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL, params) + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + else: + res = await self.db.query(SQL, params) if res: - data = res # {"data":res} - # print (data) - return data + return res + else: + return None + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res else: return None - # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: - """根据 id 获取 doc_chunks 数据""" + """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} # print("get_by_ids:"+SQL) - # print(params) res = await self.db.query(SQL, params, multirows=True) + if "llm_response_cache" == self.namespace: + modes = set() + dict_res: dict[str, dict] = {} + for row in res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in res: + dict_res[row["mode"]][row["id"]] = row + res = [{k: v} for k, v in dict_res.items()] if res: data = res # [{"data":i} for i in res] # print(data) @@ -203,38 +230,43 @@ class OracleKVStorage(BaseKVStorage): else: return None + async def get_by_status_and_ids( + self, status: str, ids: list[str] + ) -> Union[list[dict], None]: + """Specifically for llm_response_cache.""" + if ids is not None: + SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + else: + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + res = await self.db.query(SQL, params, multirows=True) + if res: + return res + else: + return None + async def filter_keys(self, keys: list[str]) -> set[str]: - """过滤掉重复内容""" + """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) ) params = {"workspace": self.db.workspace} - try: - await self.db.query(SQL, params) - except Exception as e: - logger.error(f"Oracle database error: {e}") - print(SQL) - print(params) res = await self.db.query(SQL, params, multirows=True) - data = None if res: exist_keys = [key["id"] for key in res] data = set([s for s in keys if s not in exist_keys]) + return data else: - exist_keys = [] - data = set([s for s in keys if s not in exist_keys]) - return data + return set(keys) ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - # print(self._data) - # values = [] if self.namespace == "text_chunks": list_data = [ { - "__id__": k, + "id": k, **{k1: v1 for k1, v1 in v.items()}, } for k, v in data.items() @@ -250,35 +282,50 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - # print(list_data) + + merge_sql = SQL_TEMPLATES["merge_chunk"] for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"] - data = { - "check_id": item["__id__"], - "id": item["__id__"], + _data = { + "id": item["id"], "content": item["content"], "workspace": self.db.workspace, "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], "content_vector": item["__vector__"], + "status": item["status"], } - # print(merge_sql) - await self.db.execute(merge_sql, data) - + await self.db.execute(merge_sql, _data) if self.namespace == "full_docs": - for k, v in self._data.items(): + for k, v in data.items(): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] - data = { - "check_id": k, + _data = { "id": k, "content": v["content"], "workspace": self.db.workspace, } - # print(merge_sql) - await self.db.execute(merge_sql, data) - return left_data + await self.db.execute(merge_sql, _data) + + if self.namespace == "llm_response_cache": + for mode, items in data.items(): + for k, v in items.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "cache_mode": mode, + } + + await self.db.execute(upsert_sql, _data) + return None + + async def change_status(self, id: str, status: str): + SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) + params = {"workspace": self.db.workspace, "id": id, "status": status} + await self.db.execute(SQL, params) async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: @@ -287,6 +334,8 @@ class OracleKVStorage(BaseKVStorage): @dataclass class OracleVectorDBStorage(BaseVectorStorage): + # should pass db object to self.db + db: OracleDB = None cosine_better_than_threshold: float = 0.2 def __post_init__(self): @@ -328,7 +377,7 @@ class OracleGraphStorage(BaseGraphStorage): def __post_init__(self): """从graphml文件加载图""" - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) #################### insert method ################ @@ -362,7 +411,6 @@ class OracleGraphStorage(BaseGraphStorage): "content": content, "content_vector": content_vector, } - # print(merge_sql) await self.db.execute(merge_sql, data) # self._graph.add_node(node_id, **node_data) @@ -564,20 +612,26 @@ N_T = { TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id varchar(256)PRIMARY KEY, + id varchar(256), workspace varchar(1024), doc_name varchar(1024), content CLOB, meta JSON, + content_summary varchar(1024), + content_length NUMBER, + status varchar(256), + chunks_count NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL + updatetime TIMESTAMP DEFAULT NULL, + error varchar(4096) )""" }, "LIGHTRAG_DOC_CHUNKS": { "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - id varchar(256) PRIMARY KEY, + id varchar(256), workspace varchar(1024), full_doc_id varchar(256), + status varchar(256), chunk_order_index NUMBER, tokens NUMBER, content CLOB, @@ -619,9 +673,15 @@ TABLES = { "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( id varchar(256) PRIMARY KEY, - send clob, - return clob, - model varchar(1024), + workspace varchar(1024), + cache_mode varchar(256), + model_name varchar(256), + original_prompt clob, + return_value clob, + embedding CLOB, + embedding_shape NUMBER, + embedding_min NUMBER, + embedding_max NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL )""" @@ -646,23 +706,44 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", - "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", + "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", + "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", + "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", + "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:id,:content,:workspace) - """, - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", + "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace) values(:id,:content,:workspace)""", + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS + USING DUAL + ON (id = :id and workspace = :workspace) + WHEN NOT MATCHED THEN INSERT + (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, + "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a + USING DUAL + ON (a.id = :id) + WHEN NOT MATCHED THEN + INSERT (workspace,id,original_prompt,return_value,cache_mode) + VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode) + WHEN MATCHED THEN UPDATE + SET original_prompt = :original_prompt, + return_value = :return_value, + cache_mode = :cache_mode, + updatetime = SYSDATE""", # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -714,16 +795,22 @@ SQL_TEMPLATES = { COLUMNS (a.name as source_name,b.name as target_name))""", "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL - ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.name=:name) WHEN NOT MATCHED THEN INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL - ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cacdfc50..ad79afaa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -28,6 +28,7 @@ from .utils import ( convert_response_to_json, logger, set_logger, + statistic_data, ) from .base import ( BaseGraphStorage, @@ -38,21 +39,30 @@ from .base import ( DocStatus, ) -from .storage import ( - JsonKVStorage, - NanoVectorDBStorage, - NetworkXStorage, - JsonDocStatusStorage, -) - from .prompt import GRAPH_FIELD_SEP - -# future KG integrations - -# from .kg.ArangoDB_impl import ( -# GraphStorage as ArangoDBStorage -# ) +STORAGES = { + "JsonKVStorage": ".storage", + "NanoVectorDBStorage": ".storage", + "NetworkXStorage": ".storage", + "JsonDocStatusStorage": ".storage", + "Neo4JStorage": ".kg.neo4j_impl", + "OracleKVStorage": ".kg.oracle_impl", + "OracleGraphStorage": ".kg.oracle_impl", + "OracleVectorDBStorage": ".kg.oracle_impl", + "MilvusVectorDBStorge": ".kg.milvus_impl", + "MongoKVStorage": ".kg.mongo_impl", + "ChromaVectorDBStorage": ".kg.chroma_impl", + "TiDBKVStorage": ".kg.tidb_impl", + "TiDBVectorDBStorage": ".kg.tidb_impl", + "TiDBGraphStorage": ".kg.tidb_impl", + "PGKVStorage": ".kg.postgres_impl", + "PGVectorStorage": ".kg.postgres_impl", + "AGEStorage": ".kg.age_impl", + "PGGraphStorage": ".kg.postgres_impl", + "GremlinStorage": ".kg.gremlin_impl", + "PGDocStatusStorage": ".kg.postgres_impl", +} def lazy_external_import(module_name: str, class_name: str): @@ -68,34 +78,13 @@ def lazy_external_import(module_name: str, class_name: str): def import_class(*args, **kwargs): import importlib - # Import the module using importlib module = importlib.import_module(module_name, package=package) - - # Get the class from the module and instantiate it cls = getattr(module, class_name) return cls(*args, **kwargs) return import_class -Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") -OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") -OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") -OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") -MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") -MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") -ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") -TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") -TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") -TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage") -PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage") -PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage") -AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage") -PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage") -GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage") -PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage") - - def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. @@ -199,34 +188,51 @@ class LightRAG: logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") - - _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) - logger.debug(f"LightRAG init with param:\n {_print_config}\n") - - # @TODO: should move all storage setup here to leverage initial start params attached to self. - - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( - self._get_storage_class()[self.kv_storage] - ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ - self.vector_storage - ] - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ - self.graph_storage - ] - if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), + # show config + global_config = asdict(self) + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) + logger.debug(f"LightRAG init with param:\n {_print_config}\n") + + # Init LLM + self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( + self.embedding_func + ) + + # Initialize all storages + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self._get_storage_class(self.kv_storage) + ) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + self.vector_storage + ) + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + self.graph_storage + ) + + self.key_string_value_json_storage_cls = partial( + self.key_string_value_json_storage_cls, global_config=global_config + ) + + self.vector_db_storage_cls = partial( + self.vector_db_storage_cls, global_config=global_config + ) + + self.graph_storage_cls = partial( + self.graph_storage_cls, global_config=global_config + ) + + self.json_doc_status_storage = self.key_string_value_json_storage_cls( + namespace="json_doc_status_storage", embedding_func=None, ) - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( - self.embedding_func + self.llm_response_cache = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, ) #### @@ -234,17 +240,14 @@ class LightRAG: #### self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( namespace="text_chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( namespace="chunk_entity_relation", - global_config=asdict(self), embedding_func=self.embedding_func, ) #### @@ -253,72 +256,69 @@ class LightRAG: self.entities_vdb = self.vector_db_storage_cls( namespace="entities", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) self.relationships_vdb = self.vector_db_storage_cls( namespace="relationships", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) self.chunks_vdb = self.vector_db_storage_cls( namespace="chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) + if self.llm_response_cache and hasattr( + self.llm_response_cache, "global_config" + ): + hashing_kv = self.llm_response_cache + else: + hashing_kv = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, + ) + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), - embedding_func=None, - ), + hashing_kv=hashing_kv, **self.llm_model_kwargs, ) ) # Initialize document status storage - self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage] + self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status = self.doc_status_storage_cls( namespace="doc_status", - global_config=asdict(self), + global_config=global_config, embedding_func=None, ) - def _get_storage_class(self) -> dict: - return { - # kv storage - "JsonKVStorage": JsonKVStorage, - "OracleKVStorage": OracleKVStorage, - "MongoKVStorage": MongoKVStorage, - "TiDBKVStorage": TiDBKVStorage, - # vector storage - "NanoVectorDBStorage": NanoVectorDBStorage, - "OracleVectorDBStorage": OracleVectorDBStorage, - "MilvusVectorDBStorge": MilvusVectorDBStorge, - "ChromaVectorDBStorage": ChromaVectorDBStorage, - "TiDBVectorDBStorage": TiDBVectorDBStorage, - # graph storage - "NetworkXStorage": NetworkXStorage, - "Neo4JStorage": Neo4JStorage, - "OracleGraphStorage": OracleGraphStorage, - "AGEStorage": AGEStorage, - "PGGraphStorage": PGGraphStorage, - "PGKVStorage": PGKVStorage, - "PGDocStatusStorage": PGDocStatusStorage, - "PGVectorStorage": PGVectorStorage, - "TiDBGraphStorage": TiDBGraphStorage, - "GremlinStorage": GremlinStorage, - # "ArangoDBStorage": ArangoDBStorage - "JsonDocStatusStorage": JsonDocStatusStorage, - } + def _get_storage_class(self, storage_name: str) -> dict: + import_path = STORAGES[storage_name] + storage_class = lazy_external_import(import_path, storage_name) + return storage_class + + def set_storage_client(self, db_client): + # Now only tested on Oracle Database + for storage in [ + self.vector_db_storage_cls, + self.graph_storage_cls, + self.doc_status, + self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.key_string_value_json_storage_cls, + self.chunks_vdb, + self.relationships_vdb, + self.entities_vdb, + self.graph_storage_cls, + self.chunk_entity_relation_graph, + self.llm_response_cache, + ]: + # set client + storage.db = db_client def insert( self, string_or_strings, split_by_character=None, split_by_character_only=False @@ -540,6 +540,195 @@ class LightRAG: if update_storage: await self._insert_done() + async def apipeline_process_documents(self, string_or_strings): + """Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs + Args: + string_or_strings: Single document string or list of document strings + """ + if isinstance(string_or_strings, str): + string_or_strings = [string_or_strings] + + # 1. Remove duplicate contents from the list + unique_contents = list(set(doc.strip() for doc in string_or_strings)) + + logger.info( + f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents" + ) + + # 2. Generate document IDs and initial status + new_docs = { + compute_mdhash_id(content, prefix="doc-"): { + "content": content, + "content_summary": self._get_content_summary(content), + "content_length": len(content), + "status": DocStatus.PENDING, + "created_at": datetime.now().isoformat(), + "updated_at": None, + } + for content in unique_contents + } + + # 3. Filter out already processed documents + _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) + if len(_not_stored_doc_keys) < len(new_docs): + logger.info( + f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents" + ) + new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} + + if not new_docs: + logger.info("All documents have been processed or are duplicates") + return None + + # 4. Store original document + for doc_id, doc in new_docs.items(): + await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) + await self.full_docs.change_status(doc_id, DocStatus.PENDING) + logger.info(f"Stored {len(new_docs)} new unique documents") + + async def apipeline_process_chunks(self): + """Get pendding documents, split into chunks,insert chunks""" + # 1. get all pending and failed documents + _todo_doc_keys = [] + _failed_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) + if _failed_doc: + _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) + if _pendding_doc: + _todo_doc_keys.extend([doc["id"] for doc in _pendding_doc]) + if not _todo_doc_keys: + logger.info("All documents have been processed or are duplicates") + return None + else: + logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") + + new_docs = { + doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys) + } + + # 2. split docs into chunks, insert chunks, update doc status + chunk_cnt = 0 + batch_size = self.addon_params.get("insert_batch_size", 10) + for i in range(0, len(new_docs), batch_size): + batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + for doc_id, doc in tqdm_async( + batch_docs.items(), + desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}", + ): + try: + # Generate chunks from document + chunks = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + "status": DocStatus.PENDING, + } + for dp in chunking_by_token_size( + doc["content"], + overlap_token_size=self.chunk_overlap_token_size, + max_token_size=self.chunk_token_size, + tiktoken_model=self.tiktoken_model_name, + ) + } + chunk_cnt += len(chunks) + await self.text_chunks.upsert(chunks) + await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED) + + try: + # Store chunks in vector database + await self.chunks_vdb.upsert(chunks) + # Update doc status + await self.full_docs.change_status(doc_id, DocStatus.PROCESSED) + except Exception as e: + # Mark as failed if any step fails + await self.full_docs.change_status(doc_id, DocStatus.FAILED) + raise e + except Exception as e: + import traceback + + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") + + async def apipeline_process_extract_graph(self): + """Get pendding or failed chunks, extract entities and relationships from each chunk""" + # 1. get all pending and failed chunks + _todo_chunk_keys = [] + _failed_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) + if _failed_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) + if _pendding_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks]) + if not _todo_chunk_keys: + logger.info("All chunks have been processed or are duplicates") + return None + + # Process documents in batches + batch_size = self.addon_params.get("insert_batch_size", 10) + + semaphore = asyncio.Semaphore( + batch_size + ) # Control the number of tasks that are processed simultaneously + + async def process_chunk(chunk_id): + async with semaphore: + chunks = { + i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) + } + # Extract and store entities and relationships + try: + maybe_new_kg = await extract_entities( + chunks, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + llm_response_cache=self.llm_response_cache, + global_config=asdict(self), + ) + if maybe_new_kg is None: + logger.info("No entities or relationships extracted!") + # Update status to processed + await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED) + except Exception as e: + logger.error("Failed to extract entities and relationships") + # Mark as failed if any step fails + await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) + raise e + + with tqdm_async( + total=len(_todo_chunk_keys), + desc="\nLevel 1 - Processing chunks", + unit="chunk", + position=0, + ) as progress: + tasks = [] + for chunk_id in _todo_chunk_keys: + task = asyncio.create_task(process_chunk(chunk_id)) + tasks.append(task) + + for future in asyncio.as_completed(tasks): + await future + progress.update(1) + progress.set_postfix( + { + "LLM call": statistic_data["llm_call"], + "LLM cache": statistic_data["llm_cache"], + } + ) + + # Ensure all indexes are updated after each document + await self._insert_done() + async def _insert_done(self): tasks = [] for storage_inst in [ diff --git a/lightrag/operate.py b/lightrag/operate.py index 7df489b3..e1406904 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -20,6 +20,7 @@ from .utils import ( handle_cache, save_to_cache, CacheData, + statistic_data, ) from .base import ( BaseGraphStorage, @@ -96,6 +97,10 @@ async def _handle_entity_relation_summary( description: str, global_config: dict, ) -> str: + """Handle entity relation summary + For each entity or relation, input is the combined description of already existing description and new description. + If too long, use LLM to summarize. + """ use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] @@ -176,6 +181,7 @@ async def _merge_nodes_then_upsert( knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): + """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] @@ -356,7 +362,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages) + history = json.dumps(history_messages, ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -368,8 +374,10 @@ async def extract_entities( if need_to_restore: llm_response_cache.global_config = global_config if cached_return: + logger.debug(f"Found cache for {arg_hash}") + statistic_data["llm_cache"] += 1 return cached_return - + statistic_data["llm_call"] += 1 if history_messages: res: str = await use_llm_func( input_text, history_messages=history_messages @@ -388,6 +396,11 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """ "Prpocess a single chunk + Args: + chunk_key_dp (tuple[str, TextChunkSchema]): + ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + """ nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] @@ -451,10 +464,8 @@ async def extract_entities( now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) ] - print( + logger.debug( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", - end="", - flush=True, ) return dict(maybe_nodes), dict(maybe_edges) @@ -462,8 +473,10 @@ async def extract_entities( for result in tqdm_async( asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), total=len(ordered_chunks), - desc="Extracting entities from chunks", + desc="Level 2 - Extracting entities and relationships", unit="chunk", + position=1, + leave=False, ): results.append(await result) @@ -474,7 +487,7 @@ async def extract_entities( maybe_nodes[k].extend(v) for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) - logger.info("Inserting entities into storage...") + logger.debug("Inserting entities into storage...") all_entities_data = [] for result in tqdm_async( asyncio.as_completed( @@ -484,12 +497,14 @@ async def extract_entities( ] ), total=len(maybe_nodes), - desc="Inserting entities", + desc="Level 3 - Inserting entities", unit="entity", + position=2, + leave=False, ): all_entities_data.append(await result) - logger.info("Inserting relationships into storage...") + logger.debug("Inserting relationships into storage...") all_relationships_data = [] for result in tqdm_async( asyncio.as_completed( @@ -501,8 +516,10 @@ async def extract_entities( ] ), total=len(maybe_edges), - desc="Inserting relationships", + desc="Level 3 - Inserting relationships", unit="relationship", + position=3, + leave=False, ): all_relationships_data.append(await result) diff --git a/lightrag/utils.py b/lightrag/utils.py index 1f6bf405..ce556ab2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -30,13 +30,18 @@ class UnlimitedSemaphore: ENCODER = None +statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} + logger = logging.getLogger("lightrag") +# Set httpx logging level to WARNING +logging.getLogger("httpx").setLevel(logging.WARNING) + def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file) + file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return None, None, None, None # For naive mode, only use simple cache matching - if mode == "naive": + # if mode == "naive": + if mode == "default": if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: @@ -473,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + embedding_model_func = hashing_kv.global_config[ + "embedding_func" + ].func # ["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt])