From 7ab699955e05d35ea89d4b46fb72138e15dcc877 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Thu, 17 Oct 2024 10:29:08 +0800 Subject: [PATCH 01/25] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index fb29945b..7ad8dd26 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## Install -* Install from source +* Install from source (Recommend) ```bash cd LightRAG @@ -142,7 +142,7 @@ rag = LightRAG(
- Using Ollama Models + Using Ollama Models (There are some bugs. I'll fix them ASAP.) If you want to use Ollama models, you only need to set LightRAG as follows: ```python From 0e0a037a1d15743798286146c998e6cfa29ddc1e Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Thu, 17 Oct 2024 14:39:11 +0800 Subject: [PATCH 02/25] Add Discord channel link --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ad8dd26..ff6fe44a 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,10 @@

- +

+ @@ -20,6 +21,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## ๐ŸŽ‰ News +- [x] [2024.10.17]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ๐ŸŽ‰๐ŸŽ‰ - [x] [2024.10.16]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.15]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! From a2f1654f4cc2eeb73b38ca6e1d2ff787bc514a34 Mon Sep 17 00:00:00 2001 From: LarFii <834462287@qq.com> Date: Thu, 17 Oct 2024 16:02:43 +0800 Subject: [PATCH 03/25] fix Ollama bugs --- README.md | 2 +- lightrag/operate.py | 81 ++++++++++++++++++++++++++------------------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index ff6fe44a..fd85141b 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ rag = LightRAG(

- Using Ollama Models (There are some bugs. I'll fix them ASAP.) + Using Ollama Models If you want to use Ollama models, you only need to set LightRAG as follows: ```python diff --git a/lightrag/operate.py b/lightrag/operate.py index 3d388cb6..3a17810a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -387,6 +387,7 @@ async def local_query( query_param: QueryParam, global_config: dict, ) -> str: + context = None use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] @@ -399,7 +400,9 @@ async def local_query( keywords = ', '.join(keywords) except json.JSONDecodeError as e: try: - result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') + result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip() + result = '{' + result.split('{')[1].split('}')[0] + '}' + keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) keywords = ', '.join(keywords) @@ -407,13 +410,14 @@ async def local_query( except json.JSONDecodeError as e: print(f"JSON parsing error: {e}") return PROMPTS["fail_response"] - context = await _build_local_query_context( - keywords, - knowledge_graph_inst, - entities_vdb, - text_chunks_db, - query_param, - ) + if keywords: + context = await _build_local_query_context( + keywords, + knowledge_graph_inst, + entities_vdb, + text_chunks_db, + query_param, + ) if query_param.only_need_context: return context if context is None: @@ -614,6 +618,7 @@ async def global_query( query_param: QueryParam, global_config: dict, ) -> str: + context = None use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] @@ -626,7 +631,9 @@ async def global_query( keywords = ', '.join(keywords) except json.JSONDecodeError as e: try: - result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') + result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip() + result = '{' + result.split('{')[1].split('}')[0] + '}' + keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) keywords = ', '.join(keywords) @@ -635,15 +642,15 @@ async def global_query( # Handle parsing error print(f"JSON parsing error: {e}") return PROMPTS["fail_response"] - - context = await _build_global_query_context( - keywords, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - text_chunks_db, - query_param, - ) + if keywords: + context = await _build_global_query_context( + keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + ) if query_param.only_need_context: return context @@ -836,6 +843,8 @@ async def hybrid_query( query_param: QueryParam, global_config: dict, ) -> str: + low_level_context = None + high_level_context = None use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] @@ -850,7 +859,9 @@ async def hybrid_query( ll_keywords = ', '.join(ll_keywords) except json.JSONDecodeError as e: try: - result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') + result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip() + result = '{' + result.split('{')[1].split('}')[0] + '}' + keywords_data = json.loads(result) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) @@ -861,22 +872,24 @@ async def hybrid_query( print(f"JSON parsing error: {e}") return PROMPTS["fail_response"] - low_level_context = await _build_local_query_context( - ll_keywords, - knowledge_graph_inst, - entities_vdb, - text_chunks_db, - query_param, - ) + if ll_keywords: + low_level_context = await _build_local_query_context( + ll_keywords, + knowledge_graph_inst, + entities_vdb, + text_chunks_db, + query_param, + ) - high_level_context = await _build_global_query_context( - hl_keywords, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - text_chunks_db, - query_param, - ) + if hl_keywords: + high_level_context = await _build_global_query_context( + hl_keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + ) context = combine_contexts(high_level_context, low_level_context) From 70dbca190e296fd1aeeb45d384af06eeaced3285 Mon Sep 17 00:00:00 2001 From: KIM Jae Boum Date: Fri, 18 Oct 2024 06:06:47 +0800 Subject: [PATCH 04/25] update Step_3.py and openai compatible script --- reproduce/Step_1_openai_compatible.py | 66 ++++++++++++++++++ reproduce/Step_3.py | 4 +- reproduce/Step_3_openai_compatible.py | 99 +++++++++++++++++++++++++++ 3 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 reproduce/Step_1_openai_compatible.py create mode 100644 reproduce/Step_3_openai_compatible.py diff --git a/reproduce/Step_1_openai_compatible.py b/reproduce/Step_1_openai_compatible.py new file mode 100644 index 00000000..b5c6aef3 --- /dev/null +++ b/reproduce/Step_1_openai_compatible.py @@ -0,0 +1,66 @@ +import os +import json +import time +import numpy as np + +from lightrag import LightRAG +from lightrag.utils import EmbeddingFunc +from lightrag.llm import openai_complete_if_cache, openai_embedding + +## For Upstage API +# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "solar-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar", + **kwargs + ) + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model="solar-embedding-1-large-query", + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar" + ) +## /For Upstage API + +def insert_text(rag, file_path): + with open(file_path, mode='r') as f: + unique_contexts = json.load(f) + + retries = 0 + max_retries = 3 + while retries < max_retries: + try: + rag.insert(unique_contexts) + break + except Exception as e: + retries += 1 + print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}") + time.sleep(10) + if retries == max_retries: + print("Insertion failed after exceeding the maximum number of retries") + +cls = "mix" +WORKING_DIR = f"../{cls}" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +rag = LightRAG(working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=4096, + max_token_size=8192, + func=embedding_func + ) + ) + +insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json") diff --git a/reproduce/Step_3.py b/reproduce/Step_3.py index e97e2af6..a79ebd17 100644 --- a/reproduce/Step_3.py +++ b/reproduce/Step_3.py @@ -53,10 +53,10 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file if __name__ == "__main__": cls = "agriculture" mode = "hybrid" - WORKING_DIR = "../{cls}" + WORKING_DIR = f"../{cls}" rag = LightRAG(working_dir=WORKING_DIR) query_param = QueryParam(mode=mode) queries = extract_queries(f"../datasets/questions/{cls}_questions.txt") - run_queries_and_save_to_json(queries, rag, query_param, "result.json", "errors.json") \ No newline at end of file + run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json") diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py new file mode 100644 index 00000000..7b3079a9 --- /dev/null +++ b/reproduce/Step_3_openai_compatible.py @@ -0,0 +1,99 @@ +import os +import re +import json +import asyncio +from lightrag import LightRAG, QueryParam +from tqdm import tqdm +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np + +## For Upstage API +# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "solar-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar", + **kwargs + ) + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model="solar-embedding-1-large-query", + api_key=os.getenv("UPSTAGE_API_KEY"), + base_url="https://api.upstage.ai/v1/solar" + ) +## /For Upstage API + +def extract_queries(file_path): + with open(file_path, 'r') as f: + data = f.read() + + data = data.replace('**', '') + + queries = re.findall(r'- Question \d+: (.+)', data) + + return queries + +async def process_query(query_text, rag_instance, query_param): + try: + result, context = await rag_instance.aquery(query_text, param=query_param) + return {"query": query_text, "result": result, "context": context}, None + except Exception as e: + return None, {"query": query_text, "error": str(e)} + +def always_get_an_event_loop() -> asyncio.AbstractEventLoop: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + +def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file): + loop = always_get_an_event_loop() + + with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file: + result_file.write("[\n") + first_entry = True + + for query_text in tqdm(queries, desc="Processing queries", unit="query"): + result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param)) + + if result: + if not first_entry: + result_file.write(",\n") + json.dump(result, result_file, ensure_ascii=False, indent=4) + first_entry = False + elif error: + json.dump(error, err_file, ensure_ascii=False, indent=4) + err_file.write("\n") + + result_file.write("\n]") + +if __name__ == "__main__": + cls = "mix" + mode = "hybrid" + WORKING_DIR = f"../{cls}" + + rag = LightRAG(working_dir=WORKING_DIR) + rag = LightRAG(working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=4096, + max_token_size=8192, + func=embedding_func + ) + ) + query_param = QueryParam(mode=mode) + + base_dir='../datasets/questions' + queries = extract_queries(f"{base_dir}/{cls}_questions.txt") + run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json") From 996c9543a55d773ef37930d7569c63c742b925fd Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Fri, 18 Oct 2024 12:14:14 +0800 Subject: [PATCH 05/25] Add a link to a LightRAG explanatory video --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index fd85141b..2987507d 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@

+

@@ -21,6 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
## ๐ŸŽ‰ News +- [x] [2024.10.18]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWeโ€™ve added a link to a [LightRAG explanatory video](https://youtu.be/oageL-1I0GE). Thanks to the author! - [x] [2024.10.17]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ๐ŸŽ‰๐ŸŽ‰ - [x] [2024.10.16]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.15]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! From d04f70d4254eb024e8bd2347594d29149413363f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=9C=A8Data=20Intelligence=20Lab=40HKU=E2=9C=A8?= <118165258+HKUDS@users.noreply.github.com> Date: Fri, 18 Oct 2024 12:45:30 +0800 Subject: [PATCH 06/25] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2987507d..d0ed8a35 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## ๐ŸŽ‰ News -- [x] [2024.10.18]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWeโ€™ve added a link to a [LightRAG explanatory video](https://youtu.be/oageL-1I0GE). Thanks to the author! +- [x] [2024.10.18]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWeโ€™ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author! - [x] [2024.10.17]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ๐ŸŽ‰๐ŸŽ‰ - [x] [2024.10.16]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.15]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! From f576a28e0d66904a382f3eae076f1ff2699a6239 Mon Sep 17 00:00:00 2001 From: zrguo Date: Fri, 18 Oct 2024 15:32:58 +0800 Subject: [PATCH 07/25] Create lightrag_azure_openai_demo.py --- examples/lightrag_azure_openai_demo.py | 125 +++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 examples/lightrag_azure_openai_demo.py diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py new file mode 100644 index 00000000..62282a25 --- /dev/null +++ b/examples/lightrag_azure_openai_demo.py @@ -0,0 +1,125 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc +import numpy as np +from dotenv import load_dotenv +import aiohttp +import logging + +logging.basicConfig(level=logging.INFO) + +load_dotenv() + +AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") +AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") +AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") +AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + +AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") +AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") + +WORKING_DIR = "./dickens" + +if os.path.exists(WORKING_DIR): + import shutil + + shutil.rmtree(WORKING_DIR) + +os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + headers = { + "Content-Type": "application/json", + "api-key": AZURE_OPENAI_API_KEY, + } + endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}" + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + payload = { + "messages": messages, + "temperature": kwargs.get("temperature", 0), + "top_p": kwargs.get("top_p", 1), + "n": kwargs.get("n", 1), + } + + async with aiohttp.ClientSession() as session: + async with session.post(endpoint, headers=headers, json=payload) as response: + if response.status != 200: + raise ValueError( + f"Request failed with status {response.status}: {await response.text()}" + ) + result = await response.json() + return result["choices"][0]["message"]["content"] + + +async def embedding_func(texts: list[str]) -> np.ndarray: + headers = { + "Content-Type": "application/json", + "api-key": AZURE_OPENAI_API_KEY, + } + endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}" + + payload = {"input": texts} + + async with aiohttp.ClientSession() as session: + async with session.post(endpoint, headers=headers, json=payload) as response: + if response.status != 200: + raise ValueError( + f"Request failed with status {response.status}: {await response.text()}" + ) + result = await response.json() + embeddings = [item["embedding"] for item in result["data"]] + return np.array(embeddings) + + +async def test_funcs(): + result = await llm_model_func("How are you?") + print("Resposta do llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("Resultado do embedding_func: ", result.shape) + print("Dimensรฃo da embedding: ", result.shape[1]) + + +asyncio.run(test_funcs()) + +embedding_dimension = 3072 + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), +) + +book1 = open("./book_1.txt", encoding="utf-8") +book2 = open("./book_2.txt", encoding="utf-8") + +rag.insert([book1.read(), book2.read()]) + +query_text = "What are the main themes?" + +print("Result (Naive):") +print(rag.query(query_text, param=QueryParam(mode="naive"))) + +print("\nResult (Local):") +print(rag.query(query_text, param=QueryParam(mode="local"))) + +print("\nResult (Global):") +print(rag.query(query_text, param=QueryParam(mode="global"))) + +print("\nResult (Hybrid):") +print(rag.query(query_text, param=QueryParam(mode="hybrid"))) \ No newline at end of file From e7a7ff62b264ae7dde437c8dac3e32847090805a Mon Sep 17 00:00:00 2001 From: zrguo Date: Fri, 18 Oct 2024 15:33:11 +0800 Subject: [PATCH 08/25] Update operate.py --- lightrag/operate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 3a17810a..930ceb2a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -76,7 +76,7 @@ async def _handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, ): - if record_attributes[0] != '"entity"' or len(record_attributes) < 4: + if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None # add this record as a node in the G entity_name = clean_str(record_attributes[1].upper()) @@ -97,7 +97,7 @@ async def _handle_single_relationship_extraction( record_attributes: list[str], chunk_key: str, ): - if record_attributes[0] != '"relationship"' or len(record_attributes) < 5: + if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None # add this record as edge source = clean_str(record_attributes[1].upper()) From 705087529524ec96602435cd5eb736f0632e1d89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Galego?= Date: Fri, 18 Oct 2024 14:17:14 +0100 Subject: [PATCH 09/25] Added support for Amazon Bedrock models --- .gitignore | 4 + examples/lightrag_bedrock_demo.py | 48 +++++++++++ lightrag/llm.py | 128 ++++++++++++++++++++++++++++++ requirements.txt | 1 + 4 files changed, 181 insertions(+) create mode 100644 .gitignore create mode 100644 examples/lightrag_bedrock_demo.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..cb457220 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +*.egg-info +dickens/ +book.txt \ No newline at end of file diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py new file mode 100644 index 00000000..36ec3857 --- /dev/null +++ b/examples/lightrag_bedrock_demo.py @@ -0,0 +1,48 @@ +""" +LightRAG meets Amazon Bedrock โ›ฐ๏ธ +""" + +import os + +from lightrag import LightRAG, QueryParam +from lightrag.llm import bedrock_complete, bedrock_embedding +from lightrag.utils import EmbeddingFunc + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=bedrock_complete, + llm_model_name="anthropic.claude-3-haiku-20240307-v1:0", + node2vec_params = { + 'dimensions': 1024, + 'num_walks': 10, + 'walk_length': 40, + 'window_size': 2, + 'iterations': 3, + 'random_seed': 3 + }, + embedding_func=EmbeddingFunc( + embedding_dim=1024, + max_token_size=8192, + func=lambda texts: bedrock_embedding(texts) + ) +) + +with open("./book.txt") as f: + rag.insert(f.read()) + +# Naive search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) + +# Local search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) + +# Global search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) + +# Hybrid search +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) diff --git a/lightrag/llm.py b/lightrag/llm.py index 7328a583..8fc0da2e 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,4 +1,6 @@ import os +import json +import aioboto3 import numpy as np import ollama from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout @@ -48,6 +50,54 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def bedrock_complete_if_cache( + model, prompt, system_prompt=None, history_messages=[], base_url=None, + aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs +) -> str: + os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id) + os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key) + os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + + messages = [] + messages.extend(history_messages) + messages.append({'role': "user", 'content': [{'text': prompt}]}) + + args = { + 'modelId': model, + 'messages': messages + } + + if system_prompt: + args['system'] = [{'text': system_prompt}] + + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + session = aioboto3.Session() + async with session.client("bedrock-runtime") as bedrock_async_client: + + response = await bedrock_async_client.converse(**args, **kwargs) + + if hashing_kv is not None: + await hashing_kv.upsert({ + args_hash: { + 'return': response['output']['message']['content'][0]['text'], + 'model': model + } + }) + + return response['output']['message']['content'][0]['text'] + async def hf_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -145,6 +195,19 @@ async def gpt_4o_mini_complete( **kwargs, ) + +async def bedrock_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await bedrock_complete_if_cache( + "anthropic.claude-3-sonnet-20240229-v1:0", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + async def hf_model_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -186,6 +249,71 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal return np.array([dp.embedding for dp in response.data]) +# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +# @retry( +# stop=stop_after_attempt(3), +# wait=wait_exponential(multiplier=1, min=4, max=10), +# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions +# ) +async def bedrock_embedding( + texts: list[str], model: str = "amazon.titan-embed-text-v2:0", + aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray: + os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id) + os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key) + os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token) + + session = aioboto3.Session() + async with session.client("bedrock-runtime") as bedrock_async_client: + + if (model_provider := model.split(".")[0]) == "amazon": + embed_texts = [] + for text in texts: + if "v2" in model: + body = json.dumps({ + 'inputText': text, + # 'dimensions': embedding_dim, + 'embeddingTypes': ["float"] + }) + elif "v1" in model: + body = json.dumps({ + 'inputText': text + }) + else: + raise ValueError(f"Model {model} is not supported!") + + response = await bedrock_async_client.invoke_model( + modelId=model, + body=body, + accept="application/json", + contentType="application/json" + ) + + response_body = await response.get('body').json() + + embed_texts.append(response_body['embedding']) + elif model_provider == "cohere": + body = json.dumps({ + 'texts': texts, + 'input_type': "search_document", + 'truncate': "NONE" + }) + + response = await bedrock_async_client.invoke_model( + model=model, + body=body, + accept="application/json", + contentType="application/json" + ) + + response_body = json.loads(response.get('body').read()) + + embed_texts = response_body['embeddings'] + else: + raise ValueError(f"Model provider '{model_provider}' is not supported!") + + return np.array(embed_texts) + + async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids with torch.no_grad(): diff --git a/requirements.txt b/requirements.txt index f7dcd787..a1054692 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aioboto3 openai tiktoken networkx From 75a91d9300aa62cf0e918003e430e391c8d69ccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Galego?= Date: Fri, 18 Oct 2024 16:50:02 +0100 Subject: [PATCH 10/25] Fixed retry strategy, message history and inference params; Cleaned up Bedrock example --- examples/lightrag_bedrock_demo.py | 39 +++++++++++-------------- lightrag/llm.py | 48 +++++++++++++++++++++++++------ 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py index 36ec3857..c515922e 100644 --- a/examples/lightrag_bedrock_demo.py +++ b/examples/lightrag_bedrock_demo.py @@ -3,46 +3,39 @@ LightRAG meets Amazon Bedrock โ›ฐ๏ธ """ import os +import logging from lightrag import LightRAG, QueryParam from lightrag.llm import bedrock_complete, bedrock_embedding from lightrag.utils import EmbeddingFunc -WORKING_DIR = "./dickens" +logging.getLogger("aiobotocore").setLevel(logging.WARNING) +WORKING_DIR = "./dickens" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=bedrock_complete, - llm_model_name="anthropic.claude-3-haiku-20240307-v1:0", - node2vec_params = { - 'dimensions': 1024, - 'num_walks': 10, - 'walk_length': 40, - 'window_size': 2, - 'iterations': 3, - 'random_seed': 3 - }, + llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock", embedding_func=EmbeddingFunc( embedding_dim=1024, max_token_size=8192, - func=lambda texts: bedrock_embedding(texts) + func=bedrock_embedding ) ) -with open("./book.txt") as f: +with open("./book.txt", 'r', encoding='utf-8') as f: rag.insert(f.read()) -# Naive search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) - -# Local search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) - -# Global search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) - -# Hybrid search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) +for mode in ["naive", "local", "global", "hybrid"]: + print("\n+-" + "-" * len(mode) + "-+") + print(f"| {mode.capitalize()} |") + print("+-" + "-" * len(mode) + "-+\n") + print( + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode=mode) + ) + ) diff --git a/lightrag/llm.py b/lightrag/llm.py index 8fc0da2e..48defb4d 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,6 +1,9 @@ import os +import copy import json +import botocore import aioboto3 +import botocore.errorfactory import numpy as np import ollama from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout @@ -50,43 +53,70 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content + +class BedrockError(Exception): + """Generic error for issues related to Amazon Bedrock""" + + @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, max=60), + retry=retry_if_exception_type((BedrockError)), ) async def bedrock_complete_if_cache( - model, prompt, system_prompt=None, history_messages=[], base_url=None, + model, prompt, system_prompt=None, history_messages=[], aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs ) -> str: os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id) os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key) os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - + # Fix message history format messages = [] - messages.extend(history_messages) + for history_message in history_messages: + message = copy.copy(history_message) + message['content'] = [{'text': message['content']}] + messages.append(message) + + # Add user prompt messages.append({'role': "user", 'content': [{'text': prompt}]}) + # Initialize Converse API arguments args = { 'modelId': model, 'messages': messages } + # Define system prompt if system_prompt: args['system'] = [{'text': system_prompt}] + # Map and set up inference parameters + inference_params_map = { + 'max_tokens': "maxTokens", + 'top_p': "topP", + 'stop_sequences': "stopSequences" + } + if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))): + args['inferenceConfig'] = {} + for param in inference_params: + args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] + # Call model via Converse API session = aioboto3.Session() async with session.client("bedrock-runtime") as bedrock_async_client: - response = await bedrock_async_client.converse(**args, **kwargs) + try: + response = await bedrock_async_client.converse(**args, **kwargs) + except Exception as e: + raise BedrockError(e) if hashing_kv is not None: await hashing_kv.upsert({ @@ -200,7 +230,7 @@ async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: return await bedrock_complete_if_cache( - "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", prompt, system_prompt=system_prompt, history_messages=history_messages, From a7b43d27dbe2e77c7cf666ba0327e08ec60815b9 Mon Sep 17 00:00:00 2001 From: Wade Rosko <7385473+wrosko@users.noreply.github.com> Date: Fri, 18 Oct 2024 18:09:48 -0600 Subject: [PATCH 11/25] Add comment specifying jupyter req Add lines that can be uncommented if running in a jupyter notebook --- README.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d0ed8a35..bd226582 100644 --- a/README.md +++ b/README.md @@ -47,12 +47,21 @@ pip install lightrag-hku ```bash curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt ``` -Use the below Python snippet to initialize LightRAG and perform queries: +Use the below Python snippet (in a script) to initialize LightRAG and perform queries: ```python from lightrag import LightRAG, QueryParam from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete +######### +# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() +# import nest_asyncio +# nest_asyncio.apply() +######### + +WORKING_DIR = "./dickens" + + WORKING_DIR = "./dickens" if not os.path.exists(WORKING_DIR): From e2db7b6c45ac4b48d7026d69b3a770b42bad4dbe Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Sat, 19 Oct 2024 11:46:03 +0800 Subject: [PATCH 12/25] fix prompt.py --- lightrag/prompt.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 5d28e49c..67d52d63 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -163,25 +163,10 @@ Do not include information where the supporting evidence for it is not provided. {response_type} - ---Data tables--- {context_data} - ----Goal--- - -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. - -If you don't know the answer, just say so. Do not make anything up. - -Do not include information where the supporting evidence for it is not provided. - - ----Target response length and format--- - -{response_type} - Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. """ From 744dad339d6b06505659ab5b1091180aecdc4c3b Mon Sep 17 00:00:00 2001 From: Sanketh Kumar Date: Sat, 19 Oct 2024 09:43:17 +0530 Subject: [PATCH 13/25] chore: added pre-commit-hooks and ruff formatting for commit-hooks --- .gitignore | 3 +- .pre-commit-config.yaml | 22 ++ README.md | 50 ++--- examples/batch_eval.py | 38 ++-- examples/generate_query.py | 9 +- examples/lightrag_azure_openai_demo.py | 2 +- examples/lightrag_bedrock_demo.py | 13 +- examples/lightrag_hf_demo.py | 35 ++- examples/lightrag_ollama_demo.py | 25 ++- examples/lightrag_openai_compatible_demo.py | 32 ++- examples/lightrag_openai_demo.py | 22 +- lightrag/__init__.py | 2 +- lightrag/base.py | 11 +- lightrag/lightrag.py | 65 +++--- lightrag/llm.py | 223 ++++++++++++------- lightrag/operate.py | 229 +++++++++++++------- lightrag/prompt.py | 14 +- lightrag/storage.py | 15 +- lightrag/utils.py | 28 ++- reproduce/Step_0.py | 24 +- reproduce/Step_1.py | 8 +- reproduce/Step_1_openai_compatible.py | 29 ++- reproduce/Step_2.py | 20 +- reproduce/Step_3.py | 33 ++- reproduce/Step_3_openai_compatible.py | 58 +++-- requirements.txt | 18 +- 26 files changed, 635 insertions(+), 393 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.gitignore b/.gitignore index cb457220..50f384ec 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__ *.egg-info dickens/ -book.txt \ No newline at end of file +book.txt +lightrag-dev/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..db531bb6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,22 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: requirements-txt-fixer + + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.4 + hooks: + - id: ruff-format + - id: ruff + args: [--fix] + + + - repo: https://github.com/mgedmin/check-manifest + rev: "0.49" + hooks: + - id: check-manifest + stages: [manual] diff --git a/README.md b/README.md index d0ed8a35..b3a04957 100644 --- a/README.md +++ b/README.md @@ -16,16 +16,16 @@

- + This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag). ![่ฏทๆทปๅŠ ๅ›พ็‰‡ๆ่ฟฐ](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png) -## ๐ŸŽ‰ News +## ๐ŸŽ‰ News - [x] [2024.10.18]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWeโ€™ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author! - [x] [2024.10.17]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ๐ŸŽ‰๐ŸŽ‰ -- [x] [2024.10.16]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! -- [x] [2024.10.15]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! +- [x] [2024.10.16]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! +- [x] [2024.10.15]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! ## Install @@ -83,7 +83,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
Using Open AI-like APIs -LightRAG also support Open AI-like chat/embeddings APIs: +LightRAG also supports Open AI-like chat/embeddings APIs: ```python async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs @@ -120,7 +120,7 @@ rag = LightRAG(
Using Hugging Face Models - + If you want to use Hugging Face models, you only need to set LightRAG as follows: ```python from lightrag.llm import hf_model_complete, hf_embedding @@ -136,7 +136,7 @@ rag = LightRAG( embedding_dim=384, max_token_size=5000, func=lambda texts: hf_embedding( - texts, + texts, tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") ) @@ -148,7 +148,7 @@ rag = LightRAG(
Using Ollama Models If you want to use Ollama models, you only need to set LightRAG as follows: - + ```python from lightrag.llm import ollama_model_complete, ollama_embedding @@ -162,7 +162,7 @@ rag = LightRAG( embedding_dim=768, max_token_size=8192, func=lambda texts: ollama_embedding( - texts, + texts, embed_model="nomic-embed-text" ) ), @@ -187,14 +187,14 @@ with open("./newText.txt") as f: ``` ## Evaluation ### Dataset -The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). +The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). ### Generate Query -LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`. +LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
Prompt - + ```python Given the following description of a dataset: @@ -219,18 +219,18 @@ Output the results in the following structure: ... ```
- + ### Batch Eval To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
Prompt - + ```python ---Role--- You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. ---Goal--- -You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. +You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question? - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question? @@ -294,7 +294,7 @@ Output your evaluation in the following JSON format: | **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% | | **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% | -## Reproduce +## Reproduce All the code can be found in the `./reproduce` directory. ### Step-0 Extract Unique Contexts @@ -302,7 +302,7 @@ First, we need to extract unique contexts in the datasets.
Code - + ```python def extract_unique_contexts(input_directory, output_directory): @@ -361,12 +361,12 @@ For the extracted contexts, we insert them into the LightRAG system.
Code - + ```python def insert_text(rag, file_path): with open(file_path, mode='r') as f: unique_contexts = json.load(f) - + retries = 0 max_retries = 3 while retries < max_retries: @@ -384,11 +384,11 @@ def insert_text(rag, file_path): ### Step-2 Generate Queries -We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries. +We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
Code - + ```python tokenizer = GPT2Tokenizer.from_pretrained('gpt2') @@ -401,7 +401,7 @@ def get_summary(context, tot_tokens=2000): summary_tokens = start_tokens + end_tokens summary = tokenizer.convert_tokens_to_string(summary_tokens) - + return summary ```
@@ -411,12 +411,12 @@ For the queries generated in Step-2, we will extract them and query LightRAG.
Code - + ```python def extract_queries(file_path): with open(file_path, 'r') as f: data = f.read() - + data = data.replace('**', '') queries = re.findall(r'- Question \d+: (.+)', data) @@ -470,7 +470,7 @@ def extract_queries(file_path): ```python @article{guo2024lightrag, -title={LightRAG: Simple and Fast Retrieval-Augmented Generation}, +title={LightRAG: Simple and Fast Retrieval-Augmented Generation}, author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang}, year={2024}, eprint={2410.05779}, diff --git a/examples/batch_eval.py b/examples/batch_eval.py index 4601d267..a85e1ede 100644 --- a/examples/batch_eval.py +++ b/examples/batch_eval.py @@ -1,4 +1,3 @@ -import os import re import json import jsonlines @@ -9,28 +8,28 @@ from openai import OpenAI def batch_eval(query_file, result1_file, result2_file, output_file_path): client = OpenAI() - with open(query_file, 'r') as f: + with open(query_file, "r") as f: data = f.read() - queries = re.findall(r'- Question \d+: (.+)', data) + queries = re.findall(r"- Question \d+: (.+)", data) - with open(result1_file, 'r') as f: + with open(result1_file, "r") as f: answers1 = json.load(f) - answers1 = [i['result'] for i in answers1] + answers1 = [i["result"] for i in answers1] - with open(result2_file, 'r') as f: + with open(result2_file, "r") as f: answers2 = json.load(f) - answers2 = [i['result'] for i in answers2] + answers2 = [i["result"] for i in answers2] requests = [] for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)): - sys_prompt = f""" + sys_prompt = """ ---Role--- You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. """ prompt = f""" - You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. + You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question? - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question? @@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path): }} """ - request_data = { "custom_id": f"request-{i+1}", "method": "POST", @@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path): "model": "gpt-4o-mini", "messages": [ {"role": "system", "content": sys_prompt}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ], - } + }, } - + requests.append(request_data) - with jsonlines.open(output_file_path, mode='w') as writer: + with jsonlines.open(output_file_path, mode="w") as writer: for request in requests: writer.write(request) print(f"Batch API requests written to {output_file_path}") batch_input_file = client.files.create( - file=open(output_file_path, "rb"), - purpose="batch" + file=open(output_file_path, "rb"), purpose="batch" ) batch_input_file_id = batch_input_file.id @@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path): input_file_id=batch_input_file_id, endpoint="/v1/chat/completions", completion_window="24h", - metadata={ - "description": "nightly eval job" - } + metadata={"description": "nightly eval job"}, ) - print(f'Batch {batch.id} has been created.') + print(f"Batch {batch.id} has been created.") + if __name__ == "__main__": - batch_eval() \ No newline at end of file + batch_eval() diff --git a/examples/generate_query.py b/examples/generate_query.py index 0ae82f40..705b23d3 100644 --- a/examples/generate_query.py +++ b/examples/generate_query.py @@ -1,9 +1,8 @@ -import os - from openai import OpenAI # os.environ["OPENAI_API_KEY"] = "" + def openai_complete_if_cache( model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -47,10 +46,10 @@ if __name__ == "__main__": ... """ - result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt) + result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt) - file_path = f"./queries.txt" + file_path = "./queries.txt" with open(file_path, "w") as file: file.write(result) - print(f"Queries written to {file_path}") \ No newline at end of file + print(f"Queries written to {file_path}") diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py index 62282a25..e29a6a9d 100644 --- a/examples/lightrag_azure_openai_demo.py +++ b/examples/lightrag_azure_openai_demo.py @@ -122,4 +122,4 @@ print("\nResult (Global):") print(rag.query(query_text, param=QueryParam(mode="global"))) print("\nResult (Hybrid):") -print(rag.query(query_text, param=QueryParam(mode="hybrid"))) \ No newline at end of file +print(rag.query(query_text, param=QueryParam(mode="hybrid"))) diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py index c515922e..7e18ea57 100644 --- a/examples/lightrag_bedrock_demo.py +++ b/examples/lightrag_bedrock_demo.py @@ -20,13 +20,11 @@ rag = LightRAG( llm_model_func=bedrock_complete, llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock", embedding_func=EmbeddingFunc( - embedding_dim=1024, - max_token_size=8192, - func=bedrock_embedding - ) + embedding_dim=1024, max_token_size=8192, func=bedrock_embedding + ), ) -with open("./book.txt", 'r', encoding='utf-8') as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) for mode in ["naive", "local", "global", "hybrid"]: @@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]: print(f"| {mode.capitalize()} |") print("+-" + "-" * len(mode) + "-+\n") print( - rag.query( - "What are the top themes in this story?", - param=QueryParam(mode=mode) - ) + rag.query("What are the top themes in this story?", param=QueryParam(mode=mode)) ) diff --git a/examples/lightrag_hf_demo.py b/examples/lightrag_hf_demo.py index baf62bdb..87312307 100644 --- a/examples/lightrag_hf_demo.py +++ b/examples/lightrag_hf_demo.py @@ -1,10 +1,9 @@ import os -import sys from lightrag import LightRAG, QueryParam from lightrag.llm import hf_model_complete, hf_embedding from lightrag.utils import EmbeddingFunc -from transformers import AutoModel,AutoTokenizer +from transformers import AutoModel, AutoTokenizer WORKING_DIR = "./dickens" @@ -13,16 +12,20 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG( working_dir=WORKING_DIR, - llm_model_func=hf_model_complete, - llm_model_name='meta-llama/Llama-3.1-8B-Instruct', + llm_model_func=hf_model_complete, + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", embedding_func=EmbeddingFunc( embedding_dim=384, max_token_size=5000, func=lambda texts: hf_embedding( - texts, - tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), - embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") - ) + texts, + tokenizer=AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + embed_model=AutoModel.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + ), ), ) @@ -31,13 +34,21 @@ with open("./book.txt") as f: rag.insert(f.read()) # Perform naive search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) # Perform local search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) # Perform global search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) # Perform hybrid search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index a2d04aa6..c61b71c0 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -11,15 +11,12 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG( working_dir=WORKING_DIR, - llm_model_func=ollama_model_complete, - llm_model_name='your_model_name', + llm_model_func=ollama_model_complete, + llm_model_name="your_model_name", embedding_func=EmbeddingFunc( embedding_dim=768, max_token_size=8192, - func=lambda texts: ollama_embedding( - texts, - embed_model="nomic-embed-text" - ) + func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"), ), ) @@ -28,13 +25,21 @@ with open("./book.txt") as f: rag.insert(f.read()) # Perform naive search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) # Perform local search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) # Perform global search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) # Perform hybrid search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index 75ecc118..fbad1190 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -6,10 +6,11 @@ from lightrag.utils import EmbeddingFunc import numpy as np WORKING_DIR = "./dickens" - + if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) + async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -20,17 +21,19 @@ async def llm_model_func( history_messages=history_messages, api_key=os.getenv("UPSTAGE_API_KEY"), base_url="https://api.upstage.ai/v1/solar", - **kwargs + **kwargs, ) + async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, model="solar-embedding-1-large-query", api_key=os.getenv("UPSTAGE_API_KEY"), - base_url="https://api.upstage.ai/v1/solar" + base_url="https://api.upstage.ai/v1/solar", ) + # function test async def test_funcs(): result = await llm_model_func("How are you?") @@ -39,6 +42,7 @@ async def test_funcs(): result = await embedding_func(["How are you?"]) print("embedding_func: ", result) + asyncio.run(test_funcs()) @@ -46,10 +50,8 @@ rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=llm_model_func, embedding_func=EmbeddingFunc( - embedding_dim=4096, - max_token_size=8192, - func=embedding_func - ) + embedding_dim=4096, max_token_size=8192, func=embedding_func + ), ) @@ -57,13 +59,21 @@ with open("./book.txt") as f: rag.insert(f.read()) # Perform naive search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) # Perform local search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) # Perform global search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) # Perform hybrid search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index fb1f055c..a6e7f3b2 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -1,9 +1,7 @@ import os -import sys from lightrag import LightRAG, QueryParam -from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete -from transformers import AutoModel,AutoTokenizer +from lightrag.llm import gpt_4o_mini_complete WORKING_DIR = "./dickens" @@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG( working_dir=WORKING_DIR, - llm_model_func=gpt_4o_mini_complete + llm_model_func=gpt_4o_mini_complete, # llm_model_func=gpt_4o_complete ) @@ -21,13 +19,21 @@ with open("./book.txt") as f: rag.insert(f.read()) # Perform naive search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) # Perform local search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) # Perform global search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) # Perform hybrid search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index b6b953f1..f208177f 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,4 +1,4 @@ -from .lightrag import LightRAG, QueryParam +from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam __version__ = "0.0.6" __author__ = "Zirui Guo" diff --git a/lightrag/base.py b/lightrag/base.py index d677c406..50be4f62 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -12,15 +12,16 @@ TextChunkSchema = TypedDict( T = TypeVar("T") + @dataclass class QueryParam: mode: Literal["local", "global", "hybrid", "naive"] = "global" only_need_context: bool = False response_type: str = "Multiple Paragraphs" top_k: int = 60 - max_token_for_text_unit: int = 4000 + max_token_for_text_unit: int = 4000 max_token_for_global_context: int = 4000 - max_token_for_local_context: int = 4000 + max_token_for_local_context: int = 4000 @dataclass @@ -36,6 +37,7 @@ class StorageNameSpace: """commit the storage operations after querying""" pass + @dataclass class BaseVectorStorage(StorageNameSpace): embedding_func: EmbeddingFunc @@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace): """ raise NotImplementedError + @dataclass class BaseKVStorage(Generic[T], StorageNameSpace): async def all_keys(self) -> list[str]: @@ -72,7 +75,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace): async def drop(self): raise NotImplementedError - + @dataclass class BaseGraphStorage(StorageNameSpace): @@ -113,4 +116,4 @@ class BaseGraphStorage(StorageNameSpace): raise NotImplementedError async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: - raise NotImplementedError("Node embedding is not used in lightrag.") \ No newline at end of file + raise NotImplementedError("Node embedding is not used in lightrag.") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 83312ef6..5137af42 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -3,10 +3,12 @@ import os from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Type, cast, Any -from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM +from typing import Type, cast -from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding +from .llm import ( + gpt_4o_mini_complete, + openai_embedding, +) from .operate import ( chunking_by_token_size, extract_entities, @@ -37,6 +39,7 @@ from .base import ( QueryParam, ) + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: loop = asyncio.get_running_loop() @@ -69,7 +72,6 @@ class LightRAG: "dimensions": 1536, "num_walks": 10, "walk_length": 40, - "num_walks": 10, "window_size": 2, "iterations": 3, "random_seed": 3, @@ -77,13 +79,13 @@ class LightRAG: ) # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) - embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding) + embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) embedding_batch_num: int = 32 embedding_func_max_async: int = 16 # LLM - llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete# - llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' + llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# + llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_max_token_size: int = 32768 llm_model_max_async: int = 16 @@ -98,11 +100,11 @@ class LightRAG: addon_params: dict = field(default_factory=dict) convert_response_to_json_func: callable = convert_response_to_json - def __post_init__(self): + def __post_init__(self): log_file = os.path.join(self.working_dir, "lightrag.log") set_logger(log_file) logger.info(f"Logger initialized for working directory: {self.working_dir}") - + _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -133,30 +135,24 @@ class LightRAG: self.embedding_func ) - self.entities_vdb = ( - self.vector_db_storage_cls( - namespace="entities", - global_config=asdict(self), - embedding_func=self.embedding_func, - meta_fields={"entity_name"} - ) + self.entities_vdb = self.vector_db_storage_cls( + namespace="entities", + global_config=asdict(self), + embedding_func=self.embedding_func, + meta_fields={"entity_name"}, ) - self.relationships_vdb = ( - self.vector_db_storage_cls( - namespace="relationships", - global_config=asdict(self), - embedding_func=self.embedding_func, - meta_fields={"src_id", "tgt_id"} - ) + self.relationships_vdb = self.vector_db_storage_cls( + namespace="relationships", + global_config=asdict(self), + embedding_func=self.embedding_func, + meta_fields={"src_id", "tgt_id"}, ) - self.chunks_vdb = ( - self.vector_db_storage_cls( - namespace="chunks", - global_config=asdict(self), - embedding_func=self.embedding_func, - ) + self.chunks_vdb = self.vector_db_storage_cls( + namespace="chunks", + global_config=asdict(self), + embedding_func=self.embedding_func, ) - + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial(self.llm_model_func, hashing_kv=self.llm_response_cache) ) @@ -177,7 +173,7 @@ class LightRAG: _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if not len(new_docs): - logger.warning(f"All docs are already in the storage") + logger.warning("All docs are already in the storage") return logger.info(f"[New Docs] inserting {len(new_docs)} docs") @@ -203,7 +199,7 @@ class LightRAG: k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } if not len(inserting_chunks): - logger.warning(f"All chunks are already in the storage") + logger.warning("All chunks are already in the storage") return logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") @@ -246,7 +242,7 @@ class LightRAG: def query(self, query: str, param: QueryParam = QueryParam()): loop = always_get_an_event_loop() return loop.run_until_complete(self.aquery(query, param)) - + async def aquery(self, query: str, param: QueryParam = QueryParam()): if param.mode == "local": response = await local_query( @@ -290,7 +286,6 @@ class LightRAG: raise ValueError(f"Unknown mode {param.mode}") await self._query_done() return response - async def _query_done(self): tasks = [] @@ -299,5 +294,3 @@ class LightRAG: continue tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) await asyncio.gather(*tasks) - - diff --git a/lightrag/llm.py b/lightrag/llm.py index 48defb4d..be801e0c 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,9 +1,7 @@ import os import copy import json -import botocore import aioboto3 -import botocore.errorfactory import numpy as np import ollama from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout @@ -13,24 +11,34 @@ from tenacity import ( wait_exponential, retry_if_exception_type, ) -from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM import torch from .base import BaseKVStorage from .utils import compute_args_hash, wrap_embedding_func_with_attrs -import copy + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) async def openai_complete_if_cache( - model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs + model, + prompt, + system_prompt=None, + history_messages=[], + base_url=None, + api_key=None, + **kwargs, ) -> str: if api_key: os.environ["OPENAI_API_KEY"] = api_key - openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + openai_async_client = ( + AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + ) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: @@ -64,43 +72,56 @@ class BedrockError(Exception): retry=retry_if_exception_type((BedrockError)), ) async def bedrock_complete_if_cache( - model, prompt, system_prompt=None, history_messages=[], - aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs + model, + prompt, + system_prompt=None, + history_messages=[], + aws_access_key_id=None, + aws_secret_access_key=None, + aws_session_token=None, + **kwargs, ) -> str: - os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id) - os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key) - os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token) + os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( + "AWS_ACCESS_KEY_ID", aws_access_key_id + ) + os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( + "AWS_SECRET_ACCESS_KEY", aws_secret_access_key + ) + os.environ["AWS_SESSION_TOKEN"] = os.environ.get( + "AWS_SESSION_TOKEN", aws_session_token + ) # Fix message history format messages = [] for history_message in history_messages: message = copy.copy(history_message) - message['content'] = [{'text': message['content']}] + message["content"] = [{"text": message["content"]}] messages.append(message) # Add user prompt - messages.append({'role': "user", 'content': [{'text': prompt}]}) + messages.append({"role": "user", "content": [{"text": prompt}]}) # Initialize Converse API arguments - args = { - 'modelId': model, - 'messages': messages - } + args = {"modelId": model, "messages": messages} # Define system prompt if system_prompt: - args['system'] = [{'text': system_prompt}] + args["system"] = [{"text": system_prompt}] # Map and set up inference parameters inference_params_map = { - 'max_tokens': "maxTokens", - 'top_p': "topP", - 'stop_sequences': "stopSequences" + "max_tokens": "maxTokens", + "top_p": "topP", + "stop_sequences": "stopSequences", } - if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))): - args['inferenceConfig'] = {} + if inference_params := list( + set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) + ): + args["inferenceConfig"] = {} for param in inference_params: - args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param) + args["inferenceConfig"][inference_params_map.get(param, param)] = ( + kwargs.pop(param) + ) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: @@ -112,31 +133,33 @@ async def bedrock_complete_if_cache( # Call model via Converse API session = aioboto3.Session() async with session.client("bedrock-runtime") as bedrock_async_client: - try: response = await bedrock_async_client.converse(**args, **kwargs) except Exception as e: raise BedrockError(e) if hashing_kv is not None: - await hashing_kv.upsert({ - args_hash: { - 'return': response['output']['message']['content'][0]['text'], - 'model': model + await hashing_kv.upsert( + { + args_hash: { + "return": response["output"]["message"]["content"][0]["text"], + "model": model, + } } - }) + ) + + return response["output"]["message"]["content"][0]["text"] - return response['output']['message']['content'][0]['text'] async def hf_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: model_name = model - hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto') - if hf_tokenizer.pad_token == None: + hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto") + if hf_tokenizer.pad_token is None: # print("use eos token") hf_tokenizer.pad_token = hf_tokenizer.eos_token - hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto') + hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: @@ -149,30 +172,51 @@ async def hf_model_if_cache( if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] - input_prompt = '' + input_prompt = "" try: - input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - except: + input_prompt = hf_tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: try: ori_message = copy.deepcopy(messages) - if messages[0]['role'] == "system": - messages[1]['content'] = "" + messages[0]['content'] + "\n" + messages[1]['content'] + if messages[0]["role"] == "system": + messages[1]["content"] = ( + "" + + messages[0]["content"] + + "\n" + + messages[1]["content"] + ) messages = messages[1:] - input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - except: + input_prompt = hf_tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: len_message = len(ori_message) for msgid in range(len_message): - input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'\n' - - input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda") - output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True) + input_prompt = ( + input_prompt + + "<" + + ori_message[msgid]["role"] + + ">" + + ori_message[msgid]["content"] + + "\n" + ) + + input_ids = hf_tokenizer( + input_prompt, return_tensors="pt", padding=True, truncation=True + ).to("cuda") + output = hf_model.generate( + **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True + ) response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True) if hashing_kv is not None: - await hashing_kv.upsert( - {args_hash: {"return": response_text, "model": model}} - ) + await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) return response_text + async def ollama_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -202,6 +246,7 @@ async def ollama_model_if_cache( return result + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -241,7 +286,7 @@ async def bedrock_complete( async def hf_model_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - model_name = kwargs['hashing_kv'].global_config['llm_model_name'] + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await hf_model_if_cache( model_name, prompt, @@ -250,10 +295,11 @@ async def hf_model_complete( **kwargs, ) + async def ollama_model_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - model_name = kwargs['hashing_kv'].global_config['llm_model_name'] + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await ollama_model_if_cache( model_name, prompt, @@ -262,17 +308,25 @@ async def ollama_model_complete( **kwargs, ) + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) -async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray: +async def openai_embedding( + texts: list[str], + model: str = "text-embedding-3-small", + base_url: str = None, + api_key: str = None, +) -> np.ndarray: if api_key: os.environ["OPENAI_API_KEY"] = api_key - openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + openai_async_client = ( + AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) + ) response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" ) @@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions # ) async def bedrock_embedding( - texts: list[str], model: str = "amazon.titan-embed-text-v2:0", - aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray: - os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id) - os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key) - os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token) + texts: list[str], + model: str = "amazon.titan-embed-text-v2:0", + aws_access_key_id=None, + aws_secret_access_key=None, + aws_session_token=None, +) -> np.ndarray: + os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( + "AWS_ACCESS_KEY_ID", aws_access_key_id + ) + os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( + "AWS_SECRET_ACCESS_KEY", aws_secret_access_key + ) + os.environ["AWS_SESSION_TOKEN"] = os.environ.get( + "AWS_SESSION_TOKEN", aws_session_token + ) session = aioboto3.Session() async with session.client("bedrock-runtime") as bedrock_async_client: - if (model_provider := model.split(".")[0]) == "amazon": embed_texts = [] for text in texts: if "v2" in model: - body = json.dumps({ - 'inputText': text, - # 'dimensions': embedding_dim, - 'embeddingTypes': ["float"] - }) + body = json.dumps( + { + "inputText": text, + # 'dimensions': embedding_dim, + "embeddingTypes": ["float"], + } + ) elif "v1" in model: - body = json.dumps({ - 'inputText': text - }) + body = json.dumps({"inputText": text}) else: raise ValueError(f"Model {model} is not supported!") @@ -315,29 +378,27 @@ async def bedrock_embedding( modelId=model, body=body, accept="application/json", - contentType="application/json" + contentType="application/json", ) - response_body = await response.get('body').json() + response_body = await response.get("body").json() - embed_texts.append(response_body['embedding']) + embed_texts.append(response_body["embedding"]) elif model_provider == "cohere": - body = json.dumps({ - 'texts': texts, - 'input_type': "search_document", - 'truncate': "NONE" - }) + body = json.dumps( + {"texts": texts, "input_type": "search_document", "truncate": "NONE"} + ) response = await bedrock_async_client.invoke_model( model=model, body=body, accept="application/json", - contentType="application/json" + contentType="application/json", ) - response_body = json.loads(response.get('body').read()) + response_body = json.loads(response.get("body").read()) - embed_texts = response_body['embeddings'] + embed_texts = response_body["embeddings"] else: raise ValueError(f"Model provider '{model_provider}' is not supported!") @@ -345,12 +406,15 @@ async def bedrock_embedding( async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: - input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids + input_ids = tokenizer( + texts, return_tensors="pt", padding=True, truncation=True + ).input_ids with torch.no_grad(): outputs = embed_model(input_ids) embeddings = outputs.last_hidden_state.mean(dim=1) return embeddings.detach().numpy() + async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: embed_text = [] for text in texts: @@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: return embed_text + if __name__ == "__main__": import asyncio async def main(): - result = await gpt_4o_mini_complete('How are you?') + result = await gpt_4o_mini_complete("How are you?") print(result) asyncio.run(main()) diff --git a/lightrag/operate.py b/lightrag/operate.py index 930ceb2a..a0729cd8 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -25,6 +25,7 @@ from .base import ( ) from .prompt import GRAPH_FIELD_SEP, PROMPTS + def chunking_by_token_size( content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" ): @@ -45,6 +46,7 @@ def chunking_by_token_size( ) return results + async def _handle_entity_relation_summary( entity_or_relation_name: str, description: str, @@ -229,9 +231,10 @@ async def _merge_edges_then_upsert( description=description, keywords=keywords, ) - + return edge_data + async def extract_entities( chunks: dict[str, TextChunkSchema], knwoledge_graph_inst: BaseGraphStorage, @@ -352,7 +355,9 @@ async def extract_entities( logger.warning("Didn't extract any entities, maybe your LLM is not working") return None if not len(all_relationships_data): - logger.warning("Didn't extract any relationships, maybe your LLM is not working") + logger.warning( + "Didn't extract any relationships, maybe your LLM is not working" + ) return None if entity_vdb is not None: @@ -370,7 +375,10 @@ async def extract_entities( compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { "src_id": dp["src_id"], "tgt_id": dp["tgt_id"], - "content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"], + "content": dp["keywords"] + + dp["src_id"] + + dp["tgt_id"] + + dp["description"], } for dp in all_relationships_data } @@ -378,6 +386,7 @@ async def extract_entities( return knwoledge_graph_inst + async def local_query( query, knowledge_graph_inst: BaseGraphStorage, @@ -393,19 +402,24 @@ async def local_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) - + try: keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) - keywords = ', '.join(keywords) - except json.JSONDecodeError as e: + keywords = ", ".join(keywords) + except json.JSONDecodeError: try: - result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip() - result = '{' + result.split('{')[1].split('}')[0] + '}' + result = ( + result.replace(kw_prompt[:-1], "") + .replace("user", "") + .replace("model", "") + .strip() + ) + result = "{" + result.split("{")[1].split("}")[0] + "}" keywords_data = json.loads(result) keywords = keywords_data.get("low_level_keywords", []) - keywords = ', '.join(keywords) + keywords = ", ".join(keywords) # Handle parsing error except json.JSONDecodeError as e: print(f"JSON parsing error: {e}") @@ -430,11 +444,20 @@ async def local_query( query, system_prompt=sys_prompt, ) - if len(response)>len(sys_prompt): - response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip() - + if len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + return response + async def _build_local_query_context( query, knowledge_graph_inst: BaseGraphStorage, @@ -516,6 +539,7 @@ async def _build_local_query_context( ``` """ + async def _find_most_related_text_unit_from_entities( node_datas: list[dict], query_param: QueryParam, @@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities( all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units] return all_text_units + async def _find_most_related_edges_from_entities( node_datas: list[dict], query_param: QueryParam, @@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities( ) return all_edges_data + async def global_query( query, knowledge_graph_inst: BaseGraphStorage, @@ -624,20 +650,25 @@ async def global_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) result = await use_model_func(kw_prompt) - + try: keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) - keywords = ', '.join(keywords) - except json.JSONDecodeError as e: + keywords = ", ".join(keywords) + except json.JSONDecodeError: try: - result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip() - result = '{' + result.split('{')[1].split('}')[0] + '}' + result = ( + result.replace(kw_prompt[:-1], "") + .replace("user", "") + .replace("model", "") + .strip() + ) + result = "{" + result.split("{")[1].split("}")[0] + "}" keywords_data = json.loads(result) keywords = keywords_data.get("high_level_keywords", []) - keywords = ', '.join(keywords) - + keywords = ", ".join(keywords) + except json.JSONDecodeError as e: # Handle parsing error print(f"JSON parsing error: {e}") @@ -651,12 +682,12 @@ async def global_query( text_chunks_db, query_param, ) - + if query_param.only_need_context: return context if context is None: return PROMPTS["fail_response"] - + sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type @@ -665,11 +696,20 @@ async def global_query( query, system_prompt=sys_prompt, ) - if len(response)>len(sys_prompt): - response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip() - + if len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + return response + async def _build_global_query_context( keywords, knowledge_graph_inst: BaseGraphStorage, @@ -679,14 +719,14 @@ async def _build_global_query_context( query_param: QueryParam, ): results = await relationships_vdb.query(keywords, top_k=query_param.top_k) - + if not len(results): return None - + edge_datas = await asyncio.gather( *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results] ) - + if not all([n is not None for n in edge_datas]): logger.warning("Some edges are missing, maybe the storage is damaged") edge_degree = await asyncio.gather( @@ -765,6 +805,7 @@ async def _build_global_query_context( ``` """ + async def _find_most_related_entities_from_relationships( edge_datas: list[dict], query_param: QueryParam, @@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships( for e in edge_datas: entity_names.add(e["src_id"]) entity_names.add(e["tgt_id"]) - + node_datas = await asyncio.gather( *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names] ) @@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships( return node_datas + async def _find_related_text_unit_from_relationships( edge_datas: list[dict], query_param: QueryParam, text_chunks_db: BaseKVStorage[TextChunkSchema], knowledge_graph_inst: BaseGraphStorage, ): - text_units = [ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) for dp in edge_datas @@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships( "data": await text_chunks_db.get_by_id(c_id), "order": index, } - + if any([v is None for v in all_text_units_lookup.values()]): logger.warning("Text chunks are missing, maybe the storage is damaged") all_text_units = [ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None ] - all_text_units = sorted( - all_text_units, key=lambda x: x["order"] - ) + all_text_units = sorted(all_text_units, key=lambda x: x["order"]) all_text_units = truncate_list_by_token_size( all_text_units, key=lambda x: x["data"]["content"], @@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships( return all_text_units + async def hybrid_query( query, knowledge_graph_inst: BaseGraphStorage, @@ -849,24 +889,29 @@ async def hybrid_query( kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query) - + result = await use_model_func(kw_prompt) try: keywords_data = json.loads(result) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) - hl_keywords = ', '.join(hl_keywords) - ll_keywords = ', '.join(ll_keywords) - except json.JSONDecodeError as e: + hl_keywords = ", ".join(hl_keywords) + ll_keywords = ", ".join(ll_keywords) + except json.JSONDecodeError: try: - result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip() - result = '{' + result.split('{')[1].split('}')[0] + '}' + result = ( + result.replace(kw_prompt[:-1], "") + .replace("user", "") + .replace("model", "") + .strip() + ) + result = "{" + result.split("{")[1].split("}")[0] + "}" keywords_data = json.loads(result) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) - hl_keywords = ', '.join(hl_keywords) - ll_keywords = ', '.join(ll_keywords) + hl_keywords = ", ".join(hl_keywords) + ll_keywords = ", ".join(ll_keywords) # Handle parsing error except json.JSONDecodeError as e: print(f"JSON parsing error: {e}") @@ -897,7 +942,7 @@ async def hybrid_query( return context if context is None: return PROMPTS["fail_response"] - + sys_prompt_temp = PROMPTS["rag_response"] sys_prompt = sys_prompt_temp.format( context_data=context, response_type=query_param.response_type @@ -906,53 +951,78 @@ async def hybrid_query( query, system_prompt=sys_prompt, ) - if len(response)>len(sys_prompt): - response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip() + if len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) return response + def combine_contexts(high_level_context, low_level_context): # Function to extract entities, relationships, and sources from context strings def extract_sections(context): - entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) - relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) - sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) - - entities = entities_match.group(1) if entities_match else '' - relationships = relationships_match.group(1) if relationships_match else '' - sources = sources_match.group(1) if sources_match else '' - + entities_match = re.search( + r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL + ) + relationships_match = re.search( + r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL + ) + sources_match = re.search( + r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL + ) + + entities = entities_match.group(1) if entities_match else "" + relationships = relationships_match.group(1) if relationships_match else "" + sources = sources_match.group(1) if sources_match else "" + return entities, relationships, sources - + # Extract sections from both contexts - if high_level_context==None: - warnings.warn("High Level context is None. Return empty High entity/relationship/source") - hl_entities, hl_relationships, hl_sources = '','','' + if high_level_context is None: + warnings.warn( + "High Level context is None. Return empty High entity/relationship/source" + ) + hl_entities, hl_relationships, hl_sources = "", "", "" else: hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) - - if low_level_context==None: - warnings.warn("Low Level context is None. Return empty Low entity/relationship/source") - ll_entities, ll_relationships, ll_sources = '','','' + if low_level_context is None: + warnings.warn( + "Low Level context is None. Return empty Low entity/relationship/source" + ) + ll_entities, ll_relationships, ll_sources = "", "", "" else: ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) - - # Combine and deduplicate the entities - combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n'))) - combined_entities = '\n'.join(combined_entities_set) - + combined_entities_set = set( + filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n")) + ) + combined_entities = "\n".join(combined_entities_set) + # Combine and deduplicate the relationships - combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n'))) - combined_relationships = '\n'.join(combined_relationships_set) - + combined_relationships_set = set( + filter( + None, + hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"), + ) + ) + combined_relationships = "\n".join(combined_relationships_set) + # Combine and deduplicate the sources - combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n'))) - combined_sources = '\n'.join(combined_sources_set) - + combined_sources_set = set( + filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n")) + ) + combined_sources = "\n".join(combined_sources_set) + # Format the combined context return f""" -----Entities----- @@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context): {combined_sources} """ + async def naive_query( query, chunks_vdb: BaseVectorStorage, @@ -996,8 +1067,16 @@ async def naive_query( system_prompt=sys_prompt, ) - if len(response)>len(sys_prompt): - response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip() - - return response + if len(response) > len(sys_prompt): + response = ( + response[len(sys_prompt) :] + .replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + return response diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 5d28e49c..6bd9b638 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["โ ‹", "โ ™", "โ น", "โ ธ", "โ ผ", "โ ด", "โ ฆ", " PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"] -PROMPTS[ - "entity_extraction" -] = """-Goal- +PROMPTS["entity_extraction"] = """-Goal- Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. -Steps- @@ -32,7 +30,7 @@ Format each relationship as ("relationship"{tuple_delimiter}{tupl 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. Format the content-level key words as ("content_keywords"{tuple_delimiter}) - + 4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 5. When finished, output {completion_delimiter} @@ -146,9 +144,7 @@ PROMPTS[ PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." -PROMPTS[ - "rag_response" -] = """---Role--- +PROMPTS["rag_response"] = """---Role--- You are a helpful assistant responding to questions about data in the tables provided. @@ -241,9 +237,7 @@ Output: """ -PROMPTS[ - "naive_rag_response" -] = """You're a helpful assistant +PROMPTS["naive_rag_response"] = """You're a helpful assistant Below are the knowledge you know: {content_data} --- diff --git a/lightrag/storage.py b/lightrag/storage.py index 2f2bb7d8..1f22fc56 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -1,16 +1,11 @@ import asyncio import html -import json import os -from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Union, cast -import pickle -import hnswlib import networkx as nx import numpy as np from nano_vectordb import NanoVectorDB -import xxhash from .utils import load_json, logger, write_json from .base import ( @@ -19,6 +14,7 @@ from .base import ( BaseVectorStorage, ) + @dataclass class JsonKVStorage(BaseKVStorage): def __post_init__(self): @@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage): async def drop(self): self._data = {} + @dataclass class NanoVectorDBStorage(BaseVectorStorage): cosine_better_than_threshold: float = 0.2 def __post_init__(self): - self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) @@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage): async def index_done_callback(self): self._client.save() + @dataclass class NetworkXStorage(BaseGraphStorage): @staticmethod @@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage): graph = graph.copy() graph = cast(nx.Graph, largest_connected_component(graph)) - node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore + node_mapping = { + node: html.unescape(node.upper().strip()) for node in graph.nodes() + } # type: ignore graph = nx.relabel_nodes(graph, node_mapping) return NetworkXStorage._stabilize_graph(graph) diff --git a/lightrag/utils.py b/lightrag/utils.py index 9496cf34..67d094c6 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -16,18 +16,22 @@ ENCODER = None logger = logging.getLogger("lightrag") + def set_logger(log_file: str): logger.setLevel(logging.DEBUG) file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) file_handler.setFormatter(formatter) if not logger.handlers: logger.addHandler(file_handler) + @dataclass class EmbeddingFunc: embedding_dim: int @@ -36,7 +40,8 @@ class EmbeddingFunc: async def __call__(self, *args, **kwargs) -> np.ndarray: return await self.func(*args, **kwargs) - + + def locate_json_string_body_from_string(content: str) -> Union[str, None]: """Locate the JSON string body from a string""" maybe_json_str = re.search(r"{.*}", content, re.DOTALL) @@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: else: return None + def convert_response_to_json(response: str) -> dict: json_str = locate_json_string_body_from_string(response) assert json_str is not None, f"Unable to parse JSON from response: {response}" @@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict: logger.error(f"Failed to parse JSON: {json_str}") raise e from None + def compute_args_hash(*args): return md5(str(args).encode()).hexdigest() + def compute_mdhash_id(content, prefix: str = ""): return prefix + md5(content.encode()).hexdigest() + def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): """Add restriction of maximum async calling times for a async func""" @@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): return final_decro + def wrap_embedding_func_with_attrs(**kwargs): """Wrap a function with attributes""" @@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs): return final_decro + def load_json(file_name): if not os.path.exists(file_name): return None with open(file_name, encoding="utf-8") as f: return json.load(f) + def write_json(json_obj, file_name): with open(file_name, "w", encoding="utf-8") as f: json.dump(json_obj, f, indent=2, ensure_ascii=False) + def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): global ENCODER if ENCODER is None: @@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"): content = ENCODER.decode(tokens) return content + def pack_user_ass_to_openai_messages(*args: str): roles = ["user", "assistant"] return [ {"role": roles[i % 2], "content": content} for i, content in enumerate(args) ] + def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: """Split a string by multiple markers""" if not markers: @@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str] results = re.split("|".join(re.escape(marker) for marker in markers), content) return [r.strip() for r in results if r.strip()] + # Refer the utils functions of the official GraphRAG implementation: # https://github.com/microsoft/graphrag def clean_str(input: Any) -> str: @@ -141,9 +157,11 @@ def clean_str(input: Any) -> str: # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) + def is_float_regex(value): return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) + def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): """Truncate a list of data by token size""" if max_token_size <= 0: @@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data[:i] return list_data + def list_of_list_to_csv(data: list[list]): return "\n".join( [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] ) + def save_data_to_file(data, file_name): - with open(file_name, 'w', encoding='utf-8') as f: - json.dump(data, f, ensure_ascii=False, indent=4) \ No newline at end of file + with open(file_name, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) diff --git a/reproduce/Step_0.py b/reproduce/Step_0.py index 9053aa40..2d97bd14 100644 --- a/reproduce/Step_0.py +++ b/reproduce/Step_0.py @@ -3,11 +3,11 @@ import json import glob import argparse -def extract_unique_contexts(input_directory, output_directory): +def extract_unique_contexts(input_directory, output_directory): os.makedirs(output_directory, exist_ok=True) - jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl')) + jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl")) print(f"Found {len(jsonl_files)} JSONL files.") for file_path in jsonl_files: @@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory): print(f"Processing file: {filename}") try: - with open(file_path, 'r', encoding='utf-8') as infile: + with open(file_path, "r", encoding="utf-8") as infile: for line_number, line in enumerate(infile, start=1): line = line.strip() if not line: continue try: json_obj = json.loads(line) - context = json_obj.get('context') + context = json_obj.get("context") if context and context not in unique_contexts_dict: unique_contexts_dict[context] = None except json.JSONDecodeError as e: - print(f"JSON decoding error in file {filename} at line {line_number}: {e}") + print( + f"JSON decoding error in file {filename} at line {line_number}: {e}" + ) except FileNotFoundError: print(f"File not found: {filename}") continue @@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory): continue unique_contexts_list = list(unique_contexts_dict.keys()) - print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.") + print( + f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}." + ) try: - with open(output_path, 'w', encoding='utf-8') as outfile: + with open(output_path, "w", encoding="utf-8") as outfile: json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4) print(f"Unique `context` entries have been saved to: {output_filename}") except Exception as e: @@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input_dir', type=str, default='../datasets') - parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts') + parser.add_argument("-i", "--input_dir", type=str, default="../datasets") + parser.add_argument( + "-o", "--output_dir", type=str, default="../datasets/unique_contexts" + ) args = parser.parse_args() diff --git a/reproduce/Step_1.py b/reproduce/Step_1.py index 08e497cb..43c44056 100644 --- a/reproduce/Step_1.py +++ b/reproduce/Step_1.py @@ -4,10 +4,11 @@ import time from lightrag import LightRAG + def insert_text(rag, file_path): - with open(file_path, mode='r') as f: + with open(file_path, mode="r") as f: unique_contexts = json.load(f) - + retries = 0 max_retries = 3 while retries < max_retries: @@ -21,6 +22,7 @@ def insert_text(rag, file_path): if retries == max_retries: print("Insertion failed after exceeding the maximum number of retries") + cls = "agriculture" WORKING_DIR = "../{cls}" @@ -29,4 +31,4 @@ if not os.path.exists(WORKING_DIR): rag = LightRAG(working_dir=WORKING_DIR) -insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json") \ No newline at end of file +insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json") diff --git a/reproduce/Step_1_openai_compatible.py b/reproduce/Step_1_openai_compatible.py index b5c6aef3..8e67cfb8 100644 --- a/reproduce/Step_1_openai_compatible.py +++ b/reproduce/Step_1_openai_compatible.py @@ -7,6 +7,7 @@ from lightrag import LightRAG from lightrag.utils import EmbeddingFunc from lightrag.llm import openai_complete_if_cache, openai_embedding + ## For Upstage API # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry async def llm_model_func( @@ -19,22 +20,26 @@ async def llm_model_func( history_messages=history_messages, api_key=os.getenv("UPSTAGE_API_KEY"), base_url="https://api.upstage.ai/v1/solar", - **kwargs + **kwargs, ) + async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, model="solar-embedding-1-large-query", api_key=os.getenv("UPSTAGE_API_KEY"), - base_url="https://api.upstage.ai/v1/solar" + base_url="https://api.upstage.ai/v1/solar", ) + + ## /For Upstage API + def insert_text(rag, file_path): - with open(file_path, mode='r') as f: + with open(file_path, mode="r") as f: unique_contexts = json.load(f) - + retries = 0 max_retries = 3 while retries < max_retries: @@ -48,19 +53,19 @@ def insert_text(rag, file_path): if retries == max_retries: print("Insertion failed after exceeding the maximum number of retries") + cls = "mix" WORKING_DIR = f"../{cls}" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) -rag = LightRAG(working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=4096, - max_token_size=8192, - func=embedding_func - ) - ) +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=4096, max_token_size=8192, func=embedding_func + ), +) insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json") diff --git a/reproduce/Step_2.py b/reproduce/Step_2.py index b00c19b8..557c7714 100644 --- a/reproduce/Step_2.py +++ b/reproduce/Step_2.py @@ -1,8 +1,8 @@ -import os import json from openai import OpenAI from transformers import GPT2Tokenizer + def openai_complete_if_cache( model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -19,24 +19,26 @@ def openai_complete_if_cache( ) return response.choices[0].message.content -tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + def get_summary(context, tot_tokens=2000): tokens = tokenizer.tokenize(context) half_tokens = tot_tokens // 2 - start_tokens = tokens[1000:1000 + half_tokens] - end_tokens = tokens[-(1000 + half_tokens):1000] + start_tokens = tokens[1000 : 1000 + half_tokens] + end_tokens = tokens[-(1000 + half_tokens) : 1000] summary_tokens = start_tokens + end_tokens summary = tokenizer.convert_tokens_to_string(summary_tokens) - + return summary -clses = ['agriculture'] +clses = ["agriculture"] for cls in clses: - with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f: + with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f: unique_contexts = json.load(f) summaries = [get_summary(context) for context in unique_contexts] @@ -67,10 +69,10 @@ for cls in clses: ... """ - result = openai_complete_if_cache(model='gpt-4o', prompt=prompt) + result = openai_complete_if_cache(model="gpt-4o", prompt=prompt) file_path = f"../datasets/questions/{cls}_questions.txt" with open(file_path, "w") as file: file.write(result) - print(f"{cls}_questions written to {file_path}") \ No newline at end of file + print(f"{cls}_questions written to {file_path}") diff --git a/reproduce/Step_3.py b/reproduce/Step_3.py index a79ebd17..a56190fc 100644 --- a/reproduce/Step_3.py +++ b/reproduce/Step_3.py @@ -4,16 +4,18 @@ import asyncio from lightrag import LightRAG, QueryParam from tqdm import tqdm -def extract_queries(file_path): - with open(file_path, 'r') as f: - data = f.read() - - data = data.replace('**', '') - queries = re.findall(r'- Question \d+: (.+)', data) +def extract_queries(file_path): + with open(file_path, "r") as f: + data = f.read() + + data = data.replace("**", "") + + queries = re.findall(r"- Question \d+: (.+)", data) return queries + async def process_query(query_text, rag_instance, query_param): try: result, context = await rag_instance.aquery(query_text, param=query_param) @@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param): except Exception as e: return None, {"query": query_text, "error": str(e)} + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: loop = asyncio.get_event_loop() @@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: asyncio.set_event_loop(loop) return loop -def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file): + +def run_queries_and_save_to_json( + queries, rag_instance, query_param, output_file, error_file +): loop = always_get_an_event_loop() - with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file: + with open(output_file, "a", encoding="utf-8") as result_file, open( + error_file, "a", encoding="utf-8" + ) as err_file: result_file.write("[\n") first_entry = True for query_text in tqdm(queries, desc="Processing queries", unit="query"): - result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param)) + result, error = loop.run_until_complete( + process_query(query_text, rag_instance, query_param) + ) if result: if not first_entry: @@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file result_file.write("\n]") + if __name__ == "__main__": cls = "agriculture" mode = "hybrid" @@ -59,4 +70,6 @@ if __name__ == "__main__": query_param = QueryParam(mode=mode) queries = extract_queries(f"../datasets/questions/{cls}_questions.txt") - run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json") + run_queries_and_save_to_json( + queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json" + ) diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py index 7b3079a9..2be5ea5c 100644 --- a/reproduce/Step_3_openai_compatible.py +++ b/reproduce/Step_3_openai_compatible.py @@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding from lightrag.utils import EmbeddingFunc import numpy as np + ## For Upstage API # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry async def llm_model_func( @@ -20,28 +21,33 @@ async def llm_model_func( history_messages=history_messages, api_key=os.getenv("UPSTAGE_API_KEY"), base_url="https://api.upstage.ai/v1/solar", - **kwargs + **kwargs, ) + async def embedding_func(texts: list[str]) -> np.ndarray: return await openai_embedding( texts, model="solar-embedding-1-large-query", api_key=os.getenv("UPSTAGE_API_KEY"), - base_url="https://api.upstage.ai/v1/solar" + base_url="https://api.upstage.ai/v1/solar", ) + + ## /For Upstage API -def extract_queries(file_path): - with open(file_path, 'r') as f: - data = f.read() - - data = data.replace('**', '') - queries = re.findall(r'- Question \d+: (.+)', data) +def extract_queries(file_path): + with open(file_path, "r") as f: + data = f.read() + + data = data.replace("**", "") + + queries = re.findall(r"- Question \d+: (.+)", data) return queries + async def process_query(query_text, rag_instance, query_param): try: result, context = await rag_instance.aquery(query_text, param=query_param) @@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param): except Exception as e: return None, {"query": query_text, "error": str(e)} + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: try: loop = asyncio.get_event_loop() @@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: asyncio.set_event_loop(loop) return loop -def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file): + +def run_queries_and_save_to_json( + queries, rag_instance, query_param, output_file, error_file +): loop = always_get_an_event_loop() - with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file: + with open(output_file, "a", encoding="utf-8") as result_file, open( + error_file, "a", encoding="utf-8" + ) as err_file: result_file.write("[\n") first_entry = True for query_text in tqdm(queries, desc="Processing queries", unit="query"): - result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param)) + result, error = loop.run_until_complete( + process_query(query_text, rag_instance, query_param) + ) if result: if not first_entry: @@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file result_file.write("\n]") + if __name__ == "__main__": cls = "mix" mode = "hybrid" WORKING_DIR = f"../{cls}" rag = LightRAG(working_dir=WORKING_DIR) - rag = LightRAG(working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=4096, - max_token_size=8192, - func=embedding_func - ) - ) + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=4096, max_token_size=8192, func=embedding_func + ), + ) query_param = QueryParam(mode=mode) - base_dir='../datasets/questions' + base_dir = "../datasets/questions" queries = extract_queries(f"{base_dir}/{cls}_questions.txt") - run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json") + run_queries_and_save_to_json( + queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json" + ) diff --git a/requirements.txt b/requirements.txt index a1054692..d5479dab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ +accelerate aioboto3 -openai -tiktoken -networkx graspologic -nano-vectordb hnswlib -xxhash -tenacity -transformers -torch +nano-vectordb +networkx ollama -accelerate \ No newline at end of file +openai +tenacity +tiktoken +torch +transformers +xxhash From 4945027dc025c73763ecc271017152273a81d86d Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Sat, 19 Oct 2024 21:35:50 +0800 Subject: [PATCH 14/25] Update README.md --- README.md | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a3e5c1b4..e2f7e81a 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
Using Open AI-like APIs -LightRAG also supports Open AI-like chat/embeddings APIs: +* LightRAG also supports Open AI-like chat/embeddings APIs: ```python async def llm_model_func( prompt, system_prompt=None, history_messages=[], **kwargs @@ -130,7 +130,7 @@ rag = LightRAG(
Using Hugging Face Models -If you want to use Hugging Face models, you only need to set LightRAG as follows: +* If you want to use Hugging Face models, you only need to set LightRAG as follows: ```python from lightrag.llm import hf_model_complete, hf_embedding from transformers import AutoModel, AutoTokenizer @@ -156,7 +156,8 @@ rag = LightRAG(
Using Ollama Models -If you want to use Ollama models, you only need to set LightRAG as follows: + +* If you want to use Ollama models, you only need to set LightRAG as follows: ```python from lightrag.llm import ollama_model_complete, ollama_embedding @@ -177,6 +178,29 @@ rag = LightRAG( ), ) ``` + +* Increasing the `num_ctx` parameter: + +1. Pull the model: +```python +ollama pull qwen2 +``` + +2. Display the model file: +```python +ollama show --modelfile qwen2 > Modelfile +``` + +3. Edit the Modelfile by adding the following line: +```python +PARAMETER num_ctx 32768 +``` + +4. Create the modified model: +```python +ollama create -f Modelfile qwen2m +``` +
### Batch Insert @@ -441,6 +465,8 @@ def extract_queries(file_path): โ”œโ”€โ”€ examples โ”‚ โ”œโ”€โ”€ batch_eval.py โ”‚ โ”œโ”€โ”€ generate_query.py +โ”‚ โ”œโ”€โ”€ lightrag_azure_openai_demo.py +โ”‚ โ”œโ”€โ”€ lightrag_bedrock_demo.py โ”‚ โ”œโ”€โ”€ lightrag_hf_demo.py โ”‚ โ”œโ”€โ”€ lightrag_ollama_demo.py โ”‚ โ”œโ”€โ”€ lightrag_openai_compatible_demo.py @@ -459,6 +485,8 @@ def extract_queries(file_path): โ”‚ โ”œโ”€โ”€ Step_1.py โ”‚ โ”œโ”€โ”€ Step_2.py โ”‚ โ””โ”€โ”€ Step_3.py +โ”œโ”€โ”€ .gitignore +โ”œโ”€โ”€ .pre-commit-config.yaml โ”œโ”€โ”€ LICENSE โ”œโ”€โ”€ README.md โ”œโ”€โ”€ requirements.txt From 263cde887156fa2d6108fa8463fdfd16b4b52fb1 Mon Sep 17 00:00:00 2001 From: nongbin Date: Sun, 20 Oct 2024 09:55:52 +0800 Subject: [PATCH 15/25] add visualizing graph --- .gitignore | 1 + .idea/.gitignore | 8 ++++ .idea/LightRAG.iml | 12 ++++++ .idea/inspectionProfiles/Project_Default.xml | 38 +++++++++++++++++++ .../inspectionProfiles/profiles_settings.xml | 6 +++ .idea/misc.xml | 7 ++++ .idea/modules.xml | 8 ++++ .idea/vcs.xml | 6 +++ examples/graph_visual.py | 14 +++++++ requirements.txt | 1 + 10 files changed, 101 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/LightRAG.iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 examples/graph_visual.py diff --git a/.gitignore b/.gitignore index 50f384ec..208668c5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__ dickens/ book.txt lightrag-dev/ +*.idea \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 00000000..13566b81 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/LightRAG.iml b/.idea/LightRAG.iml new file mode 100644 index 00000000..8b8c3954 --- /dev/null +++ b/.idea/LightRAG.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 00000000..c41eaf20 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,38 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 00000000..105ce2da --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 00000000..676ac0f0 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 00000000..145d7086 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000..35eb1ddf --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/examples/graph_visual.py b/examples/graph_visual.py new file mode 100644 index 00000000..72c72bad --- /dev/null +++ b/examples/graph_visual.py @@ -0,0 +1,14 @@ +import networkx as nx +from pyvis.network import Network + +# Load the GraphML file +G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml') + +# Create a Pyvis network +net = Network(notebook=True) + +# Convert NetworkX graph to Pyvis network +net.from_nx(G) + +# Save and display the network +net.show('knowledge_graph.html') \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d5479dab..9cc5b7e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ tiktoken torch transformers xxhash +pyvis \ No newline at end of file From a7e43406a5d6113a5a0483b187652c74868a21b2 Mon Sep 17 00:00:00 2001 From: nongbin Date: Sun, 20 Oct 2024 09:57:14 +0800 Subject: [PATCH 16/25] delete not used files --- .idea/.gitignore | 8 ---- .idea/LightRAG.iml | 12 ------ .idea/inspectionProfiles/Project_Default.xml | 38 ------------------- .../inspectionProfiles/profiles_settings.xml | 6 --- .idea/misc.xml | 7 ---- .idea/modules.xml | 8 ---- .idea/vcs.xml | 6 --- 7 files changed, 85 deletions(-) delete mode 100644 .idea/.gitignore delete mode 100644 .idea/LightRAG.iml delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 13566b81..00000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/.idea/LightRAG.iml b/.idea/LightRAG.iml deleted file mode 100644 index 8b8c3954..00000000 --- a/.idea/LightRAG.iml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index c41eaf20..00000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,38 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2da..00000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index 676ac0f0..00000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 145d7086..00000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1ddf..00000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file From c6585ff89f858b8d39de3eb5d4b71d59a0771a47 Mon Sep 17 00:00:00 2001 From: nongbin Date: Sun, 20 Oct 2024 10:04:34 +0800 Subject: [PATCH 17/25] ignore idea files --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 208668c5..edfbfbfc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ __pycache__ dickens/ book.txt lightrag-dev/ -*.idea \ No newline at end of file +.idea/ \ No newline at end of file From 347e8a97be3ee4e1b87ad0b16f7060e4643132a6 Mon Sep 17 00:00:00 2001 From: hanbin49 <554066527@qq.com> Date: Sun, 20 Oct 2024 11:27:47 +0800 Subject: [PATCH 18/25] 'update' --- examples/vram_management_demo.py | 82 ++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 examples/vram_management_demo.py diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py new file mode 100644 index 00000000..505e4761 --- /dev/null +++ b/examples/vram_management_demo.py @@ -0,0 +1,82 @@ +import os +import time +from lightrag import LightRAG, QueryParam +from lightrag.llm import ollama_model_complete, ollama_embedding +from lightrag.utils import EmbeddingFunc + +# ๅทฅไฝœ็›ฎๅฝ•ๅ’Œๆ–‡ๆœฌๆ–‡ไปถ็›ฎๅฝ•่ทฏๅพ„ +WORKING_DIR = "./dickens" +TEXT_FILES_DIR = "/llm/mt" + +# ๅฆ‚ๆžœๅทฅไฝœ็›ฎๅฝ•ไธๅญ˜ๅœจ๏ผŒๅˆ™ๅˆ›ๅปบ่ฏฅ็›ฎๅฝ• +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# ๅˆๅง‹ๅŒ– LightRAG +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=ollama_model_complete, + llm_model_name="qwen2.5:3b-instruct-max-context", + embedding_func=EmbeddingFunc( + embedding_dim=768, + max_token_size=8192, + func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"), + ), +) + +# ่ฏปๅ– TEXT_FILES_DIR ็›ฎๅฝ•ไธ‹ๆ‰€ๆœ‰็š„ .txt ๆ–‡ไปถ +texts = [] +for filename in os.listdir(TEXT_FILES_DIR): + if filename.endswith('.txt'): + file_path = os.path.join(TEXT_FILES_DIR, filename) + with open(file_path, 'r', encoding='utf-8') as file: + texts.append(file.read()) + +# ๆ‰น้‡ๆ’ๅ…ฅๆ–‡ๆœฌๅˆฐ LightRAG๏ผŒๅธฆๆœ‰้‡่ฏ•ๆœบๅˆถ +def insert_texts_with_retry(rag, texts, retries=3, delay=5): + for _ in range(retries): + try: + rag.insert(texts) + return + except Exception as e: + print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...") + time.sleep(delay) + raise RuntimeError("Failed to insert texts after multiple retries.") + +insert_texts_with_retry(rag, texts) + +# ๆ‰ง่กŒไธๅŒ็ฑปๅž‹็š„ๆŸฅ่ฏข๏ผŒๅนถๅค„็†ๆฝœๅœจ็š„้”™่ฏฏ +try: + print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) +except Exception as e: + print(f"Error performing naive search: {e}") + +try: + print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) +except Exception as e: + print(f"Error performing local search: {e}") + +try: + print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) +except Exception as e: + print(f"Error performing global search: {e}") + +try: + print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) +except Exception as e: + print(f"Error performing hybrid search: {e}") + +# ๆธ…็† VRAM ่ต„ๆบ็š„ๅ‡ฝๆ•ฐ +def clear_vram(): + os.system("sudo nvidia-smi --gpu-reset") + +# ๅฎšๆœŸๆธ…็† VRAM ไปฅ้˜ฒๆญขๆบขๅ‡บ +clear_vram_interval = 3600 # ๆฏๅฐๆ—ถๆธ…็†ไธ€ๆฌก +start_time = time.time() + +while True: + current_time = time.time() + if current_time - start_time > clear_vram_interval: + clear_vram() + start_time = current_time + time.sleep(60) # ๆฏๅˆ†้’Ÿๆฃ€ๆŸฅไธ€ๆฌกๆ—ถ้—ด \ No newline at end of file From a716e628e370719e0fdcb847e4cd9b4212cc72eb Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Sun, 20 Oct 2024 18:08:49 +0800 Subject: [PATCH 19/25] Add vram_management_demo.py --- examples/vram_management_demo.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py index 505e4761..ec750254 100644 --- a/examples/vram_management_demo.py +++ b/examples/vram_management_demo.py @@ -4,15 +4,15 @@ from lightrag import LightRAG, QueryParam from lightrag.llm import ollama_model_complete, ollama_embedding from lightrag.utils import EmbeddingFunc -# ๅทฅไฝœ็›ฎๅฝ•ๅ’Œๆ–‡ๆœฌๆ–‡ไปถ็›ฎๅฝ•่ทฏๅพ„ +# Working directory and the directory path for text files WORKING_DIR = "./dickens" TEXT_FILES_DIR = "/llm/mt" -# ๅฆ‚ๆžœๅทฅไฝœ็›ฎๅฝ•ไธๅญ˜ๅœจ๏ผŒๅˆ™ๅˆ›ๅปบ่ฏฅ็›ฎๅฝ• +# Create the working directory if it doesn't exist if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) -# ๅˆๅง‹ๅŒ– LightRAG +# Initialize LightRAG rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=ollama_model_complete, @@ -24,7 +24,7 @@ rag = LightRAG( ), ) -# ่ฏปๅ– TEXT_FILES_DIR ็›ฎๅฝ•ไธ‹ๆ‰€ๆœ‰็š„ .txt ๆ–‡ไปถ +# Read all .txt files from the TEXT_FILES_DIR directory texts = [] for filename in os.listdir(TEXT_FILES_DIR): if filename.endswith('.txt'): @@ -32,7 +32,7 @@ for filename in os.listdir(TEXT_FILES_DIR): with open(file_path, 'r', encoding='utf-8') as file: texts.append(file.read()) -# ๆ‰น้‡ๆ’ๅ…ฅๆ–‡ๆœฌๅˆฐ LightRAG๏ผŒๅธฆๆœ‰้‡่ฏ•ๆœบๅˆถ +# Batch insert texts into LightRAG with a retry mechanism def insert_texts_with_retry(rag, texts, retries=3, delay=5): for _ in range(retries): try: @@ -45,7 +45,7 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5): insert_texts_with_retry(rag, texts) -# ๆ‰ง่กŒไธๅŒ็ฑปๅž‹็š„ๆŸฅ่ฏข๏ผŒๅนถๅค„็†ๆฝœๅœจ็š„้”™่ฏฏ +# Perform different types of queries and handle potential errors try: print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) except Exception as e: @@ -66,12 +66,12 @@ try: except Exception as e: print(f"Error performing hybrid search: {e}") -# ๆธ…็† VRAM ่ต„ๆบ็š„ๅ‡ฝๆ•ฐ +# Function to clear VRAM resources def clear_vram(): os.system("sudo nvidia-smi --gpu-reset") -# ๅฎšๆœŸๆธ…็† VRAM ไปฅ้˜ฒๆญขๆบขๅ‡บ -clear_vram_interval = 3600 # ๆฏๅฐๆ—ถๆธ…็†ไธ€ๆฌก +# Regularly clear VRAM to prevent overflow +clear_vram_interval = 3600 # Clear once every hour start_time = time.time() while True: @@ -79,4 +79,4 @@ while True: if current_time - start_time > clear_vram_interval: clear_vram() start_time = current_time - time.sleep(60) # ๆฏๅˆ†้’Ÿๆฃ€ๆŸฅไธ€ๆฌกๆ—ถ้—ด \ No newline at end of file + time.sleep(60) # Check the time every minute From ae4aafb525b2366499b1d9cf5dd2e92731464569 Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Sun, 20 Oct 2024 18:10:00 +0800 Subject: [PATCH 20/25] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e2f7e81a..bf996f82 100644 --- a/README.md +++ b/README.md @@ -470,7 +470,8 @@ def extract_queries(file_path): โ”‚ โ”œโ”€โ”€ lightrag_hf_demo.py โ”‚ โ”œโ”€โ”€ lightrag_ollama_demo.py โ”‚ โ”œโ”€โ”€ lightrag_openai_compatible_demo.py -โ”‚ โ””โ”€โ”€ lightrag_openai_demo.py +โ”‚ โ”œโ”€โ”€ lightrag_openai_demo.py +โ”‚ โ””โ”€โ”€ vram_management_demo.py โ”œโ”€โ”€ lightrag โ”‚ โ”œโ”€โ”€ __init__.py โ”‚ โ”œโ”€โ”€ base.py From c800fa48435fab8d2aca945e68d5f9f52c988f9e Mon Sep 17 00:00:00 2001 From: zrguo <49157727+LarFii@users.noreply.github.com> Date: Sun, 20 Oct 2024 18:22:43 +0800 Subject: [PATCH 21/25] Update README.md --- README.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/README.md b/README.md index bf996f82..c8d6e312 100644 --- a/README.md +++ b/README.md @@ -218,6 +218,26 @@ rag = LightRAG(working_dir="./dickens") with open("./newText.txt") as f: rag.insert(f.read()) ``` + +### Graph Visualization + +* Generate html file +```python +import networkx as nx +from pyvis.network import Network + +# Load the GraphML file +G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml') + +# Create a Pyvis network +net = Network(notebook=True) + +# Convert NetworkX graph to Pyvis network +net.from_nx(G) + +# Save and display the network +net.show('knowledge_graph.html') +``` ## Evaluation ### Dataset The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). @@ -465,6 +485,7 @@ def extract_queries(file_path): โ”œโ”€โ”€ examples โ”‚ โ”œโ”€โ”€ batch_eval.py โ”‚ โ”œโ”€โ”€ generate_query.py +โ”‚ โ”œโ”€โ”€ graph_visual.py โ”‚ โ”œโ”€โ”€ lightrag_azure_openai_demo.py โ”‚ โ”œโ”€โ”€ lightrag_bedrock_demo.py โ”‚ โ”œโ”€โ”€ lightrag_hf_demo.py From f400b02b0f23401907a1aab004ab7bbc39615364 Mon Sep 17 00:00:00 2001 From: nongbin Date: Sun, 20 Oct 2024 21:17:09 +0800 Subject: [PATCH 22/25] make graph visualization become colorful --- examples/graph_visual.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/graph_visual.py b/examples/graph_visual.py index 72c72bad..b455e6de 100644 --- a/examples/graph_visual.py +++ b/examples/graph_visual.py @@ -1,5 +1,6 @@ import networkx as nx from pyvis.network import Network +import random # Load the GraphML file G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml') @@ -10,5 +11,9 @@ net = Network(notebook=True) # Convert NetworkX graph to Pyvis network net.from_nx(G) +# Add colors to nodes +for node in net.nodes: + node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) + # Save and display the network net.show('knowledge_graph.html') \ No newline at end of file From 8e9005baad5a3fba1324ddc9e11060f00e9a1b29 Mon Sep 17 00:00:00 2001 From: LarFii <834462287@qq.com> Date: Sun, 20 Oct 2024 23:08:26 +0800 Subject: [PATCH 23/25] Add visualization methods --- .gitignore | 3 +- README.md | 141 +++++++++++++++++- ...ph_visual.py => graph_visual_with_html.py} | 0 examples/graph_visual_with_neo4j.py | 118 +++++++++++++++ lightrag/utils.py | 49 ++++++ 5 files changed, 308 insertions(+), 3 deletions(-) rename examples/{graph_visual.py => graph_visual_with_html.py} (100%) create mode 100644 examples/graph_visual_with_neo4j.py diff --git a/.gitignore b/.gitignore index edfbfbfc..5a41ae32 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__ dickens/ book.txt lightrag-dev/ -.idea/ \ No newline at end of file +.idea/ +dist/ \ No newline at end of file diff --git a/README.md b/README.md index c8d6e312..89e50aa0 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## ๐ŸŽ‰ News +- [x] [2024.10.20]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe add two methods to visualize the graph. - [x] [2024.10.18]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWeโ€™ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author! - [x] [2024.10.17]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ๐ŸŽ‰๐ŸŽ‰ - [x] [2024.10.16]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! @@ -221,7 +222,11 @@ with open("./newText.txt") as f: ### Graph Visualization -* Generate html file +
+ Graph visualization with html + +* The following code can be found in `examples/graph_visual_with_html.py` + ```python import networkx as nx from pyvis.network import Network @@ -238,6 +243,137 @@ net.from_nx(G) # Save and display the network net.show('knowledge_graph.html') ``` + +
+ +
+ Graph visualization with Neo4j + +* The following code can be found in `examples/graph_visual_with_neo4j.py` + +```python +import os +import json +from lightrag.utils import xml_to_json +from neo4j import GraphDatabase + +# Constants +WORKING_DIR = "./dickens" +BATCH_SIZE_NODES = 500 +BATCH_SIZE_EDGES = 100 + +# Neo4j connection credentials +NEO4J_URI = "bolt://localhost:7687" +NEO4J_USERNAME = "neo4j" +NEO4J_PASSWORD = "your_password" + +def convert_xml_to_json(xml_path, output_path): + """Converts XML file to JSON and saves the output.""" + if not os.path.exists(xml_path): + print(f"Error: File not found - {xml_path}") + return None + + json_data = xml_to_json(xml_path) + if json_data: + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(json_data, f, ensure_ascii=False, indent=2) + print(f"JSON file created: {output_path}") + return json_data + else: + print("Failed to create JSON data") + return None + +def process_in_batches(tx, query, data, batch_size): + """Process data in batches and execute the given query.""" + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) + +def main(): + # Paths + xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml') + json_file = os.path.join(WORKING_DIR, 'graph_data.json') + + # Convert XML to JSON + json_data = convert_xml_to_json(xml_file, json_file) + if json_data is None: + return + + # Load nodes and edges + nodes = json_data.get('nodes', []) + edges = json_data.get('edges', []) + + # Neo4j queries + create_nodes_query = """ + UNWIND $nodes AS node + MERGE (e:Entity {id: node.id}) + SET e.entity_type = node.entity_type, + e.description = node.description, + e.source_id = node.source_id, + e.displayName = node.id + REMOVE e:Entity + WITH e, node + CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode + RETURN count(*) + """ + + create_edges_query = """ + UNWIND $edges AS edge + MATCH (source {id: edge.source}) + MATCH (target {id: edge.target}) + WITH source, target, edge, + CASE + WHEN edge.keywords CONTAINS 'lead' THEN 'lead' + WHEN edge.keywords CONTAINS 'participate' THEN 'participate' + WHEN edge.keywords CONTAINS 'uses' THEN 'uses' + WHEN edge.keywords CONTAINS 'located' THEN 'located' + WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs' + ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '') + END AS relType + CALL apoc.create.relationship(source, relType, { + weight: edge.weight, + description: edge.description, + keywords: edge.keywords, + source_id: edge.source_id + }, target) YIELD rel + RETURN count(*) + """ + + set_displayname_and_labels_query = """ + MATCH (n) + SET n.displayName = n.id + WITH n + CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node + RETURN count(*) + """ + + # Create a Neo4j driver + driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + + try: + # Execute queries in batches + with driver.session() as session: + # Insert nodes in batches + session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES) + + # Insert edges in batches + session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES) + + # Set displayName and labels + session.run(set_displayname_and_labels_query) + + except Exception as e: + print(f"Error occurred: {e}") + + finally: + driver.close() + +if __name__ == "__main__": + main() +``` + +
+ ## Evaluation ### Dataset The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). @@ -484,8 +620,9 @@ def extract_queries(file_path): . โ”œโ”€โ”€ examples โ”‚ โ”œโ”€โ”€ batch_eval.py +โ”‚ โ”œโ”€โ”€ graph_visual_with_html.py +โ”‚ โ”œโ”€โ”€ graph_visual_with_neo4j.py โ”‚ โ”œโ”€โ”€ generate_query.py -โ”‚ โ”œโ”€โ”€ graph_visual.py โ”‚ โ”œโ”€โ”€ lightrag_azure_openai_demo.py โ”‚ โ”œโ”€โ”€ lightrag_bedrock_demo.py โ”‚ โ”œโ”€โ”€ lightrag_hf_demo.py diff --git a/examples/graph_visual.py b/examples/graph_visual_with_html.py similarity index 100% rename from examples/graph_visual.py rename to examples/graph_visual_with_html.py diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py new file mode 100644 index 00000000..22dde368 --- /dev/null +++ b/examples/graph_visual_with_neo4j.py @@ -0,0 +1,118 @@ +import os +import json +from lightrag.utils import xml_to_json +from neo4j import GraphDatabase + +# Constants +WORKING_DIR = "./dickens" +BATCH_SIZE_NODES = 500 +BATCH_SIZE_EDGES = 100 + +# Neo4j connection credentials +NEO4J_URI = "bolt://localhost:7687" +NEO4J_USERNAME = "neo4j" +NEO4J_PASSWORD = "your_password" + +def convert_xml_to_json(xml_path, output_path): + """Converts XML file to JSON and saves the output.""" + if not os.path.exists(xml_path): + print(f"Error: File not found - {xml_path}") + return None + + json_data = xml_to_json(xml_path) + if json_data: + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(json_data, f, ensure_ascii=False, indent=2) + print(f"JSON file created: {output_path}") + return json_data + else: + print("Failed to create JSON data") + return None + +def process_in_batches(tx, query, data, batch_size): + """Process data in batches and execute the given query.""" + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) + +def main(): + # Paths + xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml') + json_file = os.path.join(WORKING_DIR, 'graph_data.json') + + # Convert XML to JSON + json_data = convert_xml_to_json(xml_file, json_file) + if json_data is None: + return + + # Load nodes and edges + nodes = json_data.get('nodes', []) + edges = json_data.get('edges', []) + + # Neo4j queries + create_nodes_query = """ + UNWIND $nodes AS node + MERGE (e:Entity {id: node.id}) + SET e.entity_type = node.entity_type, + e.description = node.description, + e.source_id = node.source_id, + e.displayName = node.id + REMOVE e:Entity + WITH e, node + CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode + RETURN count(*) + """ + + create_edges_query = """ + UNWIND $edges AS edge + MATCH (source {id: edge.source}) + MATCH (target {id: edge.target}) + WITH source, target, edge, + CASE + WHEN edge.keywords CONTAINS 'lead' THEN 'lead' + WHEN edge.keywords CONTAINS 'participate' THEN 'participate' + WHEN edge.keywords CONTAINS 'uses' THEN 'uses' + WHEN edge.keywords CONTAINS 'located' THEN 'located' + WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs' + ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '') + END AS relType + CALL apoc.create.relationship(source, relType, { + weight: edge.weight, + description: edge.description, + keywords: edge.keywords, + source_id: edge.source_id + }, target) YIELD rel + RETURN count(*) + """ + + set_displayname_and_labels_query = """ + MATCH (n) + SET n.displayName = n.id + WITH n + CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node + RETURN count(*) + """ + + # Create a Neo4j driver + driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + + try: + # Execute queries in batches + with driver.session() as session: + # Insert nodes in batches + session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES) + + # Insert edges in batches + session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES) + + # Set displayName and labels + session.run(set_displayname_and_labels_query) + + except Exception as e: + print(f"Error occurred: {e}") + + finally: + driver.close() + +if __name__ == "__main__": + main() diff --git a/lightrag/utils.py b/lightrag/utils.py index 67d094c6..9a68c16b 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from functools import wraps from hashlib import md5 from typing import Any, Union +import xml.etree.ElementTree as ET import numpy as np import tiktoken @@ -183,3 +184,51 @@ def list_of_list_to_csv(data: list[list]): def save_data_to_file(data, file_name): with open(file_name, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) + +def xml_to_json(xml_file): + try: + tree = ET.parse(xml_file) + root = tree.getroot() + + # Print the root element's tag and attributes to confirm the file has been correctly loaded + print(f"Root element: {root.tag}") + print(f"Root attributes: {root.attrib}") + + data = { + "nodes": [], + "edges": [] + } + + # Use namespace + namespace = {'': 'http://graphml.graphdrawing.org/xmlns'} + + for node in root.findall('.//node', namespace): + node_data = { + "id": node.get('id').strip('"'), + "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "", + "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "", + "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else "" + } + data["nodes"].append(node_data) + + for edge in root.findall('.//edge', namespace): + edge_data = { + "source": edge.get('source').strip('"'), + "target": edge.get('target').strip('"'), + "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0, + "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "", + "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "", + "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else "" + } + data["edges"].append(edge_data) + + # Print the number of nodes and edges found + print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") + + return data + except ET.ParseError as e: + print(f"Error parsing XML file: {e}") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None From 95c5ffef5a130a5949924d2c33ba9cf7e559fd97 Mon Sep 17 00:00:00 2001 From: LarFii <834462287@qq.com> Date: Sun, 20 Oct 2024 23:10:07 +0800 Subject: [PATCH 24/25] Update __init__.py --- lightrag/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index f208177f..db81e005 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "0.0.6" +__version__ = "0.0.7" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" From 57e9604ce6526a48a7f60281962c2f14c0cbea76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=9C=A8Data=20Intelligence=20Lab=40HKU=E2=9C=A8?= <118165258+HKUDS@users.noreply.github.com> Date: Mon, 21 Oct 2024 01:18:46 +0800 Subject: [PATCH 25/25] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 89e50aa0..b345c1d1 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ This repository hosts the code of LightRAG. The structure of this code is based ## ๐ŸŽ‰ News -- [x] [2024.10.20]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe add two methods to visualize the graph. +- [x] [2024.10.20]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWeโ€™ve added a new feature to LightRAG: Graph Visualization. - [x] [2024.10.18]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWeโ€™ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author! - [x] [2024.10.17]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขWe have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! ๐ŸŽ‰๐ŸŽ‰ - [x] [2024.10.16]๐ŸŽฏ๐ŸŽฏ๐Ÿ“ข๐Ÿ“ขLightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!