From 32464fab4e3b30af41ddf2928642a1a12b2b2c5b Mon Sep 17 00:00:00 2001
From: Sanketh Kumar
Date: Sat, 19 Oct 2024 09:43:17 +0530
Subject: [PATCH] chore: added pre-commit-hooks and ruff formatting for
commit-hooks
---
.gitignore | 3 +-
.pre-commit-config.yaml | 22 ++
README.md | 50 ++---
examples/batch_eval.py | 38 ++--
examples/generate_query.py | 9 +-
examples/lightrag_azure_openai_demo.py | 2 +-
examples/lightrag_bedrock_demo.py | 13 +-
examples/lightrag_hf_demo.py | 35 ++-
examples/lightrag_ollama_demo.py | 25 ++-
examples/lightrag_openai_compatible_demo.py | 32 ++-
examples/lightrag_openai_demo.py | 22 +-
lightrag/__init__.py | 2 +-
lightrag/base.py | 11 +-
lightrag/lightrag.py | 65 +++---
lightrag/llm.py | 223 ++++++++++++-------
lightrag/operate.py | 229 +++++++++++++-------
lightrag/prompt.py | 14 +-
lightrag/storage.py | 15 +-
lightrag/utils.py | 28 ++-
reproduce/Step_0.py | 24 +-
reproduce/Step_1.py | 8 +-
reproduce/Step_1_openai_compatible.py | 29 ++-
reproduce/Step_2.py | 20 +-
reproduce/Step_3.py | 33 ++-
reproduce/Step_3_openai_compatible.py | 58 +++--
requirements.txt | 18 +-
26 files changed, 635 insertions(+), 393 deletions(-)
create mode 100644 .pre-commit-config.yaml
diff --git a/.gitignore b/.gitignore
index cb457220..50f384ec 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
__pycache__
*.egg-info
dickens/
-book.txt
\ No newline at end of file
+book.txt
+lightrag-dev/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..db531bb6
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,22 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v5.0.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: requirements-txt-fixer
+
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.6.4
+ hooks:
+ - id: ruff-format
+ - id: ruff
+ args: [--fix]
+
+
+ - repo: https://github.com/mgedmin/check-manifest
+ rev: "0.49"
+ hooks:
+ - id: check-manifest
+ stages: [manual]
diff --git a/README.md b/README.md
index d0ed8a35..b3a04957 100644
--- a/README.md
+++ b/README.md
@@ -16,16 +16,16 @@
-
+
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).

-## 🎉 News
+## 🎉 News
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
-- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
-- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
+- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
+- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
## Install
@@ -83,7 +83,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
Using Open AI-like APIs
-LightRAG also support Open AI-like chat/embeddings APIs:
+LightRAG also supports Open AI-like chat/embeddings APIs:
```python
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
@@ -120,7 +120,7 @@ rag = LightRAG(
Using Hugging Face Models
-
+
If you want to use Hugging Face models, you only need to set LightRAG as follows:
```python
from lightrag.llm import hf_model_complete, hf_embedding
@@ -136,7 +136,7 @@ rag = LightRAG(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
- texts,
+ texts,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
)
@@ -148,7 +148,7 @@ rag = LightRAG(
Using Ollama Models
If you want to use Ollama models, you only need to set LightRAG as follows:
-
+
```python
from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -162,7 +162,7 @@ rag = LightRAG(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
- texts,
+ texts,
embed_model="nomic-embed-text"
)
),
@@ -187,14 +187,14 @@ with open("./newText.txt") as f:
```
## Evaluation
### Dataset
-The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
+The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
### Generate Query
-LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
+LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
Prompt
-
+
```python
Given the following description of a dataset:
@@ -219,18 +219,18 @@ Output the results in the following structure:
...
```
-
+
### Batch Eval
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
Prompt
-
+
```python
---Role---
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
---Goal---
-You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
+You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
@@ -294,7 +294,7 @@ Output your evaluation in the following JSON format:
| **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
| **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
-## Reproduce
+## Reproduce
All the code can be found in the `./reproduce` directory.
### Step-0 Extract Unique Contexts
@@ -302,7 +302,7 @@ First, we need to extract unique contexts in the datasets.
Code
-
+
```python
def extract_unique_contexts(input_directory, output_directory):
@@ -361,12 +361,12 @@ For the extracted contexts, we insert them into the LightRAG system.
Code
-
+
```python
def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
unique_contexts = json.load(f)
-
+
retries = 0
max_retries = 3
while retries < max_retries:
@@ -384,11 +384,11 @@ def insert_text(rag, file_path):
### Step-2 Generate Queries
-We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
+We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
Code
-
+
```python
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
@@ -401,7 +401,7 @@ def get_summary(context, tot_tokens=2000):
summary_tokens = start_tokens + end_tokens
summary = tokenizer.convert_tokens_to_string(summary_tokens)
-
+
return summary
```
@@ -411,12 +411,12 @@ For the queries generated in Step-2, we will extract them and query LightRAG.
Code
-
+
```python
def extract_queries(file_path):
with open(file_path, 'r') as f:
data = f.read()
-
+
data = data.replace('**', '')
queries = re.findall(r'- Question \d+: (.+)', data)
@@ -470,7 +470,7 @@ def extract_queries(file_path):
```python
@article{guo2024lightrag,
-title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
+title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
year={2024},
eprint={2410.05779},
diff --git a/examples/batch_eval.py b/examples/batch_eval.py
index 4601d267..a85e1ede 100644
--- a/examples/batch_eval.py
+++ b/examples/batch_eval.py
@@ -1,4 +1,3 @@
-import os
import re
import json
import jsonlines
@@ -9,28 +8,28 @@ from openai import OpenAI
def batch_eval(query_file, result1_file, result2_file, output_file_path):
client = OpenAI()
- with open(query_file, 'r') as f:
+ with open(query_file, "r") as f:
data = f.read()
- queries = re.findall(r'- Question \d+: (.+)', data)
+ queries = re.findall(r"- Question \d+: (.+)", data)
- with open(result1_file, 'r') as f:
+ with open(result1_file, "r") as f:
answers1 = json.load(f)
- answers1 = [i['result'] for i in answers1]
+ answers1 = [i["result"] for i in answers1]
- with open(result2_file, 'r') as f:
+ with open(result2_file, "r") as f:
answers2 = json.load(f)
- answers2 = [i['result'] for i in answers2]
+ answers2 = [i["result"] for i in answers2]
requests = []
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
- sys_prompt = f"""
+ sys_prompt = """
---Role---
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
"""
prompt = f"""
- You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
+ You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
}}
"""
-
request_data = {
"custom_id": f"request-{i+1}",
"method": "POST",
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
"model": "gpt-4o-mini",
"messages": [
{"role": "system", "content": sys_prompt},
- {"role": "user", "content": prompt}
+ {"role": "user", "content": prompt},
],
- }
+ },
}
-
+
requests.append(request_data)
- with jsonlines.open(output_file_path, mode='w') as writer:
+ with jsonlines.open(output_file_path, mode="w") as writer:
for request in requests:
writer.write(request)
print(f"Batch API requests written to {output_file_path}")
batch_input_file = client.files.create(
- file=open(output_file_path, "rb"),
- purpose="batch"
+ file=open(output_file_path, "rb"), purpose="batch"
)
batch_input_file_id = batch_input_file.id
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
input_file_id=batch_input_file_id,
endpoint="/v1/chat/completions",
completion_window="24h",
- metadata={
- "description": "nightly eval job"
- }
+ metadata={"description": "nightly eval job"},
)
- print(f'Batch {batch.id} has been created.')
+ print(f"Batch {batch.id} has been created.")
+
if __name__ == "__main__":
- batch_eval()
\ No newline at end of file
+ batch_eval()
diff --git a/examples/generate_query.py b/examples/generate_query.py
index 0ae82f40..705b23d3 100644
--- a/examples/generate_query.py
+++ b/examples/generate_query.py
@@ -1,9 +1,8 @@
-import os
-
from openai import OpenAI
# os.environ["OPENAI_API_KEY"] = ""
+
def openai_complete_if_cache(
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -47,10 +46,10 @@ if __name__ == "__main__":
...
"""
- result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt)
+ result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
- file_path = f"./queries.txt"
+ file_path = "./queries.txt"
with open(file_path, "w") as file:
file.write(result)
- print(f"Queries written to {file_path}")
\ No newline at end of file
+ print(f"Queries written to {file_path}")
diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py
index 62282a25..e29a6a9d 100644
--- a/examples/lightrag_azure_openai_demo.py
+++ b/examples/lightrag_azure_openai_demo.py
@@ -122,4 +122,4 @@ print("\nResult (Global):")
print(rag.query(query_text, param=QueryParam(mode="global")))
print("\nResult (Hybrid):")
-print(rag.query(query_text, param=QueryParam(mode="hybrid")))
\ No newline at end of file
+print(rag.query(query_text, param=QueryParam(mode="hybrid")))
diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py
index c515922e..7e18ea57 100644
--- a/examples/lightrag_bedrock_demo.py
+++ b/examples/lightrag_bedrock_demo.py
@@ -20,13 +20,11 @@ rag = LightRAG(
llm_model_func=bedrock_complete,
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
embedding_func=EmbeddingFunc(
- embedding_dim=1024,
- max_token_size=8192,
- func=bedrock_embedding
- )
+ embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
+ ),
)
-with open("./book.txt", 'r', encoding='utf-8') as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
for mode in ["naive", "local", "global", "hybrid"]:
@@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]:
print(f"| {mode.capitalize()} |")
print("+-" + "-" * len(mode) + "-+\n")
print(
- rag.query(
- "What are the top themes in this story?",
- param=QueryParam(mode=mode)
- )
+ rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
)
diff --git a/examples/lightrag_hf_demo.py b/examples/lightrag_hf_demo.py
index baf62bdb..87312307 100644
--- a/examples/lightrag_hf_demo.py
+++ b/examples/lightrag_hf_demo.py
@@ -1,10 +1,9 @@
import os
-import sys
from lightrag import LightRAG, QueryParam
from lightrag.llm import hf_model_complete, hf_embedding
from lightrag.utils import EmbeddingFunc
-from transformers import AutoModel,AutoTokenizer
+from transformers import AutoModel, AutoTokenizer
WORKING_DIR = "./dickens"
@@ -13,16 +12,20 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
- llm_model_func=hf_model_complete,
- llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
+ llm_model_func=hf_model_complete,
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
- texts,
- tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
- embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
- )
+ texts,
+ tokenizer=AutoTokenizer.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ embed_model=AutoModel.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ ),
),
)
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
# Perform local search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
# Perform global search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
# Perform hybrid search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py
index a2d04aa6..c61b71c0 100644
--- a/examples/lightrag_ollama_demo.py
+++ b/examples/lightrag_ollama_demo.py
@@ -11,15 +11,12 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
- llm_model_func=ollama_model_complete,
- llm_model_name='your_model_name',
+ llm_model_func=ollama_model_complete,
+ llm_model_name="your_model_name",
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
- func=lambda texts: ollama_embedding(
- texts,
- embed_model="nomic-embed-text"
- )
+ func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
),
)
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
# Perform local search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
# Perform global search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
# Perform hybrid search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py
index 75ecc118..fbad1190 100644
--- a/examples/lightrag_openai_compatible_demo.py
+++ b/examples/lightrag_openai_compatible_demo.py
@@ -6,10 +6,11 @@ from lightrag.utils import EmbeddingFunc
import numpy as np
WORKING_DIR = "./dickens"
-
+
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
+
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -20,17 +21,19 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
- **kwargs
+ **kwargs,
)
+
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
- base_url="https://api.upstage.ai/v1/solar"
+ base_url="https://api.upstage.ai/v1/solar",
)
+
# function test
async def test_funcs():
result = await llm_model_func("How are you?")
@@ -39,6 +42,7 @@ async def test_funcs():
result = await embedding_func(["How are you?"])
print("embedding_func: ", result)
+
asyncio.run(test_funcs())
@@ -46,10 +50,8 @@ rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
- embedding_dim=4096,
- max_token_size=8192,
- func=embedding_func
- )
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
+ ),
)
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
# Perform local search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
# Perform global search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
# Perform hybrid search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py
index fb1f055c..a6e7f3b2 100644
--- a/examples/lightrag_openai_demo.py
+++ b/examples/lightrag_openai_demo.py
@@ -1,9 +1,7 @@
import os
-import sys
from lightrag import LightRAG, QueryParam
-from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
-from transformers import AutoModel,AutoTokenizer
+from lightrag.llm import gpt_4o_mini_complete
WORKING_DIR = "./dickens"
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(
working_dir=WORKING_DIR,
- llm_model_func=gpt_4o_mini_complete
+ llm_model_func=gpt_4o_mini_complete,
# llm_model_func=gpt_4o_complete
)
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
# Perform local search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
# Perform global search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
# Perform hybrid search
-print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/lightrag/__init__.py b/lightrag/__init__.py
index b6b953f1..f208177f 100644
--- a/lightrag/__init__.py
+++ b/lightrag/__init__.py
@@ -1,4 +1,4 @@
-from .lightrag import LightRAG, QueryParam
+from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "0.0.6"
__author__ = "Zirui Guo"
diff --git a/lightrag/base.py b/lightrag/base.py
index d677c406..50be4f62 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -12,15 +12,16 @@ TextChunkSchema = TypedDict(
T = TypeVar("T")
+
@dataclass
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
top_k: int = 60
- max_token_for_text_unit: int = 4000
+ max_token_for_text_unit: int = 4000
max_token_for_global_context: int = 4000
- max_token_for_local_context: int = 4000
+ max_token_for_local_context: int = 4000
@dataclass
@@ -36,6 +37,7 @@ class StorageNameSpace:
"""commit the storage operations after querying"""
pass
+
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
"""
raise NotImplementedError
+
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
@@ -72,7 +75,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
async def drop(self):
raise NotImplementedError
-
+
@dataclass
class BaseGraphStorage(StorageNameSpace):
@@ -113,4 +116,4 @@ class BaseGraphStorage(StorageNameSpace):
raise NotImplementedError
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
- raise NotImplementedError("Node embedding is not used in lightrag.")
\ No newline at end of file
+ raise NotImplementedError("Node embedding is not used in lightrag.")
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 83312ef6..5137af42 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -3,10 +3,12 @@ import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
-from typing import Type, cast, Any
-from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
+from typing import Type, cast
-from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
+from .llm import (
+ gpt_4o_mini_complete,
+ openai_embedding,
+)
from .operate import (
chunking_by_token_size,
extract_entities,
@@ -37,6 +39,7 @@ from .base import (
QueryParam,
)
+
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
@@ -69,7 +72,6 @@ class LightRAG:
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
- "num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
@@ -77,13 +79,13 @@ class LightRAG:
)
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
- embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
+ embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
- llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
- llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
+ llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
+ llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
@@ -98,11 +100,11 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
- def __post_init__(self):
+ def __post_init__(self):
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
-
+
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -133,30 +135,24 @@ class LightRAG:
self.embedding_func
)
- self.entities_vdb = (
- self.vector_db_storage_cls(
- namespace="entities",
- global_config=asdict(self),
- embedding_func=self.embedding_func,
- meta_fields={"entity_name"}
- )
+ self.entities_vdb = self.vector_db_storage_cls(
+ namespace="entities",
+ global_config=asdict(self),
+ embedding_func=self.embedding_func,
+ meta_fields={"entity_name"},
)
- self.relationships_vdb = (
- self.vector_db_storage_cls(
- namespace="relationships",
- global_config=asdict(self),
- embedding_func=self.embedding_func,
- meta_fields={"src_id", "tgt_id"}
- )
+ self.relationships_vdb = self.vector_db_storage_cls(
+ namespace="relationships",
+ global_config=asdict(self),
+ embedding_func=self.embedding_func,
+ meta_fields={"src_id", "tgt_id"},
)
- self.chunks_vdb = (
- self.vector_db_storage_cls(
- namespace="chunks",
- global_config=asdict(self),
- embedding_func=self.embedding_func,
- )
+ self.chunks_vdb = self.vector_db_storage_cls(
+ namespace="chunks",
+ global_config=asdict(self),
+ embedding_func=self.embedding_func,
)
-
+
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
)
@@ -177,7 +173,7 @@ class LightRAG:
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
- logger.warning(f"All docs are already in the storage")
+ logger.warning("All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
@@ -203,7 +199,7 @@ class LightRAG:
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
- logger.warning(f"All chunks are already in the storage")
+ logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
@@ -246,7 +242,7 @@ class LightRAG:
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
-
+
async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local":
response = await local_query(
@@ -290,7 +286,6 @@ class LightRAG:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
-
async def _query_done(self):
tasks = []
@@ -299,5 +294,3 @@ class LightRAG:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)
-
-
diff --git a/lightrag/llm.py b/lightrag/llm.py
index 48defb4d..be801e0c 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -1,9 +1,7 @@
import os
import copy
import json
-import botocore
import aioboto3
-import botocore.errorfactory
import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -13,24 +11,34 @@ from tenacity import (
wait_exponential,
retry_if_exception_type,
)
-from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
-import copy
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_complete_if_cache(
- model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ base_url=None,
+ api_key=None,
+ **kwargs,
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
- openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
+ openai_async_client = (
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
+ )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
@@ -64,43 +72,56 @@ class BedrockError(Exception):
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
- model, prompt, system_prompt=None, history_messages=[],
- aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ aws_session_token=None,
+ **kwargs,
) -> str:
- os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
- os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
- os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
+ )
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
+ )
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
+ "AWS_SESSION_TOKEN", aws_session_token
+ )
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
- message['content'] = [{'text': message['content']}]
+ message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
- messages.append({'role': "user", 'content': [{'text': prompt}]})
+ messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
- args = {
- 'modelId': model,
- 'messages': messages
- }
+ args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
- args['system'] = [{'text': system_prompt}]
+ args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
- 'max_tokens': "maxTokens",
- 'top_p': "topP",
- 'stop_sequences': "stopSequences"
+ "max_tokens": "maxTokens",
+ "top_p": "topP",
+ "stop_sequences": "stopSequences",
}
- if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
- args['inferenceConfig'] = {}
+ if inference_params := list(
+ set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
+ ):
+ args["inferenceConfig"] = {}
for param in inference_params:
- args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
+ args["inferenceConfig"][inference_params_map.get(param, param)] = (
+ kwargs.pop(param)
+ )
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
-
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
if hashing_kv is not None:
- await hashing_kv.upsert({
- args_hash: {
- 'return': response['output']['message']['content'][0]['text'],
- 'model': model
+ await hashing_kv.upsert(
+ {
+ args_hash: {
+ "return": response["output"]["message"]["content"][0]["text"],
+ "model": model,
+ }
}
- })
+ )
+
+ return response["output"]["message"]["content"][0]["text"]
- return response['output']['message']['content'][0]['text']
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = model
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
- if hf_tokenizer.pad_token == None:
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
+ if hf_tokenizer.pad_token is None:
# print("use eos token")
hf_tokenizer.pad_token = hf_tokenizer.eos_token
- hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
- input_prompt = ''
+ input_prompt = ""
try:
- input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- except:
+ input_prompt = hf_tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
try:
ori_message = copy.deepcopy(messages)
- if messages[0]['role'] == "system":
- messages[1]['content'] = "" + messages[0]['content'] + "\n" + messages[1]['content']
+ if messages[0]["role"] == "system":
+ messages[1]["content"] = (
+ ""
+ + messages[0]["content"]
+ + "\n"
+ + messages[1]["content"]
+ )
messages = messages[1:]
- input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- except:
+ input_prompt = hf_tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
len_message = len(ori_message)
for msgid in range(len_message):
- input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+''+ori_message[msgid]['role']+'>\n'
-
- input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
- output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
+ input_prompt = (
+ input_prompt
+ + "<"
+ + ori_message[msgid]["role"]
+ + ">"
+ + ori_message[msgid]["content"]
+ + ""
+ + ori_message[msgid]["role"]
+ + ">\n"
+ )
+
+ input_ids = hf_tokenizer(
+ input_prompt, return_tensors="pt", padding=True, truncation=True
+ ).to("cuda")
+ output = hf_model.generate(
+ **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
+ )
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None:
- await hashing_kv.upsert(
- {args_hash: {"return": response_text, "model": model}}
- )
+ await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
+
async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
return result
+
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -241,7 +286,7 @@ async def bedrock_complete(
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
- model_name = kwargs['hashing_kv'].global_config['llm_model_name']
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache(
model_name,
prompt,
@@ -250,10 +295,11 @@ async def hf_model_complete(
**kwargs,
)
+
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
- model_name = kwargs['hashing_kv'].global_config['llm_model_name']
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache(
model_name,
prompt,
@@ -262,17 +308,25 @@ async def ollama_model_complete(
**kwargs,
)
+
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
-async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
+async def openai_embedding(
+ texts: list[str],
+ model: str = "text-embedding-3-small",
+ base_url: str = None,
+ api_key: str = None,
+) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
- openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
+ openai_async_client = (
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
+ )
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embedding(
- texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
- aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
- os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
- os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
- os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
+ texts: list[str],
+ model: str = "amazon.titan-embed-text-v2:0",
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ aws_session_token=None,
+) -> np.ndarray:
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
+ )
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
+ )
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
+ "AWS_SESSION_TOKEN", aws_session_token
+ )
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
-
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
- body = json.dumps({
- 'inputText': text,
- # 'dimensions': embedding_dim,
- 'embeddingTypes': ["float"]
- })
+ body = json.dumps(
+ {
+ "inputText": text,
+ # 'dimensions': embedding_dim,
+ "embeddingTypes": ["float"],
+ }
+ )
elif "v1" in model:
- body = json.dumps({
- 'inputText': text
- })
+ body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
@@ -315,29 +378,27 @@ async def bedrock_embedding(
modelId=model,
body=body,
accept="application/json",
- contentType="application/json"
+ contentType="application/json",
)
- response_body = await response.get('body').json()
+ response_body = await response.get("body").json()
- embed_texts.append(response_body['embedding'])
+ embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
- body = json.dumps({
- 'texts': texts,
- 'input_type': "search_document",
- 'truncate': "NONE"
- })
+ body = json.dumps(
+ {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
+ )
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
- contentType="application/json"
+ contentType="application/json",
)
- response_body = json.loads(response.get('body').read())
+ response_body = json.loads(response.get("body").read())
- embed_texts = response_body['embeddings']
+ embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
@@ -345,12 +406,15 @@ async def bedrock_embedding(
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
- input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
+ input_ids = tokenizer(
+ texts, return_tensors="pt", padding=True, truncation=True
+ ).input_ids
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy()
+
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
embed_text = []
for text in texts:
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
return embed_text
+
if __name__ == "__main__":
import asyncio
async def main():
- result = await gpt_4o_mini_complete('How are you?')
+ result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 930ceb2a..a0729cd8 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -25,6 +25,7 @@ from .base import (
)
from .prompt import GRAPH_FIELD_SEP, PROMPTS
+
def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
@@ -45,6 +46,7 @@ def chunking_by_token_size(
)
return results
+
async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
description=description,
keywords=keywords,
)
-
+
return edge_data
+
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
@@ -352,7 +355,9 @@ async def extract_entities(
logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None
if not len(all_relationships_data):
- logger.warning("Didn't extract any relationships, maybe your LLM is not working")
+ logger.warning(
+ "Didn't extract any relationships, maybe your LLM is not working"
+ )
return None
if entity_vdb is not None:
@@ -370,7 +375,10 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
- "content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
+ "content": dp["keywords"]
+ + dp["src_id"]
+ + dp["tgt_id"]
+ + dp["description"],
}
for dp in all_relationships_data
}
@@ -378,6 +386,7 @@ async def extract_entities(
return knwoledge_graph_inst
+
async def local_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -393,19 +402,24 @@ async def local_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
-
+
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
- keywords = ', '.join(keywords)
- except json.JSONDecodeError as e:
+ keywords = ", ".join(keywords)
+ except json.JSONDecodeError:
try:
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
- result = '{' + result.split('{')[1].split('}')[0] + '}'
+ result = (
+ result.replace(kw_prompt[:-1], "")
+ .replace("user", "")
+ .replace("model", "")
+ .strip()
+ )
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", [])
- keywords = ', '.join(keywords)
+ keywords = ", ".join(keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -430,11 +444,20 @@ async def local_query(
query,
system_prompt=sys_prompt,
)
- if len(response)>len(sys_prompt):
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip()
-
+ if len(response) > len(sys_prompt):
+ response = (
+ response.replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
+
return response
+
async def _build_local_query_context(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -516,6 +539,7 @@ async def _build_local_query_context(
```
"""
+
async def _find_most_related_text_unit_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units
+
async def _find_most_related_edges_from_entities(
node_datas: list[dict],
query_param: QueryParam,
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
)
return all_edges_data
+
async def global_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -624,20 +650,25 @@ async def global_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
result = await use_model_func(kw_prompt)
-
+
try:
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
- keywords = ', '.join(keywords)
- except json.JSONDecodeError as e:
+ keywords = ", ".join(keywords)
+ except json.JSONDecodeError:
try:
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
- result = '{' + result.split('{')[1].split('}')[0] + '}'
+ result = (
+ result.replace(kw_prompt[:-1], "")
+ .replace("user", "")
+ .replace("model", "")
+ .strip()
+ )
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", [])
- keywords = ', '.join(keywords)
-
+ keywords = ", ".join(keywords)
+
except json.JSONDecodeError as e:
# Handle parsing error
print(f"JSON parsing error: {e}")
@@ -651,12 +682,12 @@ async def global_query(
text_chunks_db,
query_param,
)
-
+
if query_param.only_need_context:
return context
if context is None:
return PROMPTS["fail_response"]
-
+
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -665,11 +696,20 @@ async def global_query(
query,
system_prompt=sys_prompt,
)
- if len(response)>len(sys_prompt):
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip()
-
+ if len(response) > len(sys_prompt):
+ response = (
+ response.replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
+
return response
+
async def _build_global_query_context(
keywords,
knowledge_graph_inst: BaseGraphStorage,
@@ -679,14 +719,14 @@ async def _build_global_query_context(
query_param: QueryParam,
):
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
-
+
if not len(results):
return None
-
+
edge_datas = await asyncio.gather(
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
)
-
+
if not all([n is not None for n in edge_datas]):
logger.warning("Some edges are missing, maybe the storage is damaged")
edge_degree = await asyncio.gather(
@@ -765,6 +805,7 @@ async def _build_global_query_context(
```
"""
+
async def _find_most_related_entities_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
for e in edge_datas:
entity_names.add(e["src_id"])
entity_names.add(e["tgt_id"])
-
+
node_datas = await asyncio.gather(
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
)
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
return node_datas
+
async def _find_related_text_unit_from_relationships(
edge_datas: list[dict],
query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage,
):
-
text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
"data": await text_chunks_db.get_by_id(c_id),
"order": index,
}
-
+
if any([v is None for v in all_text_units_lookup.values()]):
logger.warning("Text chunks are missing, maybe the storage is damaged")
all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
]
- all_text_units = sorted(
- all_text_units, key=lambda x: x["order"]
- )
+ all_text_units = sorted(all_text_units, key=lambda x: x["order"])
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
+
async def hybrid_query(
query,
knowledge_graph_inst: BaseGraphStorage,
@@ -849,24 +889,29 @@ async def hybrid_query(
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query)
-
+
result = await use_model_func(kw_prompt)
try:
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
- hl_keywords = ', '.join(hl_keywords)
- ll_keywords = ', '.join(ll_keywords)
- except json.JSONDecodeError as e:
+ hl_keywords = ", ".join(hl_keywords)
+ ll_keywords = ", ".join(ll_keywords)
+ except json.JSONDecodeError:
try:
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
- result = '{' + result.split('{')[1].split('}')[0] + '}'
+ result = (
+ result.replace(kw_prompt[:-1], "")
+ .replace("user", "")
+ .replace("model", "")
+ .strip()
+ )
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
- hl_keywords = ', '.join(hl_keywords)
- ll_keywords = ', '.join(ll_keywords)
+ hl_keywords = ", ".join(hl_keywords)
+ ll_keywords = ", ".join(ll_keywords)
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
@@ -897,7 +942,7 @@ async def hybrid_query(
return context
if context is None:
return PROMPTS["fail_response"]
-
+
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
@@ -906,53 +951,78 @@ async def hybrid_query(
query,
system_prompt=sys_prompt,
)
- if len(response)>len(sys_prompt):
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip()
+ if len(response) > len(sys_prompt):
+ response = (
+ response.replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
return response
+
def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings
def extract_sections(context):
- entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
- relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
- sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
-
- entities = entities_match.group(1) if entities_match else ''
- relationships = relationships_match.group(1) if relationships_match else ''
- sources = sources_match.group(1) if sources_match else ''
-
+ entities_match = re.search(
+ r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
+ )
+ relationships_match = re.search(
+ r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
+ )
+ sources_match = re.search(
+ r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
+ )
+
+ entities = entities_match.group(1) if entities_match else ""
+ relationships = relationships_match.group(1) if relationships_match else ""
+ sources = sources_match.group(1) if sources_match else ""
+
return entities, relationships, sources
-
+
# Extract sections from both contexts
- if high_level_context==None:
- warnings.warn("High Level context is None. Return empty High entity/relationship/source")
- hl_entities, hl_relationships, hl_sources = '','',''
+ if high_level_context is None:
+ warnings.warn(
+ "High Level context is None. Return empty High entity/relationship/source"
+ )
+ hl_entities, hl_relationships, hl_sources = "", "", ""
else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
-
- if low_level_context==None:
- warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
- ll_entities, ll_relationships, ll_sources = '','',''
+ if low_level_context is None:
+ warnings.warn(
+ "Low Level context is None. Return empty Low entity/relationship/source"
+ )
+ ll_entities, ll_relationships, ll_sources = "", "", ""
else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
-
-
# Combine and deduplicate the entities
- combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
- combined_entities = '\n'.join(combined_entities_set)
-
+ combined_entities_set = set(
+ filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
+ )
+ combined_entities = "\n".join(combined_entities_set)
+
# Combine and deduplicate the relationships
- combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
- combined_relationships = '\n'.join(combined_relationships_set)
-
+ combined_relationships_set = set(
+ filter(
+ None,
+ hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
+ )
+ )
+ combined_relationships = "\n".join(combined_relationships_set)
+
# Combine and deduplicate the sources
- combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
- combined_sources = '\n'.join(combined_sources_set)
-
+ combined_sources_set = set(
+ filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
+ )
+ combined_sources = "\n".join(combined_sources_set)
+
# Format the combined context
return f"""
-----Entities-----
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
{combined_sources}
"""
+
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
@@ -996,8 +1067,16 @@ async def naive_query(
system_prompt=sys_prompt,
)
- if len(response)>len(sys_prompt):
- response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('','').replace('','').strip()
-
- return response
+ if len(response) > len(sys_prompt):
+ response = (
+ response[len(sys_prompt) :]
+ .replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
+ return response
diff --git a/lightrag/prompt.py b/lightrag/prompt.py
index 5d28e49c..6bd9b638 100644
--- a/lightrag/prompt.py
+++ b/lightrag/prompt.py
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
-PROMPTS[
- "entity_extraction"
-] = """-Goal-
+PROMPTS["entity_extraction"] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps-
@@ -32,7 +30,7 @@ Format each relationship as ("relationship"{tuple_delimiter}{tupl
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
Format the content-level key words as ("content_keywords"{tuple_delimiter})
-
+
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
5. When finished, output {completion_delimiter}
@@ -146,9 +144,7 @@ PROMPTS[
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
-PROMPTS[
- "rag_response"
-] = """---Role---
+PROMPTS["rag_response"] = """---Role---
You are a helpful assistant responding to questions about data in the tables provided.
@@ -241,9 +237,7 @@ Output:
"""
-PROMPTS[
- "naive_rag_response"
-] = """You're a helpful assistant
+PROMPTS["naive_rag_response"] = """You're a helpful assistant
Below are the knowledge you know:
{content_data}
---
diff --git a/lightrag/storage.py b/lightrag/storage.py
index 2f2bb7d8..1f22fc56 100644
--- a/lightrag/storage.py
+++ b/lightrag/storage.py
@@ -1,16 +1,11 @@
import asyncio
import html
-import json
import os
-from collections import defaultdict
-from dataclasses import dataclass, field
+from dataclasses import dataclass
from typing import Any, Union, cast
-import pickle
-import hnswlib
import networkx as nx
import numpy as np
from nano_vectordb import NanoVectorDB
-import xxhash
from .utils import load_json, logger, write_json
from .base import (
@@ -19,6 +14,7 @@ from .base import (
BaseVectorStorage,
)
+
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self):
self._data = {}
+
@dataclass
class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
-
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self):
self._client.save()
+
@dataclass
class NetworkXStorage(BaseGraphStorage):
@staticmethod
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph))
- node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
+ node_mapping = {
+ node: html.unescape(node.upper().strip()) for node in graph.nodes()
+ } # type: ignore
graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph)
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 9496cf34..67d094c6 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -16,18 +16,22 @@ ENCODER = None
logger = logging.getLogger("lightrag")
+
def set_logger(log_file: str):
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ formatter = logging.Formatter(
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ )
file_handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(file_handler)
+
@dataclass
class EmbeddingFunc:
embedding_dim: int
@@ -36,7 +40,8 @@ class EmbeddingFunc:
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
-
+
+
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
else:
return None
+
def convert_response_to_json(response: str) -> dict:
json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}"
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
logger.error(f"Failed to parse JSON: {json_str}")
raise e from None
+
def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()
+
def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()
+
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
return final_decro
+
def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes"""
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
return final_decro
+
def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
+
def write_json(json_obj, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
+
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER
if ENCODER is None:
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
content = ENCODER.decode(tokens)
return content
+
def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
+
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
+
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
+
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
+
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
if max_token_size <= 0:
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i]
return list_data
+
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
)
+
def save_data_to_file(data, file_name):
- with open(file_name, 'w', encoding='utf-8') as f:
- json.dump(data, f, ensure_ascii=False, indent=4)
\ No newline at end of file
+ with open(file_name, "w", encoding="utf-8") as f:
+ json.dump(data, f, ensure_ascii=False, indent=4)
diff --git a/reproduce/Step_0.py b/reproduce/Step_0.py
index 9053aa40..2d97bd14 100644
--- a/reproduce/Step_0.py
+++ b/reproduce/Step_0.py
@@ -3,11 +3,11 @@ import json
import glob
import argparse
-def extract_unique_contexts(input_directory, output_directory):
+def extract_unique_contexts(input_directory, output_directory):
os.makedirs(output_directory, exist_ok=True)
- jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
+ jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
print(f"Found {len(jsonl_files)} JSONL files.")
for file_path in jsonl_files:
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
print(f"Processing file: {filename}")
try:
- with open(file_path, 'r', encoding='utf-8') as infile:
+ with open(file_path, "r", encoding="utf-8") as infile:
for line_number, line in enumerate(infile, start=1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
- context = json_obj.get('context')
+ context = json_obj.get("context")
if context and context not in unique_contexts_dict:
unique_contexts_dict[context] = None
except json.JSONDecodeError as e:
- print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
+ print(
+ f"JSON decoding error in file {filename} at line {line_number}: {e}"
+ )
except FileNotFoundError:
print(f"File not found: {filename}")
continue
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
continue
unique_contexts_list = list(unique_contexts_dict.keys())
- print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
+ print(
+ f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
+ )
try:
- with open(output_path, 'w', encoding='utf-8') as outfile:
+ with open(output_path, "w", encoding="utf-8") as outfile:
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
print(f"Unique `context` entries have been saved to: {output_filename}")
except Exception as e:
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
- parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
+ parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
+ parser.add_argument(
+ "-o", "--output_dir", type=str, default="../datasets/unique_contexts"
+ )
args = parser.parse_args()
diff --git a/reproduce/Step_1.py b/reproduce/Step_1.py
index 08e497cb..43c44056 100644
--- a/reproduce/Step_1.py
+++ b/reproduce/Step_1.py
@@ -4,10 +4,11 @@ import time
from lightrag import LightRAG
+
def insert_text(rag, file_path):
- with open(file_path, mode='r') as f:
+ with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
-
+
retries = 0
max_retries = 3
while retries < max_retries:
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
+
cls = "agriculture"
WORKING_DIR = "../{cls}"
@@ -29,4 +31,4 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG(working_dir=WORKING_DIR)
-insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
\ No newline at end of file
+insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
diff --git a/reproduce/Step_1_openai_compatible.py b/reproduce/Step_1_openai_compatible.py
index b5c6aef3..8e67cfb8 100644
--- a/reproduce/Step_1_openai_compatible.py
+++ b/reproduce/Step_1_openai_compatible.py
@@ -7,6 +7,7 @@ from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm import openai_complete_if_cache, openai_embedding
+
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
@@ -19,22 +20,26 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
- **kwargs
+ **kwargs,
)
+
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
- base_url="https://api.upstage.ai/v1/solar"
+ base_url="https://api.upstage.ai/v1/solar",
)
+
+
## /For Upstage API
+
def insert_text(rag, file_path):
- with open(file_path, mode='r') as f:
+ with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
-
+
retries = 0
max_retries = 3
while retries < max_retries:
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
+
cls = "mix"
WORKING_DIR = f"../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
-rag = LightRAG(working_dir=WORKING_DIR,
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=4096,
- max_token_size=8192,
- func=embedding_func
- )
- )
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
+ ),
+)
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
diff --git a/reproduce/Step_2.py b/reproduce/Step_2.py
index b00c19b8..557c7714 100644
--- a/reproduce/Step_2.py
+++ b/reproduce/Step_2.py
@@ -1,8 +1,8 @@
-import os
import json
from openai import OpenAI
from transformers import GPT2Tokenizer
+
def openai_complete_if_cache(
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -19,24 +19,26 @@ def openai_complete_if_cache(
)
return response.choices[0].message.content
-tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+
+tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+
def get_summary(context, tot_tokens=2000):
tokens = tokenizer.tokenize(context)
half_tokens = tot_tokens // 2
- start_tokens = tokens[1000:1000 + half_tokens]
- end_tokens = tokens[-(1000 + half_tokens):1000]
+ start_tokens = tokens[1000 : 1000 + half_tokens]
+ end_tokens = tokens[-(1000 + half_tokens) : 1000]
summary_tokens = start_tokens + end_tokens
summary = tokenizer.convert_tokens_to_string(summary_tokens)
-
+
return summary
-clses = ['agriculture']
+clses = ["agriculture"]
for cls in clses:
- with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
+ with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
unique_contexts = json.load(f)
summaries = [get_summary(context) for context in unique_contexts]
@@ -67,10 +69,10 @@ for cls in clses:
...
"""
- result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
+ result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
file_path = f"../datasets/questions/{cls}_questions.txt"
with open(file_path, "w") as file:
file.write(result)
- print(f"{cls}_questions written to {file_path}")
\ No newline at end of file
+ print(f"{cls}_questions written to {file_path}")
diff --git a/reproduce/Step_3.py b/reproduce/Step_3.py
index a79ebd17..a56190fc 100644
--- a/reproduce/Step_3.py
+++ b/reproduce/Step_3.py
@@ -4,16 +4,18 @@ import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
-def extract_queries(file_path):
- with open(file_path, 'r') as f:
- data = f.read()
-
- data = data.replace('**', '')
- queries = re.findall(r'- Question \d+: (.+)', data)
+def extract_queries(file_path):
+ with open(file_path, "r") as f:
+ data = f.read()
+
+ data = data.replace("**", "")
+
+ queries = re.findall(r"- Question \d+: (.+)", data)
return queries
+
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
except Exception as e:
return None, {"query": query_text, "error": str(e)}
+
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop)
return loop
-def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
+
+def run_queries_and_save_to_json(
+ queries, rag_instance, query_param, output_file, error_file
+):
loop = always_get_an_event_loop()
- with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
+ with open(output_file, "a", encoding="utf-8") as result_file, open(
+ error_file, "a", encoding="utf-8"
+ ) as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
- result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
+ result, error = loop.run_until_complete(
+ process_query(query_text, rag_instance, query_param)
+ )
if result:
if not first_entry:
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
result_file.write("\n]")
+
if __name__ == "__main__":
cls = "agriculture"
mode = "hybrid"
@@ -59,4 +70,6 @@ if __name__ == "__main__":
query_param = QueryParam(mode=mode)
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
- run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
+ run_queries_and_save_to_json(
+ queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
+ )
diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py
index 7b3079a9..2be5ea5c 100644
--- a/reproduce/Step_3_openai_compatible.py
+++ b/reproduce/Step_3_openai_compatible.py
@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
+
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
@@ -20,28 +21,33 @@ async def llm_model_func(
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
- **kwargs
+ **kwargs,
)
+
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
- base_url="https://api.upstage.ai/v1/solar"
+ base_url="https://api.upstage.ai/v1/solar",
)
+
+
## /For Upstage API
-def extract_queries(file_path):
- with open(file_path, 'r') as f:
- data = f.read()
-
- data = data.replace('**', '')
- queries = re.findall(r'- Question \d+: (.+)', data)
+def extract_queries(file_path):
+ with open(file_path, "r") as f:
+ data = f.read()
+
+ data = data.replace("**", "")
+
+ queries = re.findall(r"- Question \d+: (.+)", data)
return queries
+
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
except Exception as e:
return None, {"query": query_text, "error": str(e)}
+
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop)
return loop
-def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
+
+def run_queries_and_save_to_json(
+ queries, rag_instance, query_param, output_file, error_file
+):
loop = always_get_an_event_loop()
- with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
+ with open(output_file, "a", encoding="utf-8") as result_file, open(
+ error_file, "a", encoding="utf-8"
+ ) as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
- result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
+ result, error = loop.run_until_complete(
+ process_query(query_text, rag_instance, query_param)
+ )
if result:
if not first_entry:
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
result_file.write("\n]")
+
if __name__ == "__main__":
cls = "mix"
mode = "hybrid"
WORKING_DIR = f"../{cls}"
rag = LightRAG(working_dir=WORKING_DIR)
- rag = LightRAG(working_dir=WORKING_DIR,
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=4096,
- max_token_size=8192,
- func=embedding_func
- )
- )
+ rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
+ ),
+ )
query_param = QueryParam(mode=mode)
- base_dir='../datasets/questions'
+ base_dir = "../datasets/questions"
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
- run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
+ run_queries_and_save_to_json(
+ queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
+ )
diff --git a/requirements.txt b/requirements.txt
index a1054692..d5479dab 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,13 +1,13 @@
+accelerate
aioboto3
-openai
-tiktoken
-networkx
graspologic
-nano-vectordb
hnswlib
-xxhash
-tenacity
-transformers
-torch
+nano-vectordb
+networkx
ollama
-accelerate
\ No newline at end of file
+openai
+tenacity
+tiktoken
+torch
+transformers
+xxhash