chore: added pre-commit-hooks and ruff formatting for commit-hooks
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,3 +2,4 @@ __pycache__
|
|||||||
*.egg-info
|
*.egg-info
|
||||||
dickens/
|
dickens/
|
||||||
book.txt
|
book.txt
|
||||||
|
lightrag-dev/
|
||||||
|
22
.pre-commit-config.yaml
Normal file
22
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: requirements-txt-fixer
|
||||||
|
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.4
|
||||||
|
hooks:
|
||||||
|
- id: ruff-format
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
|
||||||
|
|
||||||
|
- repo: https://github.com/mgedmin/check-manifest
|
||||||
|
rev: "0.49"
|
||||||
|
hooks:
|
||||||
|
- id: check-manifest
|
||||||
|
stages: [manual]
|
@@ -83,7 +83,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
|
|||||||
<details>
|
<details>
|
||||||
<summary> Using Open AI-like APIs </summary>
|
<summary> Using Open AI-like APIs </summary>
|
||||||
|
|
||||||
LightRAG also support Open AI-like chat/embeddings APIs:
|
LightRAG also supports Open AI-like chat/embeddings APIs:
|
||||||
```python
|
```python
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
@@ -187,10 +187,10 @@ with open("./newText.txt") as f:
|
|||||||
```
|
```
|
||||||
## Evaluation
|
## Evaluation
|
||||||
### Dataset
|
### Dataset
|
||||||
The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
||||||
|
|
||||||
### Generate Query
|
### Generate Query
|
||||||
LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
|
LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary> Prompt </summary>
|
<summary> Prompt </summary>
|
||||||
@@ -384,7 +384,7 @@ def insert_text(rag, file_path):
|
|||||||
|
|
||||||
### Step-2 Generate Queries
|
### Step-2 Generate Queries
|
||||||
|
|
||||||
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
|
We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary> Code </summary>
|
<summary> Code </summary>
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import jsonlines
|
import jsonlines
|
||||||
@@ -9,22 +8,22 @@ from openai import OpenAI
|
|||||||
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
|
|
||||||
with open(query_file, 'r') as f:
|
with open(query_file, "r") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||||
|
|
||||||
with open(result1_file, 'r') as f:
|
with open(result1_file, "r") as f:
|
||||||
answers1 = json.load(f)
|
answers1 = json.load(f)
|
||||||
answers1 = [i['result'] for i in answers1]
|
answers1 = [i["result"] for i in answers1]
|
||||||
|
|
||||||
with open(result2_file, 'r') as f:
|
with open(result2_file, "r") as f:
|
||||||
answers2 = json.load(f)
|
answers2 = json.load(f)
|
||||||
answers2 = [i['result'] for i in answers2]
|
answers2 = [i["result"] for i in answers2]
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
||||||
sys_prompt = f"""
|
sys_prompt = """
|
||||||
---Role---
|
---Role---
|
||||||
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
||||||
"""
|
"""
|
||||||
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"custom_id": f"request-{i+1}",
|
"custom_id": f"request-{i+1}",
|
||||||
"method": "POST",
|
"method": "POST",
|
||||||
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": sys_prompt},
|
{"role": "system", "content": sys_prompt},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
requests.append(request_data)
|
requests.append(request_data)
|
||||||
|
|
||||||
with jsonlines.open(output_file_path, mode='w') as writer:
|
with jsonlines.open(output_file_path, mode="w") as writer:
|
||||||
for request in requests:
|
for request in requests:
|
||||||
writer.write(request)
|
writer.write(request)
|
||||||
|
|
||||||
print(f"Batch API requests written to {output_file_path}")
|
print(f"Batch API requests written to {output_file_path}")
|
||||||
|
|
||||||
batch_input_file = client.files.create(
|
batch_input_file = client.files.create(
|
||||||
file=open(output_file_path, "rb"),
|
file=open(output_file_path, "rb"), purpose="batch"
|
||||||
purpose="batch"
|
|
||||||
)
|
)
|
||||||
batch_input_file_id = batch_input_file.id
|
batch_input_file_id = batch_input_file.id
|
||||||
|
|
||||||
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|||||||
input_file_id=batch_input_file_id,
|
input_file_id=batch_input_file_id,
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
metadata={
|
metadata={"description": "nightly eval job"},
|
||||||
"description": "nightly eval job"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f'Batch {batch.id} has been created.')
|
print(f"Batch {batch.id} has been created.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
batch_eval()
|
batch_eval()
|
@@ -1,9 +1,8 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
# os.environ["OPENAI_API_KEY"] = ""
|
# os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
|
||||||
|
|
||||||
def openai_complete_if_cache(
|
def openai_complete_if_cache(
|
||||||
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -47,9 +46,9 @@ if __name__ == "__main__":
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt)
|
result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
|
||||||
|
|
||||||
file_path = f"./queries.txt"
|
file_path = "./queries.txt"
|
||||||
with open(file_path, "w") as file:
|
with open(file_path, "w") as file:
|
||||||
file.write(result)
|
file.write(result)
|
||||||
|
|
||||||
|
@@ -20,13 +20,11 @@ rag = LightRAG(
|
|||||||
llm_model_func=bedrock_complete,
|
llm_model_func=bedrock_complete,
|
||||||
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=1024,
|
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
|
||||||
max_token_size=8192,
|
),
|
||||||
func=bedrock_embedding
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with open("./book.txt", 'r', encoding='utf-8') as f:
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
for mode in ["naive", "local", "global", "hybrid"]:
|
for mode in ["naive", "local", "global", "hybrid"]:
|
||||||
@@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]:
|
|||||||
print(f"| {mode.capitalize()} |")
|
print(f"| {mode.capitalize()} |")
|
||||||
print("+-" + "-" * len(mode) + "-+\n")
|
print("+-" + "-" * len(mode) + "-+\n")
|
||||||
print(
|
print(
|
||||||
rag.query(
|
rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
|
||||||
"What are the top themes in this story?",
|
|
||||||
param=QueryParam(mode=mode)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
@@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import hf_model_complete, hf_embedding
|
from lightrag.llm import hf_model_complete, hf_embedding
|
||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from transformers import AutoModel,AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
@@ -14,15 +13,19 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=hf_model_complete,
|
llm_model_func=hf_model_complete,
|
||||||
llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
|
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=384,
|
embedding_dim=384,
|
||||||
max_token_size=5000,
|
max_token_size=5000,
|
||||||
func=lambda texts: hf_embedding(
|
func=lambda texts: hf_embedding(
|
||||||
texts,
|
texts,
|
||||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
tokenizer=AutoTokenizer.from_pretrained(
|
||||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
"sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
),
|
||||||
|
embed_model=AutoModel.from_pretrained(
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
@@ -12,14 +12,11 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=ollama_model_complete,
|
llm_model_func=ollama_model_complete,
|
||||||
llm_model_name='your_model_name',
|
llm_model_name="your_model_name",
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=768,
|
embedding_dim=768,
|
||||||
max_token_size=8192,
|
max_token_size=8192,
|
||||||
func=lambda texts: ollama_embedding(
|
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
|
||||||
texts,
|
|
||||||
embed_model="nomic-embed-text"
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
@@ -10,6 +10,7 @@ WORKING_DIR = "./dickens"
|
|||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -20,17 +21,19 @@ async def llm_model_func(
|
|||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar",
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await openai_embedding(
|
return await openai_embedding(
|
||||||
texts,
|
texts,
|
||||||
model="solar-embedding-1-large-query",
|
model="solar-embedding-1-large-query",
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar"
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# function test
|
# function test
|
||||||
async def test_funcs():
|
async def test_funcs():
|
||||||
result = await llm_model_func("How are you?")
|
result = await llm_model_func("How are you?")
|
||||||
@@ -39,6 +42,7 @@ async def test_funcs():
|
|||||||
result = await embedding_func(["How are you?"])
|
result = await embedding_func(["How are you?"])
|
||||||
print("embedding_func: ", result)
|
print("embedding_func: ", result)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(test_funcs())
|
asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
|
||||||
@@ -46,10 +50,8 @@ rag = LightRAG(
|
|||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=llm_model_func,
|
llm_model_func=llm_model_func,
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=4096,
|
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||||
max_token_size=8192,
|
),
|
||||||
func=embedding_func
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
@@ -1,9 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
from lightrag.llm import gpt_4o_mini_complete
|
||||||
from transformers import AutoModel,AutoTokenizer
|
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
|
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=gpt_4o_mini_complete
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
# llm_model_func=gpt_4o_complete
|
# llm_model_func=gpt_4o_complete
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from .lightrag import LightRAG, QueryParam
|
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||||
|
|
||||||
__version__ = "0.0.6"
|
__version__ = "0.0.6"
|
||||||
__author__ = "Zirui Guo"
|
__author__ = "Zirui Guo"
|
||||||
|
@@ -12,6 +12,7 @@ TextChunkSchema = TypedDict(
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueryParam:
|
class QueryParam:
|
||||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||||
@@ -36,6 +37,7 @@ class StorageNameSpace:
|
|||||||
"""commit the storage operations after querying"""
|
"""commit the storage operations after querying"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseVectorStorage(StorageNameSpace):
|
class BaseVectorStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||||
async def all_keys(self) -> list[str]:
|
async def all_keys(self) -> list[str]:
|
||||||
|
@@ -3,10 +3,12 @@ import os
|
|||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type, cast, Any
|
from typing import Type, cast
|
||||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
|
||||||
|
|
||||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
|
from .llm import (
|
||||||
|
gpt_4o_mini_complete,
|
||||||
|
openai_embedding,
|
||||||
|
)
|
||||||
from .operate import (
|
from .operate import (
|
||||||
chunking_by_token_size,
|
chunking_by_token_size,
|
||||||
extract_entities,
|
extract_entities,
|
||||||
@@ -37,6 +39,7 @@ from .base import (
|
|||||||
QueryParam,
|
QueryParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
@@ -69,7 +72,6 @@ class LightRAG:
|
|||||||
"dimensions": 1536,
|
"dimensions": 1536,
|
||||||
"num_walks": 10,
|
"num_walks": 10,
|
||||||
"walk_length": 40,
|
"walk_length": 40,
|
||||||
"num_walks": 10,
|
|
||||||
"window_size": 2,
|
"window_size": 2,
|
||||||
"iterations": 3,
|
"iterations": 3,
|
||||||
"random_seed": 3,
|
"random_seed": 3,
|
||||||
@@ -77,13 +79,13 @@ class LightRAG:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
||||||
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
|
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||||
embedding_batch_num: int = 32
|
embedding_batch_num: int = 32
|
||||||
embedding_func_max_async: int = 16
|
embedding_func_max_async: int = 16
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
|
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||||
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||||
llm_model_max_token_size: int = 32768
|
llm_model_max_token_size: int = 32768
|
||||||
llm_model_max_async: int = 16
|
llm_model_max_async: int = 16
|
||||||
|
|
||||||
@@ -133,28 +135,22 @@ class LightRAG:
|
|||||||
self.embedding_func
|
self.embedding_func
|
||||||
)
|
)
|
||||||
|
|
||||||
self.entities_vdb = (
|
self.entities_vdb = self.vector_db_storage_cls(
|
||||||
self.vector_db_storage_cls(
|
namespace="entities",
|
||||||
namespace="entities",
|
global_config=asdict(self),
|
||||||
global_config=asdict(self),
|
embedding_func=self.embedding_func,
|
||||||
embedding_func=self.embedding_func,
|
meta_fields={"entity_name"},
|
||||||
meta_fields={"entity_name"}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.relationships_vdb = (
|
self.relationships_vdb = self.vector_db_storage_cls(
|
||||||
self.vector_db_storage_cls(
|
namespace="relationships",
|
||||||
namespace="relationships",
|
global_config=asdict(self),
|
||||||
global_config=asdict(self),
|
embedding_func=self.embedding_func,
|
||||||
embedding_func=self.embedding_func,
|
meta_fields={"src_id", "tgt_id"},
|
||||||
meta_fields={"src_id", "tgt_id"}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.chunks_vdb = (
|
self.chunks_vdb = self.vector_db_storage_cls(
|
||||||
self.vector_db_storage_cls(
|
namespace="chunks",
|
||||||
namespace="chunks",
|
global_config=asdict(self),
|
||||||
global_config=asdict(self),
|
embedding_func=self.embedding_func,
|
||||||
embedding_func=self.embedding_func,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
@@ -177,7 +173,7 @@ class LightRAG:
|
|||||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||||
if not len(new_docs):
|
if not len(new_docs):
|
||||||
logger.warning(f"All docs are already in the storage")
|
logger.warning("All docs are already in the storage")
|
||||||
return
|
return
|
||||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||||
|
|
||||||
@@ -203,7 +199,7 @@ class LightRAG:
|
|||||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||||
}
|
}
|
||||||
if not len(inserting_chunks):
|
if not len(inserting_chunks):
|
||||||
logger.warning(f"All chunks are already in the storage")
|
logger.warning("All chunks are already in the storage")
|
||||||
return
|
return
|
||||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||||
|
|
||||||
@@ -291,7 +287,6 @@ class LightRAG:
|
|||||||
await self._query_done()
|
await self._query_done()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _query_done(self):
|
async def _query_done(self):
|
||||||
tasks = []
|
tasks = []
|
||||||
for storage_inst in [self.llm_response_cache]:
|
for storage_inst in [self.llm_response_cache]:
|
||||||
@@ -299,5 +294,3 @@ class LightRAG:
|
|||||||
continue
|
continue
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
221
lightrag/llm.py
221
lightrag/llm.py
@@ -1,9 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import botocore
|
|
||||||
import aioboto3
|
import aioboto3
|
||||||
import botocore.errorfactory
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ollama
|
import ollama
|
||||||
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
||||||
@@ -13,24 +11,34 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
)
|
)
|
||||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
from .base import BaseKVStorage
|
from .base import BaseKVStorage
|
||||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||||
import copy
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
)
|
)
|
||||||
async def openai_complete_if_cache(
|
async def openai_complete_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
|
model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
base_url=None,
|
||||||
|
api_key=None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
os.environ["OPENAI_API_KEY"] = api_key
|
||||||
|
|
||||||
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
openai_async_client = (
|
||||||
|
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||||
|
)
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -64,43 +72,56 @@ class BedrockError(Exception):
|
|||||||
retry=retry_if_exception_type((BedrockError)),
|
retry=retry_if_exception_type((BedrockError)),
|
||||||
)
|
)
|
||||||
async def bedrock_complete_if_cache(
|
async def bedrock_complete_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[],
|
model,
|
||||||
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
aws_access_key_id=None,
|
||||||
|
aws_secret_access_key=None,
|
||||||
|
aws_session_token=None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
|
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||||
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
|
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||||
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
|
)
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||||
|
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||||
|
)
|
||||||
|
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||||
|
"AWS_SESSION_TOKEN", aws_session_token
|
||||||
|
)
|
||||||
|
|
||||||
# Fix message history format
|
# Fix message history format
|
||||||
messages = []
|
messages = []
|
||||||
for history_message in history_messages:
|
for history_message in history_messages:
|
||||||
message = copy.copy(history_message)
|
message = copy.copy(history_message)
|
||||||
message['content'] = [{'text': message['content']}]
|
message["content"] = [{"text": message["content"]}]
|
||||||
messages.append(message)
|
messages.append(message)
|
||||||
|
|
||||||
# Add user prompt
|
# Add user prompt
|
||||||
messages.append({'role': "user", 'content': [{'text': prompt}]})
|
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||||
|
|
||||||
# Initialize Converse API arguments
|
# Initialize Converse API arguments
|
||||||
args = {
|
args = {"modelId": model, "messages": messages}
|
||||||
'modelId': model,
|
|
||||||
'messages': messages
|
|
||||||
}
|
|
||||||
|
|
||||||
# Define system prompt
|
# Define system prompt
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
args['system'] = [{'text': system_prompt}]
|
args["system"] = [{"text": system_prompt}]
|
||||||
|
|
||||||
# Map and set up inference parameters
|
# Map and set up inference parameters
|
||||||
inference_params_map = {
|
inference_params_map = {
|
||||||
'max_tokens': "maxTokens",
|
"max_tokens": "maxTokens",
|
||||||
'top_p': "topP",
|
"top_p": "topP",
|
||||||
'stop_sequences': "stopSequences"
|
"stop_sequences": "stopSequences",
|
||||||
}
|
}
|
||||||
if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
|
if inference_params := list(
|
||||||
args['inferenceConfig'] = {}
|
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
||||||
|
):
|
||||||
|
args["inferenceConfig"] = {}
|
||||||
for param in inference_params:
|
for param in inference_params:
|
||||||
args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
|
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||||
|
kwargs.pop(param)
|
||||||
|
)
|
||||||
|
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
|
|||||||
# Call model via Converse API
|
# Call model via Converse API
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await bedrock_async_client.converse(**args, **kwargs)
|
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BedrockError(e)
|
raise BedrockError(e)
|
||||||
|
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert({
|
await hashing_kv.upsert(
|
||||||
args_hash: {
|
{
|
||||||
'return': response['output']['message']['content'][0]['text'],
|
args_hash: {
|
||||||
'model': model
|
"return": response["output"]["message"]["content"][0]["text"],
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
|
return response["output"]["message"]["content"][0]["text"]
|
||||||
|
|
||||||
return response['output']['message']['content'][0]['text']
|
|
||||||
|
|
||||||
async def hf_model_if_cache(
|
async def hf_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = model
|
model_name = model
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
||||||
if hf_tokenizer.pad_token == None:
|
if hf_tokenizer.pad_token is None:
|
||||||
# print("use eos token")
|
# print("use eos token")
|
||||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
|
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
|
|||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
return if_cache_return["return"]
|
return if_cache_return["return"]
|
||||||
input_prompt = ''
|
input_prompt = ""
|
||||||
try:
|
try:
|
||||||
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
input_prompt = hf_tokenizer.apply_chat_template(
|
||||||
except:
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
try:
|
try:
|
||||||
ori_message = copy.deepcopy(messages)
|
ori_message = copy.deepcopy(messages)
|
||||||
if messages[0]['role'] == "system":
|
if messages[0]["role"] == "system":
|
||||||
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
|
messages[1]["content"] = (
|
||||||
|
"<system>"
|
||||||
|
+ messages[0]["content"]
|
||||||
|
+ "</system>\n"
|
||||||
|
+ messages[1]["content"]
|
||||||
|
)
|
||||||
messages = messages[1:]
|
messages = messages[1:]
|
||||||
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
input_prompt = hf_tokenizer.apply_chat_template(
|
||||||
except:
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
len_message = len(ori_message)
|
len_message = len(ori_message)
|
||||||
for msgid in range(len_message):
|
for msgid in range(len_message):
|
||||||
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
|
input_prompt = (
|
||||||
|
input_prompt
|
||||||
|
+ "<"
|
||||||
|
+ ori_message[msgid]["role"]
|
||||||
|
+ ">"
|
||||||
|
+ ori_message[msgid]["content"]
|
||||||
|
+ "</"
|
||||||
|
+ ori_message[msgid]["role"]
|
||||||
|
+ ">\n"
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
|
input_ids = hf_tokenizer(
|
||||||
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
|
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||||
|
).to("cuda")
|
||||||
|
output = hf_model.generate(
|
||||||
|
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
|
||||||
|
)
|
||||||
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert(
|
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
||||||
{args_hash: {"return": response_text, "model": model}}
|
|
||||||
)
|
|
||||||
return response_text
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
async def ollama_model_if_cache(
|
async def ollama_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def gpt_4o_complete(
|
async def gpt_4o_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -241,7 +286,7 @@ async def bedrock_complete(
|
|||||||
async def hf_model_complete(
|
async def hf_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
return await hf_model_if_cache(
|
return await hf_model_if_cache(
|
||||||
model_name,
|
model_name,
|
||||||
prompt,
|
prompt,
|
||||||
@@ -250,10 +295,11 @@ async def hf_model_complete(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def ollama_model_complete(
|
async def ollama_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
return await ollama_model_if_cache(
|
return await ollama_model_if_cache(
|
||||||
model_name,
|
model_name,
|
||||||
prompt,
|
prompt,
|
||||||
@@ -262,17 +308,25 @@ async def ollama_model_complete(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
)
|
)
|
||||||
async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
|
async def openai_embedding(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
base_url: str = None,
|
||||||
|
api_key: str = None,
|
||||||
|
) -> np.ndarray:
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
os.environ["OPENAI_API_KEY"] = api_key
|
||||||
|
|
||||||
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
openai_async_client = (
|
||||||
|
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||||
|
)
|
||||||
response = await openai_async_client.embeddings.create(
|
response = await openai_async_client.embeddings.create(
|
||||||
model=model, input=texts, encoding_format="float"
|
model=model, input=texts, encoding_format="float"
|
||||||
)
|
)
|
||||||
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
|
|||||||
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
||||||
# )
|
# )
|
||||||
async def bedrock_embedding(
|
async def bedrock_embedding(
|
||||||
texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
|
texts: list[str],
|
||||||
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
|
model: str = "amazon.titan-embed-text-v2:0",
|
||||||
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
|
aws_access_key_id=None,
|
||||||
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
|
aws_secret_access_key=None,
|
||||||
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
|
aws_session_token=None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||||
|
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||||
|
)
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||||
|
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||||
|
)
|
||||||
|
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||||
|
"AWS_SESSION_TOKEN", aws_session_token
|
||||||
|
)
|
||||||
|
|
||||||
session = aioboto3.Session()
|
session = aioboto3.Session()
|
||||||
async with session.client("bedrock-runtime") as bedrock_async_client:
|
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||||
|
|
||||||
if (model_provider := model.split(".")[0]) == "amazon":
|
if (model_provider := model.split(".")[0]) == "amazon":
|
||||||
embed_texts = []
|
embed_texts = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
if "v2" in model:
|
if "v2" in model:
|
||||||
body = json.dumps({
|
body = json.dumps(
|
||||||
'inputText': text,
|
{
|
||||||
# 'dimensions': embedding_dim,
|
"inputText": text,
|
||||||
'embeddingTypes': ["float"]
|
# 'dimensions': embedding_dim,
|
||||||
})
|
"embeddingTypes": ["float"],
|
||||||
|
}
|
||||||
|
)
|
||||||
elif "v1" in model:
|
elif "v1" in model:
|
||||||
body = json.dumps({
|
body = json.dumps({"inputText": text})
|
||||||
'inputText': text
|
|
||||||
})
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {model} is not supported!")
|
raise ValueError(f"Model {model} is not supported!")
|
||||||
|
|
||||||
@@ -315,29 +378,27 @@ async def bedrock_embedding(
|
|||||||
modelId=model,
|
modelId=model,
|
||||||
body=body,
|
body=body,
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json"
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
response_body = await response.get('body').json()
|
response_body = await response.get("body").json()
|
||||||
|
|
||||||
embed_texts.append(response_body['embedding'])
|
embed_texts.append(response_body["embedding"])
|
||||||
elif model_provider == "cohere":
|
elif model_provider == "cohere":
|
||||||
body = json.dumps({
|
body = json.dumps(
|
||||||
'texts': texts,
|
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||||
'input_type': "search_document",
|
)
|
||||||
'truncate': "NONE"
|
|
||||||
})
|
|
||||||
|
|
||||||
response = await bedrock_async_client.invoke_model(
|
response = await bedrock_async_client.invoke_model(
|
||||||
model=model,
|
model=model,
|
||||||
body=body,
|
body=body,
|
||||||
accept="application/json",
|
accept="application/json",
|
||||||
contentType="application/json"
|
contentType="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
response_body = json.loads(response.get('body').read())
|
response_body = json.loads(response.get("body").read())
|
||||||
|
|
||||||
embed_texts = response_body['embeddings']
|
embed_texts = response_body["embeddings"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
||||||
|
|
||||||
@@ -345,12 +406,15 @@ async def bedrock_embedding(
|
|||||||
|
|
||||||
|
|
||||||
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||||
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
|
input_ids = tokenizer(
|
||||||
|
texts, return_tensors="pt", padding=True, truncation=True
|
||||||
|
).input_ids
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = embed_model(input_ids)
|
outputs = embed_model(input_ids)
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
return embeddings.detach().numpy()
|
return embeddings.detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
||||||
embed_text = []
|
embed_text = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
|||||||
|
|
||||||
return embed_text
|
return embed_text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
result = await gpt_4o_mini_complete('How are you?')
|
result = await gpt_4o_mini_complete("How are you?")
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
@@ -25,6 +25,7 @@ from .base import (
|
|||||||
)
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||||
|
|
||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||||
):
|
):
|
||||||
@@ -45,6 +46,7 @@ def chunking_by_token_size(
|
|||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def _handle_entity_relation_summary(
|
async def _handle_entity_relation_summary(
|
||||||
entity_or_relation_name: str,
|
entity_or_relation_name: str,
|
||||||
description: str,
|
description: str,
|
||||||
@@ -232,6 +234,7 @@ async def _merge_edges_then_upsert(
|
|||||||
|
|
||||||
return edge_data
|
return edge_data
|
||||||
|
|
||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
chunks: dict[str, TextChunkSchema],
|
chunks: dict[str, TextChunkSchema],
|
||||||
knwoledge_graph_inst: BaseGraphStorage,
|
knwoledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -352,7 +355,9 @@ async def extract_entities(
|
|||||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||||
return None
|
return None
|
||||||
if not len(all_relationships_data):
|
if not len(all_relationships_data):
|
||||||
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
|
logger.warning(
|
||||||
|
"Didn't extract any relationships, maybe your LLM is not working"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if entity_vdb is not None:
|
if entity_vdb is not None:
|
||||||
@@ -370,7 +375,10 @@ async def extract_entities(
|
|||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
|
"content": dp["keywords"]
|
||||||
|
+ dp["src_id"]
|
||||||
|
+ dp["tgt_id"]
|
||||||
|
+ dp["description"],
|
||||||
}
|
}
|
||||||
for dp in all_relationships_data
|
for dp in all_relationships_data
|
||||||
}
|
}
|
||||||
@@ -378,6 +386,7 @@ async def extract_entities(
|
|||||||
|
|
||||||
return knwoledge_graph_inst
|
return knwoledge_graph_inst
|
||||||
|
|
||||||
|
|
||||||
async def local_query(
|
async def local_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -397,15 +406,20 @@ async def local_query(
|
|||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("low_level_keywords", [])
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
result = (
|
||||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||||
|
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("low_level_keywords", [])
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"JSON parsing error: {e}")
|
print(f"JSON parsing error: {e}")
|
||||||
@@ -430,11 +444,20 @@ async def local_query(
|
|||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
if len(response)>len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _build_local_query_context(
|
async def _build_local_query_context(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -516,6 +539,7 @@ async def _build_local_query_context(
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_text_unit_from_entities(
|
async def _find_most_related_text_unit_from_entities(
|
||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
||||||
return all_text_units
|
return all_text_units
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_edges_from_entities(
|
async def _find_most_related_edges_from_entities(
|
||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
|
|||||||
)
|
)
|
||||||
return all_edges_data
|
return all_edges_data
|
||||||
|
|
||||||
|
|
||||||
async def global_query(
|
async def global_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -628,15 +654,20 @@ async def global_query(
|
|||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
result = (
|
||||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||||
|
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
@@ -665,11 +696,20 @@ async def global_query(
|
|||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
if len(response)>len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _build_global_query_context(
|
async def _build_global_query_context(
|
||||||
keywords,
|
keywords,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -765,6 +805,7 @@ async def _build_global_query_context(
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_entities_from_relationships(
|
async def _find_most_related_entities_from_relationships(
|
||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
|
|
||||||
return node_datas
|
return node_datas
|
||||||
|
|
||||||
|
|
||||||
async def _find_related_text_unit_from_relationships(
|
async def _find_related_text_unit_from_relationships(
|
||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
):
|
):
|
||||||
|
|
||||||
text_units = [
|
text_units = [
|
||||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||||
for dp in edge_datas
|
for dp in edge_datas
|
||||||
@@ -822,9 +863,7 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
all_text_units = [
|
all_text_units = [
|
||||||
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
||||||
]
|
]
|
||||||
all_text_units = sorted(
|
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
||||||
all_text_units, key=lambda x: x["order"]
|
|
||||||
)
|
|
||||||
all_text_units = truncate_list_by_token_size(
|
all_text_units = truncate_list_by_token_size(
|
||||||
all_text_units,
|
all_text_units,
|
||||||
key=lambda x: x["data"]["content"],
|
key=lambda x: x["data"]["content"],
|
||||||
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
|
|
||||||
return all_text_units
|
return all_text_units
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_query(
|
async def hybrid_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -855,18 +895,23 @@ async def hybrid_query(
|
|||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
hl_keywords = ', '.join(hl_keywords)
|
hl_keywords = ", ".join(hl_keywords)
|
||||||
ll_keywords = ', '.join(ll_keywords)
|
ll_keywords = ", ".join(ll_keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
|
result = (
|
||||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||||
|
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
hl_keywords = ', '.join(hl_keywords)
|
hl_keywords = ", ".join(hl_keywords)
|
||||||
ll_keywords = ', '.join(ll_keywords)
|
ll_keywords = ", ".join(ll_keywords)
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"JSON parsing error: {e}")
|
print(f"JSON parsing error: {e}")
|
||||||
@@ -906,52 +951,77 @@ async def hybrid_query(
|
|||||||
query,
|
query,
|
||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
if len(response)>len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def combine_contexts(high_level_context, low_level_context):
|
def combine_contexts(high_level_context, low_level_context):
|
||||||
# Function to extract entities, relationships, and sources from context strings
|
# Function to extract entities, relationships, and sources from context strings
|
||||||
|
|
||||||
def extract_sections(context):
|
def extract_sections(context):
|
||||||
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
entities_match = re.search(
|
||||||
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||||
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
)
|
||||||
|
relationships_match = re.search(
|
||||||
|
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||||
|
)
|
||||||
|
sources_match = re.search(
|
||||||
|
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
entities = entities_match.group(1) if entities_match else ''
|
entities = entities_match.group(1) if entities_match else ""
|
||||||
relationships = relationships_match.group(1) if relationships_match else ''
|
relationships = relationships_match.group(1) if relationships_match else ""
|
||||||
sources = sources_match.group(1) if sources_match else ''
|
sources = sources_match.group(1) if sources_match else ""
|
||||||
|
|
||||||
return entities, relationships, sources
|
return entities, relationships, sources
|
||||||
|
|
||||||
# Extract sections from both contexts
|
# Extract sections from both contexts
|
||||||
|
|
||||||
if high_level_context==None:
|
if high_level_context is None:
|
||||||
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
warnings.warn(
|
||||||
hl_entities, hl_relationships, hl_sources = '','',''
|
"High Level context is None. Return empty High entity/relationship/source"
|
||||||
|
)
|
||||||
|
hl_entities, hl_relationships, hl_sources = "", "", ""
|
||||||
else:
|
else:
|
||||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||||
|
|
||||||
|
if low_level_context is None:
|
||||||
if low_level_context==None:
|
warnings.warn(
|
||||||
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
"Low Level context is None. Return empty Low entity/relationship/source"
|
||||||
ll_entities, ll_relationships, ll_sources = '','',''
|
)
|
||||||
|
ll_entities, ll_relationships, ll_sources = "", "", ""
|
||||||
else:
|
else:
|
||||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Combine and deduplicate the entities
|
# Combine and deduplicate the entities
|
||||||
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
combined_entities_set = set(
|
||||||
combined_entities = '\n'.join(combined_entities_set)
|
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
|
||||||
|
)
|
||||||
|
combined_entities = "\n".join(combined_entities_set)
|
||||||
|
|
||||||
# Combine and deduplicate the relationships
|
# Combine and deduplicate the relationships
|
||||||
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
|
combined_relationships_set = set(
|
||||||
combined_relationships = '\n'.join(combined_relationships_set)
|
filter(
|
||||||
|
None,
|
||||||
|
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
combined_relationships = "\n".join(combined_relationships_set)
|
||||||
|
|
||||||
# Combine and deduplicate the sources
|
# Combine and deduplicate the sources
|
||||||
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
|
combined_sources_set = set(
|
||||||
combined_sources = '\n'.join(combined_sources_set)
|
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
|
||||||
|
)
|
||||||
|
combined_sources = "\n".join(combined_sources_set)
|
||||||
|
|
||||||
# Format the combined context
|
# Format the combined context
|
||||||
return f"""
|
return f"""
|
||||||
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
|
|||||||
{combined_sources}
|
{combined_sources}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
query,
|
query,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
@@ -996,8 +1067,16 @@ async def naive_query(
|
|||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(response)>len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response[len(sys_prompt) :]
|
||||||
|
.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
|
|||||||
|
|
||||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
||||||
|
|
||||||
PROMPTS[
|
PROMPTS["entity_extraction"] = """-Goal-
|
||||||
"entity_extraction"
|
|
||||||
] = """-Goal-
|
|
||||||
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
||||||
|
|
||||||
-Steps-
|
-Steps-
|
||||||
@@ -146,9 +144,7 @@ PROMPTS[
|
|||||||
|
|
||||||
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
||||||
|
|
||||||
PROMPTS[
|
PROMPTS["rag_response"] = """---Role---
|
||||||
"rag_response"
|
|
||||||
] = """---Role---
|
|
||||||
|
|
||||||
You are a helpful assistant responding to questions about data in the tables provided.
|
You are a helpful assistant responding to questions about data in the tables provided.
|
||||||
|
|
||||||
@@ -241,9 +237,7 @@ Output:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PROMPTS[
|
PROMPTS["naive_rag_response"] = """You're a helpful assistant
|
||||||
"naive_rag_response"
|
|
||||||
] = """You're a helpful assistant
|
|
||||||
Below are the knowledge you know:
|
Below are the knowledge you know:
|
||||||
{content_data}
|
{content_data}
|
||||||
---
|
---
|
||||||
|
@@ -1,16 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from dataclasses import dataclass
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
import pickle
|
|
||||||
import hnswlib
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nano_vectordb import NanoVectorDB
|
from nano_vectordb import NanoVectorDB
|
||||||
import xxhash
|
|
||||||
|
|
||||||
from .utils import load_json, logger, write_json
|
from .utils import load_json, logger, write_json
|
||||||
from .base import (
|
from .base import (
|
||||||
@@ -19,6 +14,7 @@ from .base import (
|
|||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JsonKVStorage(BaseKVStorage):
|
class JsonKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
async def drop(self):
|
async def drop(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = 0.2
|
cosine_better_than_threshold: float = 0.2
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
)
|
)
|
||||||
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
self._client.save()
|
self._client.save()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NetworkXStorage(BaseGraphStorage):
|
class NetworkXStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
graph = graph.copy()
|
graph = graph.copy()
|
||||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
node_mapping = {
|
||||||
|
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
||||||
|
} # type: ignore
|
||||||
graph = nx.relabel_nodes(graph, node_mapping)
|
graph = nx.relabel_nodes(graph, node_mapping)
|
||||||
return NetworkXStorage._stabilize_graph(graph)
|
return NetworkXStorage._stabilize_graph(graph)
|
||||||
|
|
||||||
|
@@ -16,18 +16,22 @@ ENCODER = None
|
|||||||
|
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
|
|
||||||
|
|
||||||
def set_logger(log_file: str):
|
def set_logger(log_file: str):
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
@@ -37,6 +41,7 @@ class EmbeddingFunc:
|
|||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||||
"""Locate the JSON string body from a string"""
|
"""Locate the JSON string body from a string"""
|
||||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||||
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def convert_response_to_json(response: str) -> dict:
|
def convert_response_to_json(response: str) -> dict:
|
||||||
json_str = locate_json_string_body_from_string(response)
|
json_str = locate_json_string_body_from_string(response)
|
||||||
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
||||||
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
|
|||||||
logger.error(f"Failed to parse JSON: {json_str}")
|
logger.error(f"Failed to parse JSON: {json_str}")
|
||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args):
|
def compute_args_hash(*args):
|
||||||
return md5(str(args).encode()).hexdigest()
|
return md5(str(args).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content, prefix: str = ""):
|
def compute_mdhash_id(content, prefix: str = ""):
|
||||||
return prefix + md5(content.encode()).hexdigest()
|
return prefix + md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||||
"""Add restriction of maximum async calling times for a async func"""
|
"""Add restriction of maximum async calling times for a async func"""
|
||||||
|
|
||||||
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
|||||||
|
|
||||||
return final_decro
|
return final_decro
|
||||||
|
|
||||||
|
|
||||||
def wrap_embedding_func_with_attrs(**kwargs):
|
def wrap_embedding_func_with_attrs(**kwargs):
|
||||||
"""Wrap a function with attributes"""
|
"""Wrap a function with attributes"""
|
||||||
|
|
||||||
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
|||||||
|
|
||||||
return final_decro
|
return final_decro
|
||||||
|
|
||||||
|
|
||||||
def load_json(file_name):
|
def load_json(file_name):
|
||||||
if not os.path.exists(file_name):
|
if not os.path.exists(file_name):
|
||||||
return None
|
return None
|
||||||
with open(file_name, encoding="utf-8") as f:
|
with open(file_name, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
||||||
global ENCODER
|
global ENCODER
|
||||||
if ENCODER is None:
|
if ENCODER is None:
|
||||||
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
|||||||
content = ENCODER.decode(tokens)
|
content = ENCODER.decode(tokens)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def pack_user_ass_to_openai_messages(*args: str):
|
def pack_user_ass_to_openai_messages(*args: str):
|
||||||
roles = ["user", "assistant"]
|
roles = ["user", "assistant"]
|
||||||
return [
|
return [
|
||||||
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
||||||
"""Split a string by multiple markers"""
|
"""Split a string by multiple markers"""
|
||||||
if not markers:
|
if not markers:
|
||||||
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
|
|||||||
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
||||||
return [r.strip() for r in results if r.strip()]
|
return [r.strip() for r in results if r.strip()]
|
||||||
|
|
||||||
|
|
||||||
# Refer the utils functions of the official GraphRAG implementation:
|
# Refer the utils functions of the official GraphRAG implementation:
|
||||||
# https://github.com/microsoft/graphrag
|
# https://github.com/microsoft/graphrag
|
||||||
def clean_str(input: Any) -> str:
|
def clean_str(input: Any) -> str:
|
||||||
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
|
|||||||
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
||||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||||
|
|
||||||
|
|
||||||
def is_float_regex(value):
|
def is_float_regex(value):
|
||||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
||||||
"""Truncate a list of data by token size"""
|
"""Truncate a list of data by token size"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|||||||
return list_data[:i]
|
return list_data[:i]
|
||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
|
|
||||||
def list_of_list_to_csv(data: list[list]):
|
def list_of_list_to_csv(data: list[list]):
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_data_to_file(data, file_name):
|
def save_data_to_file(data, file_name):
|
||||||
with open(file_name, 'w', encoding='utf-8') as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
@@ -3,11 +3,11 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
def extract_unique_contexts(input_directory, output_directory):
|
|
||||||
|
|
||||||
|
def extract_unique_contexts(input_directory, output_directory):
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
os.makedirs(output_directory, exist_ok=True)
|
||||||
|
|
||||||
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
|
jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
|
||||||
print(f"Found {len(jsonl_files)} JSONL files.")
|
print(f"Found {len(jsonl_files)} JSONL files.")
|
||||||
|
|
||||||
for file_path in jsonl_files:
|
for file_path in jsonl_files:
|
||||||
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
|
|||||||
print(f"Processing file: {filename}")
|
print(f"Processing file: {filename}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8') as infile:
|
with open(file_path, "r", encoding="utf-8") as infile:
|
||||||
for line_number, line in enumerate(infile, start=1):
|
for line_number, line in enumerate(infile, start=1):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
json_obj = json.loads(line)
|
json_obj = json.loads(line)
|
||||||
context = json_obj.get('context')
|
context = json_obj.get("context")
|
||||||
if context and context not in unique_contexts_dict:
|
if context and context not in unique_contexts_dict:
|
||||||
unique_contexts_dict[context] = None
|
unique_contexts_dict[context] = None
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
|
print(
|
||||||
|
f"JSON decoding error in file {filename} at line {line_number}: {e}"
|
||||||
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"File not found: {filename}")
|
print(f"File not found: {filename}")
|
||||||
continue
|
continue
|
||||||
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
unique_contexts_list = list(unique_contexts_dict.keys())
|
unique_contexts_list = list(unique_contexts_dict.keys())
|
||||||
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
|
print(
|
||||||
|
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(output_path, 'w', encoding='utf-8') as outfile:
|
with open(output_path, "w", encoding="utf-8") as outfile:
|
||||||
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
||||||
print(f"Unique `context` entries have been saved to: {output_filename}")
|
print(f"Unique `context` entries have been saved to: {output_filename}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
|
parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
|
||||||
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
|
parser.add_argument(
|
||||||
|
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@@ -4,8 +4,9 @@ import time
|
|||||||
|
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
|
|
||||||
|
|
||||||
def insert_text(rag, file_path):
|
def insert_text(rag, file_path):
|
||||||
with open(file_path, mode='r') as f:
|
with open(file_path, mode="r") as f:
|
||||||
unique_contexts = json.load(f)
|
unique_contexts = json.load(f)
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
|
|||||||
if retries == max_retries:
|
if retries == max_retries:
|
||||||
print("Insertion failed after exceeding the maximum number of retries")
|
print("Insertion failed after exceeding the maximum number of retries")
|
||||||
|
|
||||||
|
|
||||||
cls = "agriculture"
|
cls = "agriculture"
|
||||||
WORKING_DIR = "../{cls}"
|
WORKING_DIR = "../{cls}"
|
||||||
|
|
||||||
|
@@ -7,6 +7,7 @@ from lightrag import LightRAG
|
|||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
|
|
||||||
|
|
||||||
## For Upstage API
|
## For Upstage API
|
||||||
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
@@ -19,20 +20,24 @@ async def llm_model_func(
|
|||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar",
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await openai_embedding(
|
return await openai_embedding(
|
||||||
texts,
|
texts,
|
||||||
model="solar-embedding-1-large-query",
|
model="solar-embedding-1-large-query",
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar"
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
## /For Upstage API
|
## /For Upstage API
|
||||||
|
|
||||||
|
|
||||||
def insert_text(rag, file_path):
|
def insert_text(rag, file_path):
|
||||||
with open(file_path, mode='r') as f:
|
with open(file_path, mode="r") as f:
|
||||||
unique_contexts = json.load(f)
|
unique_contexts = json.load(f)
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
|
|||||||
if retries == max_retries:
|
if retries == max_retries:
|
||||||
print("Insertion failed after exceeding the maximum number of retries")
|
print("Insertion failed after exceeding the maximum number of retries")
|
||||||
|
|
||||||
|
|
||||||
cls = "mix"
|
cls = "mix"
|
||||||
WORKING_DIR = f"../{cls}"
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
rag = LightRAG(working_dir=WORKING_DIR,
|
rag = LightRAG(
|
||||||
llm_model_func=llm_model_func,
|
working_dir=WORKING_DIR,
|
||||||
embedding_func=EmbeddingFunc(
|
llm_model_func=llm_model_func,
|
||||||
embedding_dim=4096,
|
embedding_func=EmbeddingFunc(
|
||||||
max_token_size=8192,
|
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||||
func=embedding_func
|
),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
def openai_complete_if_cache(
|
def openai_complete_if_cache(
|
||||||
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -19,14 +19,16 @@ def openai_complete_if_cache(
|
|||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
|
||||||
def get_summary(context, tot_tokens=2000):
|
def get_summary(context, tot_tokens=2000):
|
||||||
tokens = tokenizer.tokenize(context)
|
tokens = tokenizer.tokenize(context)
|
||||||
half_tokens = tot_tokens // 2
|
half_tokens = tot_tokens // 2
|
||||||
|
|
||||||
start_tokens = tokens[1000:1000 + half_tokens]
|
start_tokens = tokens[1000 : 1000 + half_tokens]
|
||||||
end_tokens = tokens[-(1000 + half_tokens):1000]
|
end_tokens = tokens[-(1000 + half_tokens) : 1000]
|
||||||
|
|
||||||
summary_tokens = start_tokens + end_tokens
|
summary_tokens = start_tokens + end_tokens
|
||||||
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
||||||
@@ -34,9 +36,9 @@ def get_summary(context, tot_tokens=2000):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
clses = ['agriculture']
|
clses = ["agriculture"]
|
||||||
for cls in clses:
|
for cls in clses:
|
||||||
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
|
with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
|
||||||
unique_contexts = json.load(f)
|
unique_contexts = json.load(f)
|
||||||
|
|
||||||
summaries = [get_summary(context) for context in unique_contexts]
|
summaries = [get_summary(context) for context in unique_contexts]
|
||||||
@@ -67,7 +69,7 @@ for cls in clses:
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
|
result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
|
||||||
|
|
||||||
file_path = f"../datasets/questions/{cls}_questions.txt"
|
file_path = f"../datasets/questions/{cls}_questions.txt"
|
||||||
with open(file_path, "w") as file:
|
with open(file_path, "w") as file:
|
||||||
|
@@ -4,16 +4,18 @@ import asyncio
|
|||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def extract_queries(file_path):
|
def extract_queries(file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, "r") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
data = data.replace('**', '')
|
data = data.replace("**", "")
|
||||||
|
|
||||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||||
|
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
async def process_query(query_text, rag_instance, query_param):
|
async def process_query(query_text, rag_instance, query_param):
|
||||||
try:
|
try:
|
||||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||||
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
return loop
|
return loop
|
||||||
|
|
||||||
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
|
|
||||||
|
def run_queries_and_save_to_json(
|
||||||
|
queries, rag_instance, query_param, output_file, error_file
|
||||||
|
):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
|
|
||||||
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
|
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
||||||
|
error_file, "a", encoding="utf-8"
|
||||||
|
) as err_file:
|
||||||
result_file.write("[\n")
|
result_file.write("[\n")
|
||||||
first_entry = True
|
first_entry = True
|
||||||
|
|
||||||
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||||
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
|
result, error = loop.run_until_complete(
|
||||||
|
process_query(query_text, rag_instance, query_param)
|
||||||
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if not first_entry:
|
if not first_entry:
|
||||||
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|||||||
|
|
||||||
result_file.write("\n]")
|
result_file.write("\n]")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cls = "agriculture"
|
cls = "agriculture"
|
||||||
mode = "hybrid"
|
mode = "hybrid"
|
||||||
@@ -59,4 +70,6 @@ if __name__ == "__main__":
|
|||||||
query_param = QueryParam(mode=mode)
|
query_param = QueryParam(mode=mode)
|
||||||
|
|
||||||
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
||||||
run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
|
run_queries_and_save_to_json(
|
||||||
|
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
|
||||||
|
)
|
||||||
|
@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
|
|||||||
from lightrag.utils import EmbeddingFunc
|
from lightrag.utils import EmbeddingFunc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
## For Upstage API
|
## For Upstage API
|
||||||
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
@@ -20,28 +21,33 @@ async def llm_model_func(
|
|||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar",
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await openai_embedding(
|
return await openai_embedding(
|
||||||
texts,
|
texts,
|
||||||
model="solar-embedding-1-large-query",
|
model="solar-embedding-1-large-query",
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar"
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
## /For Upstage API
|
## /For Upstage API
|
||||||
|
|
||||||
|
|
||||||
def extract_queries(file_path):
|
def extract_queries(file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, "r") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
data = data.replace('**', '')
|
data = data.replace("**", "")
|
||||||
|
|
||||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||||
|
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
async def process_query(query_text, rag_instance, query_param):
|
async def process_query(query_text, rag_instance, query_param):
|
||||||
try:
|
try:
|
||||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||||
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
return loop
|
return loop
|
||||||
|
|
||||||
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
|
|
||||||
|
def run_queries_and_save_to_json(
|
||||||
|
queries, rag_instance, query_param, output_file, error_file
|
||||||
|
):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
|
|
||||||
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
|
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
||||||
|
error_file, "a", encoding="utf-8"
|
||||||
|
) as err_file:
|
||||||
result_file.write("[\n")
|
result_file.write("[\n")
|
||||||
first_entry = True
|
first_entry = True
|
||||||
|
|
||||||
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||||
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
|
result, error = loop.run_until_complete(
|
||||||
|
process_query(query_text, rag_instance, query_param)
|
||||||
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if not first_entry:
|
if not first_entry:
|
||||||
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|||||||
|
|
||||||
result_file.write("\n]")
|
result_file.write("\n]")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cls = "mix"
|
cls = "mix"
|
||||||
mode = "hybrid"
|
mode = "hybrid"
|
||||||
WORKING_DIR = f"../{cls}"
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
rag = LightRAG(working_dir=WORKING_DIR)
|
rag = LightRAG(working_dir=WORKING_DIR)
|
||||||
rag = LightRAG(working_dir=WORKING_DIR,
|
rag = LightRAG(
|
||||||
llm_model_func=llm_model_func,
|
working_dir=WORKING_DIR,
|
||||||
embedding_func=EmbeddingFunc(
|
llm_model_func=llm_model_func,
|
||||||
embedding_dim=4096,
|
embedding_func=EmbeddingFunc(
|
||||||
max_token_size=8192,
|
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||||
func=embedding_func
|
),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
query_param = QueryParam(mode=mode)
|
query_param = QueryParam(mode=mode)
|
||||||
|
|
||||||
base_dir='../datasets/questions'
|
base_dir = "../datasets/questions"
|
||||||
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
||||||
run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
|
run_queries_and_save_to_json(
|
||||||
|
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
|
||||||
|
)
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
aioboto3
|
|
||||||
openai
|
|
||||||
tiktoken
|
|
||||||
networkx
|
|
||||||
graspologic
|
|
||||||
nano-vectordb
|
|
||||||
hnswlib
|
|
||||||
xxhash
|
|
||||||
tenacity
|
|
||||||
transformers
|
|
||||||
torch
|
|
||||||
ollama
|
|
||||||
accelerate
|
accelerate
|
||||||
|
aioboto3
|
||||||
|
graspologic
|
||||||
|
hnswlib
|
||||||
|
nano-vectordb
|
||||||
|
networkx
|
||||||
|
ollama
|
||||||
|
openai
|
||||||
|
tenacity
|
||||||
|
tiktoken
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
|
xxhash
|
||||||
|
Reference in New Issue
Block a user