diff --git a/.gitignore b/.gitignore index 8ac420b1..ec95f8a5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,4 @@ rag_storage venv/ examples/input/ examples/output/ -test_results.json +.DS_Store diff --git a/README.md b/README.md index 9e3d87ed..1d25f378 100644 --- a/README.md +++ b/README.md @@ -330,6 +330,26 @@ rag = LightRAG( with open("./newText.txt") as f: rag.insert(f.read()) ``` +### Separate Keyword Extraction +We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords. + +##### How It Works? +The function operates by dividing the input into two parts: +- `User Query` +- `Prompt` + +It then performs keyword extraction exclusively on the `user query`. This separation ensures that the extraction process is focused and relevant, unaffected by any additional language in the `prompt`. It also allows the `prompt` to serve purely for response formatting, maintaining the intent and clarity of the user's original question. + +##### Usage Example +This `example` shows how to tailor the function for educational content, focusing on detailed explanations for older students. + +```python +rag.query_with_separate_keyword_extraction( + query="Explain the law of gravity", + prompt="Provide a detailed explanation suitable for high school students studying physics.", + param=QueryParam(mode="hybrid") +) +``` ### Using Neo4J for Storage @@ -361,6 +381,7 @@ see test_neo4j.py for a working example. ### Using PostgreSQL for Storage For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac. +* If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) * Create index for AGE example: (Change below `dickens` to your graph name if necessary) ``` diff --git a/examples/copy_llm_cache_to_another_storage.py b/examples/copy_llm_cache_to_another_storage.py new file mode 100644 index 00000000..b9378c7c --- /dev/null +++ b/examples/copy_llm_cache_to_another_storage.py @@ -0,0 +1,97 @@ +""" +Sometimes you need to switch a storage solution, but you want to save LLM token and time. +This handy script helps you to copy the LLM caches from one storage solution to another. +(Not all the storage impl are supported) +""" + +import asyncio +import logging +import os +from dotenv import load_dotenv + +from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage +from lightrag.storage import JsonKVStorage + +load_dotenv() +ROOT_DIR = os.environ.get("ROOT_DIR") +WORKING_DIR = f"{ROOT_DIR}/dickens" + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# AGE +os.environ["AGE_GRAPH_NAME"] = "chinese" + +postgres_db = PostgreSQLDB( + config={ + "host": "localhost", + "port": 15432, + "user": "rag", + "password": "rag", + "database": "r2", + } +) + + +async def copy_from_postgres_to_json(): + await postgres_db.initdb() + + from_llm_response_cache = PGKVStorage( + namespace="llm_response_cache", + global_config={"embedding_batch_num": 6}, + embedding_func=None, + db=postgres_db, + ) + + to_llm_response_cache = JsonKVStorage( + namespace="llm_response_cache", + global_config={"working_dir": WORKING_DIR}, + embedding_func=None, + ) + + kv = {} + for c_id in await from_llm_response_cache.all_keys(): + print(f"Copying {c_id}") + workspace = c_id["workspace"] + mode = c_id["mode"] + _id = c_id["id"] + postgres_db.workspace = workspace + obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id) + if mode not in kv: + kv[mode] = {} + kv[mode][_id] = obj[_id] + print(f"Object {obj}") + await to_llm_response_cache.upsert(kv) + await to_llm_response_cache.index_done_callback() + print("Mission accomplished!") + + +async def copy_from_json_to_postgres(): + await postgres_db.initdb() + + from_llm_response_cache = JsonKVStorage( + namespace="llm_response_cache", + global_config={"working_dir": WORKING_DIR}, + embedding_func=None, + ) + + to_llm_response_cache = PGKVStorage( + namespace="llm_response_cache", + global_config={"embedding_batch_num": 6}, + embedding_func=None, + db=postgres_db, + ) + + for mode in await from_llm_response_cache.all_keys(): + print(f"Copying {mode}") + caches = await from_llm_response_cache.get_by_id(mode) + for k, v in caches.items(): + item = {mode: {k: v}} + print(f"\tCopying {item}") + await to_llm_response_cache.upsert(item) + + +if __name__ == "__main__": + asyncio.run(copy_from_json_to_postgres()) diff --git a/get_all_edges_nx.py b/examples/get_all_edges_nx.py similarity index 100% rename from get_all_edges_nx.py rename to examples/get_all_edges_nx.py diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index bbb69319..6de6e0a7 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/" APIKEY = "ocigenerativeai" CHATMODEL = "cohere.command-r-plus" EMBEDMODEL = "cohere.embed-multilingual-v3.0" - +CHUNK_TOKEN_SIZE = 1024 +MAX_TOKENS = 4000 if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -86,30 +87,46 @@ async def main(): # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( - enable_llm_cache=False, + # log_level="DEBUG", working_dir=WORKING_DIR, - chunk_token_size=512, + entity_extract_max_gleaning=1, + enable_llm_cache=True, + enable_llm_cache_for_entity_extract=True, + embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90}, + chunk_token_size=CHUNK_TOKEN_SIZE, + llm_model_max_token_size=MAX_TOKENS, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( embedding_dim=embedding_dimension, - max_token_size=512, + max_token_size=500, func=embedding_func, ), graph_storage="OracleGraphStorage", kv_storage="OracleKVStorage", vector_storage="OracleVectorDBStorage", + addon_params={ + "example_number": 1, + "language": "Simplfied Chinese", + "entity_types": ["organization", "person", "geo", "event"], + "insert_batch_size": 2, + }, ) # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.graph_storage_cls.db = oracle_db - rag.key_string_value_json_storage_cls.db = oracle_db - rag.vector_db_storage_cls.db = oracle_db - # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c - rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func + rag.set_storage_client(db_client=oracle_db) # Extract and Insert into LightRAG storage - with open("./dickens/demo.txt", "r", encoding="utf-8") as f: - await rag.ainsert(f.read()) + with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: + all_text = f.read() + texts = [x for x in all_text.split("\n") if x] + + # New mode use pipeline + await rag.apipeline_process_documents(texts) + await rag.apipeline_process_chunks() + await rag.apipeline_process_extract_graph() + + # Old method use ainsert + # await rag.ainsert(texts) # Perform search in different modes modes = ["naive", "local", "global", "hybrid"] diff --git a/examples/query_keyword_separation_example.py b/examples/query_keyword_separation_example.py new file mode 100644 index 00000000..f11ce8c1 --- /dev/null +++ b/examples/query_keyword_separation_example.py @@ -0,0 +1,116 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc +import numpy as np +from dotenv import load_dotenv +import logging +from openai import AzureOpenAI + +logging.basicConfig(level=logging.INFO) + +load_dotenv() + +AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") +AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") +AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") +AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + +AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") +AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") + +WORKING_DIR = "./dickens" + +if os.path.exists(WORKING_DIR): + import shutil + + shutil.rmtree(WORKING_DIR) + +os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name". + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + return chat_completion.choices[0].message.content + + +async def embedding_func(texts: list[str]) -> np.ndarray: + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_EMBEDDING_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts) + + embeddings = [item.embedding for item in embedding.data] + return np.array(embeddings) + + +async def test_funcs(): + result = await llm_model_func("How are you?") + print("Resposta do llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("Resultado do embedding_func: ", result.shape) + print("Dimensão da embedding: ", result.shape[1]) + + +asyncio.run(test_funcs()) + +embedding_dimension = 3072 + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), +) + +book1 = open("./book_1.txt", encoding="utf-8") +book2 = open("./book_2.txt", encoding="utf-8") + +rag.insert([book1.read(), book2.read()]) + + +# Example function demonstrating the new query_with_separate_keyword_extraction usage +async def run_example(): + query = "What are the top themes in this story?" + prompt = "Please simplify the response for a young audience." + + # Using the new method to ensure the keyword extraction is only applied to the query + response = rag.query_with_separate_keyword_extraction( + query=query, + prompt=prompt, + param=QueryParam(mode="hybrid"), # Adjust QueryParam mode as necessary + ) + + print("Extracted Response:", response) + + +# Run the example asynchronously +if __name__ == "__main__": + asyncio.run(run_example()) diff --git a/test.py b/examples/test.py similarity index 100% rename from test.py rename to examples/test.py diff --git a/test_chromadb.py b/examples/test_chromadb.py similarity index 100% rename from test_chromadb.py rename to examples/test_chromadb.py diff --git a/test_neo4j.py b/examples/test_neo4j.py similarity index 100% rename from test_neo4j.py rename to examples/test_neo4j.py diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 7a26a282..b6317d84 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.1.1" +__version__ = "1.1.2" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/base.py b/lightrag/base.py index 94a39cf3..7b3504d0 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -31,6 +31,8 @@ class QueryParam: max_token_for_global_context: int = 4000 # Number of tokens for the entity descriptions max_token_for_local_context: int = 4000 + hl_keywords: list[str] = field(default_factory=list) + ll_keywords: list[str] = field(default_factory=list) @dataclass diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 34745312..e30b6909 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -153,8 +153,6 @@ class OracleDB: if data is None: await cursor.execute(sql) else: - # print(data) - # print(sql) await cursor.execute(sql, data) await connection.commit() except Exception as e: @@ -167,35 +165,64 @@ class OracleDB: @dataclass class OracleKVStorage(BaseKVStorage): # should pass db object to self.db + db: OracleDB = None + meta_fields = None + def __post_init__(self): self._data = {} - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> Union[dict, None]: - """根据 id 获取 doc_full 数据.""" + """get doc_full data based on id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"workspace": self.db.workspace, "id": id} # print("get_by_id:"+SQL) - res = await self.db.query(SQL, params) + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + else: + res = await self.db.query(SQL, params) if res: - data = res # {"data":res} - # print (data) - return data + return res + else: + return None + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id} + if "llm_response_cache" == self.namespace: + array_res = await self.db.query(SQL, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res else: return None - # Query by id async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]: - """根据 id 获取 doc_chunks 数据""" + """get doc_chunks data based on id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) params = {"workspace": self.db.workspace} # print("get_by_ids:"+SQL) - # print(params) res = await self.db.query(SQL, params, multirows=True) + if "llm_response_cache" == self.namespace: + modes = set() + dict_res: dict[str, dict] = {} + for row in res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in res: + dict_res[row["mode"]][row["id"]] = row + res = [{k: v} for k, v in dict_res.items()] if res: data = res # [{"data":i} for i in res] # print(data) @@ -203,38 +230,43 @@ class OracleKVStorage(BaseKVStorage): else: return None + async def get_by_status_and_ids( + self, status: str, ids: list[str] + ) -> Union[list[dict], None]: + """Specifically for llm_response_cache.""" + if ids is not None: + SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + else: + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + res = await self.db.query(SQL, params, multirows=True) + if res: + return res + else: + return None + async def filter_keys(self, keys: list[str]) -> set[str]: - """过滤掉重复内容""" + """Return keys that don't exist in storage""" SQL = SQL_TEMPLATES["filter_keys"].format( table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys]) ) params = {"workspace": self.db.workspace} - try: - await self.db.query(SQL, params) - except Exception as e: - logger.error(f"Oracle database error: {e}") - print(SQL) - print(params) res = await self.db.query(SQL, params, multirows=True) - data = None if res: exist_keys = [key["id"] for key in res] data = set([s for s in keys if s not in exist_keys]) + return data else: - exist_keys = [] - data = set([s for s in keys if s not in exist_keys]) - return data + return set(keys) ################ INSERT METHODS ################ async def upsert(self, data: dict[str, dict]): - left_data = {k: v for k, v in data.items() if k not in self._data} - self._data.update(left_data) - # print(self._data) - # values = [] if self.namespace == "text_chunks": list_data = [ { - "__id__": k, + "id": k, **{k1: v1 for k1, v1 in v.items()}, } for k, v in data.items() @@ -250,35 +282,50 @@ class OracleKVStorage(BaseKVStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] - # print(list_data) + + merge_sql = SQL_TEMPLATES["merge_chunk"] for item in list_data: - merge_sql = SQL_TEMPLATES["merge_chunk"] - data = { - "check_id": item["__id__"], - "id": item["__id__"], + _data = { + "id": item["id"], "content": item["content"], "workspace": self.db.workspace, "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], "content_vector": item["__vector__"], + "status": item["status"], } - # print(merge_sql) - await self.db.execute(merge_sql, data) - + await self.db.execute(merge_sql, _data) if self.namespace == "full_docs": - for k, v in self._data.items(): + for k, v in data.items(): # values.clear() merge_sql = SQL_TEMPLATES["merge_doc_full"] - data = { - "check_id": k, + _data = { "id": k, "content": v["content"], "workspace": self.db.workspace, } - # print(merge_sql) - await self.db.execute(merge_sql, data) - return left_data + await self.db.execute(merge_sql, _data) + + if self.namespace == "llm_response_cache": + for mode, items in data.items(): + for k, v in items.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "cache_mode": mode, + } + + await self.db.execute(upsert_sql, _data) + return None + + async def change_status(self, id: str, status: str): + SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace]) + params = {"workspace": self.db.workspace, "id": id, "status": status} + await self.db.execute(SQL, params) async def index_done_callback(self): if self.namespace in ["full_docs", "text_chunks"]: @@ -287,6 +334,8 @@ class OracleKVStorage(BaseKVStorage): @dataclass class OracleVectorDBStorage(BaseVectorStorage): + # should pass db object to self.db + db: OracleDB = None cosine_better_than_threshold: float = 0.2 def __post_init__(self): @@ -328,7 +377,7 @@ class OracleGraphStorage(BaseGraphStorage): def __post_init__(self): """从graphml文件加载图""" - self._max_batch_size = self.global_config["embedding_batch_num"] + self._max_batch_size = self.global_config.get("embedding_batch_num", 10) #################### insert method ################ @@ -362,7 +411,6 @@ class OracleGraphStorage(BaseGraphStorage): "content": content, "content_vector": content_vector, } - # print(merge_sql) await self.db.execute(merge_sql, data) # self._graph.add_node(node_id, **node_data) @@ -564,20 +612,26 @@ N_T = { TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id varchar(256)PRIMARY KEY, + id varchar(256), workspace varchar(1024), doc_name varchar(1024), content CLOB, meta JSON, + content_summary varchar(1024), + content_length NUMBER, + status varchar(256), + chunks_count NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updatetime TIMESTAMP DEFAULT NULL + updatetime TIMESTAMP DEFAULT NULL, + error varchar(4096) )""" }, "LIGHTRAG_DOC_CHUNKS": { "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - id varchar(256) PRIMARY KEY, + id varchar(256), workspace varchar(1024), full_doc_id varchar(256), + status varchar(256), chunk_order_index NUMBER, tokens NUMBER, content CLOB, @@ -619,9 +673,15 @@ TABLES = { "LIGHTRAG_LLM_CACHE": { "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( id varchar(256) PRIMARY KEY, - send clob, - return clob, - model varchar(1024), + workspace varchar(1024), + cache_mode varchar(256), + model_name varchar(256), + original_prompt clob, + return_value clob, + embedding CLOB, + embedding_shape NUMBER, + embedding_min NUMBER, + embedding_max NUMBER, createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updatetime TIMESTAMP DEFAULT NULL )""" @@ -646,23 +706,44 @@ TABLES = { SQL_TEMPLATES = { # SQL for KVStorage - "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", - "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", - "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})", - "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id", + "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id", + "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""", + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""", + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode" + FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""", + "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})", + "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})", + "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})", + "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})", + "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status", + "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status", "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})", - "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace) values(:id,:content,:workspace) - """, - "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a - USING DUAL - ON (a.id = :check_id) - WHEN NOT MATCHED THEN - INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector) - values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """, + "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id", + "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a + USING DUAL + ON (a.id = :id and a.workspace = :workspace) + WHEN NOT MATCHED THEN + INSERT(id,content,workspace) values(:id,:content,:workspace)""", + "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS + USING DUAL + ON (id = :id and workspace = :workspace) + WHEN NOT MATCHED THEN INSERT + (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status) + values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """, + "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a + USING DUAL + ON (a.id = :id) + WHEN NOT MATCHED THEN + INSERT (workspace,id,original_prompt,return_value,cache_mode) + VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode) + WHEN MATCHED THEN UPDATE + SET original_prompt = :original_prompt, + return_value = :return_value, + cache_mode = :cache_mode, + updatetime = SYSDATE""", # SQL for VectorStorage "entities": """SELECT name as entity_name FROM (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance @@ -714,16 +795,22 @@ SQL_TEMPLATES = { COLUMNS (a.name as source_name,b.name as target_name))""", "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a USING DUAL - ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.name=:name) WHEN NOT MATCHED THEN INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector) - values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a USING DUAL - ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id) + ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name) WHEN NOT MATCHED THEN INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector) - values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """, + values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) + WHEN MATCHED THEN + UPDATE SET + weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""", "get_all_nodes": """WITH t0 AS ( SELECT name AS id, entity_type AS label, entity_type, description, '["' || replace(source_chunk_id, '', '","') || '"]' source_chunk_ids diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index b93a345b..86072c9f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -231,6 +231,16 @@ class PGKVStorage(BaseKVStorage): else: return None + async def all_keys(self) -> list[dict]: + if "llm_response_cache" == self.namespace: + sql = "select workspace,mode,id from lightrag_llm_cache" + res = await self.db.query(sql, multirows=True) + return res + else: + logger.error( + f"all_keys is only implemented for llm_response_cache, not for {self.namespace}" + ) + async def filter_keys(self, keys: List[str]) -> Set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( @@ -412,7 +422,10 @@ class PGDocStatusStorage(DocStatusStorage): async def filter_keys(self, data: list[str]) -> set[str]: """Return keys that don't exist in storage""" - sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})" + keys = ",".join([f"'{_id}'" for _id in data]) + sql = ( + f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})" + ) result = await self.db.query(sql, {"workspace": self.db.workspace}, True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 596fbdbf..ad79afaa 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,8 @@ from .operate import ( kg_query, naive_query, mix_kg_vector_query, + extract_keywords_only, + kg_query_with_keywords, ) from .utils import ( @@ -26,6 +28,7 @@ from .utils import ( convert_response_to_json, logger, set_logger, + statistic_data, ) from .base import ( BaseGraphStorage, @@ -36,21 +39,30 @@ from .base import ( DocStatus, ) -from .storage import ( - JsonKVStorage, - NanoVectorDBStorage, - NetworkXStorage, - JsonDocStatusStorage, -) - from .prompt import GRAPH_FIELD_SEP - -# future KG integrations - -# from .kg.ArangoDB_impl import ( -# GraphStorage as ArangoDBStorage -# ) +STORAGES = { + "JsonKVStorage": ".storage", + "NanoVectorDBStorage": ".storage", + "NetworkXStorage": ".storage", + "JsonDocStatusStorage": ".storage", + "Neo4JStorage": ".kg.neo4j_impl", + "OracleKVStorage": ".kg.oracle_impl", + "OracleGraphStorage": ".kg.oracle_impl", + "OracleVectorDBStorage": ".kg.oracle_impl", + "MilvusVectorDBStorge": ".kg.milvus_impl", + "MongoKVStorage": ".kg.mongo_impl", + "ChromaVectorDBStorage": ".kg.chroma_impl", + "TiDBKVStorage": ".kg.tidb_impl", + "TiDBVectorDBStorage": ".kg.tidb_impl", + "TiDBGraphStorage": ".kg.tidb_impl", + "PGKVStorage": ".kg.postgres_impl", + "PGVectorStorage": ".kg.postgres_impl", + "AGEStorage": ".kg.age_impl", + "PGGraphStorage": ".kg.postgres_impl", + "GremlinStorage": ".kg.gremlin_impl", + "PGDocStatusStorage": ".kg.postgres_impl", +} def lazy_external_import(module_name: str, class_name: str): @@ -66,34 +78,13 @@ def lazy_external_import(module_name: str, class_name: str): def import_class(*args, **kwargs): import importlib - # Import the module using importlib module = importlib.import_module(module_name, package=package) - - # Get the class from the module and instantiate it cls = getattr(module, class_name) return cls(*args, **kwargs) return import_class -Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") -OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") -OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") -OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") -MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") -MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") -ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage") -TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage") -TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage") -TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage") -PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage") -PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage") -AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage") -PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage") -GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage") -PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage") - - def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. @@ -197,34 +188,51 @@ class LightRAG: logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") - - _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) - logger.debug(f"LightRAG init with param:\n {_print_config}\n") - - # @TODO: should move all storage setup here to leverage initial start params attached to self. - - self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( - self._get_storage_class()[self.kv_storage] - ) - self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[ - self.vector_storage - ] - self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[ - self.graph_storage - ] - if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) - self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), + # show config + global_config = asdict(self) + _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) + logger.debug(f"LightRAG init with param:\n {_print_config}\n") + + # Init LLM + self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( + self.embedding_func + ) + + # Initialize all storages + self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( + self._get_storage_class(self.kv_storage) + ) + self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( + self.vector_storage + ) + self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( + self.graph_storage + ) + + self.key_string_value_json_storage_cls = partial( + self.key_string_value_json_storage_cls, global_config=global_config + ) + + self.vector_db_storage_cls = partial( + self.vector_db_storage_cls, global_config=global_config + ) + + self.graph_storage_cls = partial( + self.graph_storage_cls, global_config=global_config + ) + + self.json_doc_status_storage = self.key_string_value_json_storage_cls( + namespace="json_doc_status_storage", embedding_func=None, ) - self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( - self.embedding_func + self.llm_response_cache = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, ) #### @@ -232,17 +240,14 @@ class LightRAG: #### self.full_docs = self.key_string_value_json_storage_cls( namespace="full_docs", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.text_chunks = self.key_string_value_json_storage_cls( namespace="text_chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph = self.graph_storage_cls( namespace="chunk_entity_relation", - global_config=asdict(self), embedding_func=self.embedding_func, ) #### @@ -251,72 +256,69 @@ class LightRAG: self.entities_vdb = self.vector_db_storage_cls( namespace="entities", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"entity_name"}, ) self.relationships_vdb = self.vector_db_storage_cls( namespace="relationships", - global_config=asdict(self), embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"}, ) self.chunks_vdb = self.vector_db_storage_cls( namespace="chunks", - global_config=asdict(self), embedding_func=self.embedding_func, ) + if self.llm_response_cache and hasattr( + self.llm_response_cache, "global_config" + ): + hashing_kv = self.llm_response_cache + else: + hashing_kv = self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + embedding_func=None, + ) + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, - hashing_kv=self.llm_response_cache - if self.llm_response_cache - and hasattr(self.llm_response_cache, "global_config") - else self.key_string_value_json_storage_cls( - namespace="llm_response_cache", - global_config=asdict(self), - embedding_func=None, - ), + hashing_kv=hashing_kv, **self.llm_model_kwargs, ) ) # Initialize document status storage - self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage] + self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) self.doc_status = self.doc_status_storage_cls( namespace="doc_status", - global_config=asdict(self), + global_config=global_config, embedding_func=None, ) - def _get_storage_class(self) -> dict: - return { - # kv storage - "JsonKVStorage": JsonKVStorage, - "OracleKVStorage": OracleKVStorage, - "MongoKVStorage": MongoKVStorage, - "TiDBKVStorage": TiDBKVStorage, - # vector storage - "NanoVectorDBStorage": NanoVectorDBStorage, - "OracleVectorDBStorage": OracleVectorDBStorage, - "MilvusVectorDBStorge": MilvusVectorDBStorge, - "ChromaVectorDBStorage": ChromaVectorDBStorage, - "TiDBVectorDBStorage": TiDBVectorDBStorage, - # graph storage - "NetworkXStorage": NetworkXStorage, - "Neo4JStorage": Neo4JStorage, - "OracleGraphStorage": OracleGraphStorage, - "AGEStorage": AGEStorage, - "PGGraphStorage": PGGraphStorage, - "PGKVStorage": PGKVStorage, - "PGDocStatusStorage": PGDocStatusStorage, - "PGVectorStorage": PGVectorStorage, - "TiDBGraphStorage": TiDBGraphStorage, - "GremlinStorage": GremlinStorage, - # "ArangoDBStorage": ArangoDBStorage - "JsonDocStatusStorage": JsonDocStatusStorage, - } + def _get_storage_class(self, storage_name: str) -> dict: + import_path = STORAGES[storage_name] + storage_class = lazy_external_import(import_path, storage_name) + return storage_class + + def set_storage_client(self, db_client): + # Now only tested on Oracle Database + for storage in [ + self.vector_db_storage_cls, + self.graph_storage_cls, + self.doc_status, + self.full_docs, + self.text_chunks, + self.llm_response_cache, + self.key_string_value_json_storage_cls, + self.chunks_vdb, + self.relationships_vdb, + self.entities_vdb, + self.graph_storage_cls, + self.chunk_entity_relation_graph, + self.llm_response_cache, + ]: + # set client + storage.db = db_client def insert( self, string_or_strings, split_by_character=None, split_by_character_only=False @@ -538,6 +540,195 @@ class LightRAG: if update_storage: await self._insert_done() + async def apipeline_process_documents(self, string_or_strings): + """Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs + Args: + string_or_strings: Single document string or list of document strings + """ + if isinstance(string_or_strings, str): + string_or_strings = [string_or_strings] + + # 1. Remove duplicate contents from the list + unique_contents = list(set(doc.strip() for doc in string_or_strings)) + + logger.info( + f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents" + ) + + # 2. Generate document IDs and initial status + new_docs = { + compute_mdhash_id(content, prefix="doc-"): { + "content": content, + "content_summary": self._get_content_summary(content), + "content_length": len(content), + "status": DocStatus.PENDING, + "created_at": datetime.now().isoformat(), + "updated_at": None, + } + for content in unique_contents + } + + # 3. Filter out already processed documents + _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) + if len(_not_stored_doc_keys) < len(new_docs): + logger.info( + f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents" + ) + new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys} + + if not new_docs: + logger.info("All documents have been processed or are duplicates") + return None + + # 4. Store original document + for doc_id, doc in new_docs.items(): + await self.full_docs.upsert({doc_id: {"content": doc["content"]}}) + await self.full_docs.change_status(doc_id, DocStatus.PENDING) + logger.info(f"Stored {len(new_docs)} new unique documents") + + async def apipeline_process_chunks(self): + """Get pendding documents, split into chunks,insert chunks""" + # 1. get all pending and failed documents + _todo_doc_keys = [] + _failed_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_doc = await self.full_docs.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) + if _failed_doc: + _todo_doc_keys.extend([doc["id"] for doc in _failed_doc]) + if _pendding_doc: + _todo_doc_keys.extend([doc["id"] for doc in _pendding_doc]) + if not _todo_doc_keys: + logger.info("All documents have been processed or are duplicates") + return None + else: + logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents") + + new_docs = { + doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys) + } + + # 2. split docs into chunks, insert chunks, update doc status + chunk_cnt = 0 + batch_size = self.addon_params.get("insert_batch_size", 10) + for i in range(0, len(new_docs), batch_size): + batch_docs = dict(list(new_docs.items())[i : i + batch_size]) + for doc_id, doc in tqdm_async( + batch_docs.items(), + desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}", + ): + try: + # Generate chunks from document + chunks = { + compute_mdhash_id(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_id, + "status": DocStatus.PENDING, + } + for dp in chunking_by_token_size( + doc["content"], + overlap_token_size=self.chunk_overlap_token_size, + max_token_size=self.chunk_token_size, + tiktoken_model=self.tiktoken_model_name, + ) + } + chunk_cnt += len(chunks) + await self.text_chunks.upsert(chunks) + await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED) + + try: + # Store chunks in vector database + await self.chunks_vdb.upsert(chunks) + # Update doc status + await self.full_docs.change_status(doc_id, DocStatus.PROCESSED) + except Exception as e: + # Mark as failed if any step fails + await self.full_docs.change_status(doc_id, DocStatus.FAILED) + raise e + except Exception as e: + import traceback + + error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + continue + logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents") + + async def apipeline_process_extract_graph(self): + """Get pendding or failed chunks, extract entities and relationships from each chunk""" + # 1. get all pending and failed chunks + _todo_chunk_keys = [] + _failed_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.FAILED, ids=None + ) + _pendding_chunks = await self.text_chunks.get_by_status_and_ids( + status=DocStatus.PENDING, ids=None + ) + if _failed_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks]) + if _pendding_chunks: + _todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks]) + if not _todo_chunk_keys: + logger.info("All chunks have been processed or are duplicates") + return None + + # Process documents in batches + batch_size = self.addon_params.get("insert_batch_size", 10) + + semaphore = asyncio.Semaphore( + batch_size + ) # Control the number of tasks that are processed simultaneously + + async def process_chunk(chunk_id): + async with semaphore: + chunks = { + i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id]) + } + # Extract and store entities and relationships + try: + maybe_new_kg = await extract_entities( + chunks, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entity_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + llm_response_cache=self.llm_response_cache, + global_config=asdict(self), + ) + if maybe_new_kg is None: + logger.info("No entities or relationships extracted!") + # Update status to processed + await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED) + except Exception as e: + logger.error("Failed to extract entities and relationships") + # Mark as failed if any step fails + await self.text_chunks.change_status(chunk_id, DocStatus.FAILED) + raise e + + with tqdm_async( + total=len(_todo_chunk_keys), + desc="\nLevel 1 - Processing chunks", + unit="chunk", + position=0, + ) as progress: + tasks = [] + for chunk_id in _todo_chunk_keys: + task = asyncio.create_task(process_chunk(chunk_id)) + tasks.append(task) + + for future in asyncio.as_completed(tasks): + await future + progress.update(1) + progress.set_postfix( + { + "LLM call": statistic_data["llm_call"], + "LLM cache": statistic_data["llm_cache"], + } + ) + + # Ensure all indexes are updated after each document + await self._insert_done() + async def _insert_done(self): tasks = [] for storage_inst in [ @@ -753,6 +944,114 @@ class LightRAG: await self._query_done() return response + def query_with_separate_keyword_extraction( + self, query: str, prompt: str, param: QueryParam = QueryParam() + ): + """ + 1. Extract keywords from the 'query' using new function in operate.py. + 2. Then run the standard aquery() flow with the final prompt (formatted_question). + """ + + loop = always_get_an_event_loop() + return loop.run_until_complete( + self.aquery_with_separate_keyword_extraction(query, prompt, param) + ) + + async def aquery_with_separate_keyword_extraction( + self, query: str, prompt: str, param: QueryParam = QueryParam() + ): + """ + 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. + 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. + """ + + # --------------------- + # STEP 1: Keyword Extraction + # --------------------- + # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords). + hl_keywords, ll_keywords = await extract_keywords_only( + text=query, + param=param, + global_config=asdict(self), + hashing_kv=self.llm_response_cache + or self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + + param.hl_keywords = (hl_keywords,) + param.ll_keywords = (ll_keywords,) + + # --------------------- + # STEP 2: Final Query Logic + # --------------------- + + # Create a new string with the prompt and the keywords + ll_keywords_str = ", ".join(ll_keywords) + hl_keywords_str = ", ".join(hl_keywords) + formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" + + if param.mode in ["local", "global", "hybrid"]: + response = await kg_query_with_keywords( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "naive": + response = await naive_query( + formatted_question, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "mix": + response = await mix_kg_vector_query( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + else: + raise ValueError(f"Unknown mode {param.mode}") + + await self._query_done() + return response + async def _query_done(self): tasks = [] for storage_inst in [self.llm_response_cache]: diff --git a/lightrag/operate.py b/lightrag/operate.py index 7216c07f..e1406904 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -20,6 +20,7 @@ from .utils import ( handle_cache, save_to_cache, CacheData, + statistic_data, ) from .base import ( BaseGraphStorage, @@ -96,6 +97,10 @@ async def _handle_entity_relation_summary( description: str, global_config: dict, ) -> str: + """Handle entity relation summary + For each entity or relation, input is the combined description of already existing description and new description. + If too long, use LLM to summarize. + """ use_llm_func: callable = global_config["llm_model_func"] llm_max_tokens = global_config["llm_model_max_token_size"] tiktoken_model_name = global_config["tiktoken_model_name"] @@ -176,6 +181,7 @@ async def _merge_nodes_then_upsert( knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): + """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" already_entity_types = [] already_source_ids = [] already_description = [] @@ -356,7 +362,7 @@ async def extract_entities( llm_response_cache.global_config = new_config need_to_restore = True if history_messages: - history = json.dumps(history_messages) + history = json.dumps(history_messages, ensure_ascii=False) _prompt = history + "\n" + input_text else: _prompt = input_text @@ -368,8 +374,10 @@ async def extract_entities( if need_to_restore: llm_response_cache.global_config = global_config if cached_return: + logger.debug(f"Found cache for {arg_hash}") + statistic_data["llm_cache"] += 1 return cached_return - + statistic_data["llm_call"] += 1 if history_messages: res: str = await use_llm_func( input_text, history_messages=history_messages @@ -388,6 +396,11 @@ async def extract_entities( return await use_llm_func(input_text) async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """ "Prpocess a single chunk + Args: + chunk_key_dp (tuple[str, TextChunkSchema]): + ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + """ nonlocal already_processed, already_entities, already_relations chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] @@ -451,10 +464,8 @@ async def extract_entities( now_ticks = PROMPTS["process_tickers"][ already_processed % len(PROMPTS["process_tickers"]) ] - print( + logger.debug( f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", - end="", - flush=True, ) return dict(maybe_nodes), dict(maybe_edges) @@ -462,8 +473,10 @@ async def extract_entities( for result in tqdm_async( asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]), total=len(ordered_chunks), - desc="Extracting entities from chunks", + desc="Level 2 - Extracting entities and relationships", unit="chunk", + position=1, + leave=False, ): results.append(await result) @@ -474,7 +487,7 @@ async def extract_entities( maybe_nodes[k].extend(v) for k, v in m_edges.items(): maybe_edges[tuple(sorted(k))].extend(v) - logger.info("Inserting entities into storage...") + logger.debug("Inserting entities into storage...") all_entities_data = [] for result in tqdm_async( asyncio.as_completed( @@ -484,12 +497,14 @@ async def extract_entities( ] ), total=len(maybe_nodes), - desc="Inserting entities", + desc="Level 3 - Inserting entities", unit="entity", + position=2, + leave=False, ): all_entities_data.append(await result) - logger.info("Inserting relationships into storage...") + logger.debug("Inserting relationships into storage...") all_relationships_data = [] for result in tqdm_async( asyncio.as_completed( @@ -501,8 +516,10 @@ async def extract_entities( ] ), total=len(maybe_edges), - desc="Inserting relationships", + desc="Level 3 - Inserting relationships", unit="relationship", + position=3, + leave=False, ): all_relationships_data.append(await result) @@ -681,6 +698,219 @@ async def kg_query( return response +async def kg_query_with_keywords( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> str: + """ + Refactored kg_query that does NOT extract keywords by itself. + It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. + Then it uses those to build context and produce a final LLM response. + """ + + # --------------------------- + # 0) Handle potential cache + # --------------------------- + use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + + # --------------------------- + # 1) RETRIEVE KEYWORDS FROM query_param + # --------------------------- + + # If these fields don't exist, default to empty lists/strings. + hl_keywords = getattr(query_param, "hl_keywords", []) or [] + ll_keywords = getattr(query_param, "ll_keywords", []) or [] + + # If neither has any keywords, you could handle that logic here. + if not hl_keywords and not ll_keywords: + logger.warning( + "No keywords found in query_param. Could default to global mode or fail." + ) + return PROMPTS["fail_response"] + if not ll_keywords and query_param.mode in ["local", "hybrid"]: + logger.warning("low_level_keywords is empty, switching to global mode.") + query_param.mode = "global" + if not hl_keywords and query_param.mode in ["global", "hybrid"]: + logger.warning("high_level_keywords is empty, switching to local mode.") + query_param.mode = "local" + + # Flatten low-level and high-level keywords if needed + ll_keywords_flat = ( + [item for sublist in ll_keywords for item in sublist] + if any(isinstance(i, list) for i in ll_keywords) + else ll_keywords + ) + hl_keywords_flat = ( + [item for sublist in hl_keywords for item in sublist] + if any(isinstance(i, list) for i in hl_keywords) + else hl_keywords + ) + + # Join the flattened lists + ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else "" + hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else "" + + keywords = [ll_keywords_str, hl_keywords_str] + + logger.info("Using %s mode for query processing", query_param.mode) + + # --------------------------- + # 2) BUILD CONTEXT + # --------------------------- + context = await _build_query_context( + keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + ) + if not context: + return PROMPTS["fail_response"] + + # If only context is needed, return it + if query_param.only_need_context: + return context + + # --------------------------- + # 3) BUILD THE SYSTEM PROMPT + CALL LLM + # --------------------------- + sys_prompt_temp = PROMPTS["rag_response"] + sys_prompt = sys_prompt_temp.format( + context_data=context, response_type=query_param.response_type + ) + + if query_param.only_need_prompt: + return sys_prompt + + # Now call the LLM with the final system prompt + response = await use_model_func( + query, + system_prompt=sys_prompt, + stream=query_param.stream, + ) + + # Clean up the response + if isinstance(response, str) and len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + + # --------------------------- + # 4) SAVE TO CACHE + # --------------------------- + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response + + +async def extract_keywords_only( + text: str, + param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> tuple[list[str], list[str]]: + """ + Extract high-level and low-level keywords from the given 'text' using the LLM. + This method does NOT build the final RAG context or provide a final answer. + It ONLY extracts keywords (hl_keywords, ll_keywords). + """ + + # 1. Handle cache if needed + args_hash = compute_args_hash(param.mode, text) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, text, param.mode + ) + if cached_response is not None: + # parse the cached_response if it’s JSON containing keywords + # or simply return (hl_keywords, ll_keywords) from cached + # Assuming cached_response is in the same JSON structure: + match = re.search(r"\{.*\}", cached_response, re.DOTALL) + if match: + keywords_data = json.loads(match.group(0)) + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + return hl_keywords, ll_keywords + return [], [] + + # 2. Build the examples + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): + examples = "\n".join( + PROMPTS["keywords_extraction_examples"][: int(example_number)] + ) + else: + examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) + + # 3. Build the keyword-extraction prompt + kw_prompt_temp = PROMPTS["keywords_extraction"] + kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language) + + # 4. Call the LLM for keyword extraction + use_model_func = global_config["llm_model_func"] + result = await use_model_func(kw_prompt, keyword_extraction=True) + + # 5. Parse out JSON from the LLM response + match = re.search(r"\{.*\}", result, re.DOTALL) + if not match: + logger.error("No JSON-like structure found in the result.") + return [], [] + try: + keywords_data = json.loads(match.group(0)) + except json.JSONDecodeError as e: + logger.error(f"JSON parsing error: {e}") + return [], [] + + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + + # 6. Cache the result if needed + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=result, + prompt=text, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=param.mode, + ), + ) + return hl_keywords, ll_keywords + + async def _build_query_context( query: list, knowledge_graph_inst: BaseGraphStorage, diff --git a/lightrag/utils.py b/lightrag/utils.py index 1f6bf405..ce556ab2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -30,13 +30,18 @@ class UnlimitedSemaphore: ENCODER = None +statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} + logger = logging.getLogger("lightrag") +# Set httpx logging level to WARNING +logging.getLogger("httpx").setLevel(logging.WARNING) + def set_logger(log_file: str): logger.setLevel(logging.DEBUG) - file_handler = logging.FileHandler(log_file) + file_handler = logging.FileHandler(log_file, encoding="utf-8") file_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( @@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): return None, None, None, None # For naive mode, only use simple cache matching - if mode == "naive": + # if mode == "naive": + if mode == "default": if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: @@ -473,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): quantized = min_val = max_val = None if is_embedding_cache_enabled: # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + embedding_model_func = hashing_kv.global_config[ + "embedding_func" + ].func # ["func"] llm_model_func = hashing_kv.global_config.get("llm_model_func") current_embedding = await embedding_model_func([prompt])