Merge branch 'HKUDS:main' into main

This commit is contained in:
JavieHush
2024-10-21 09:06:06 +08:00
committed by GitHub
29 changed files with 1553 additions and 348 deletions

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
__pycache__
*.egg-info
dickens/
book.txt
lightrag-dev/
.idea/
dist/

22
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,22 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff-format
- id: ruff
args: [--fix]
- repo: https://github.com/mgedmin/check-manifest
rev: "0.49"
hooks:
- id: check-manifest
stages: [manual]

220
README.md
View File

@@ -6,10 +6,12 @@
<div align='center'> <div align='center'>
<p> <p>
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a> <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' /> <a href='https://discord.gg/mvsfu2Tg'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
</p> </p>
<p> <p>
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
<img src="https://img.shields.io/badge/python->=3.9.11-blue"> <img src="https://img.shields.io/badge/python->=3.9.11-blue">
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a> <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a> <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
@@ -20,12 +22,15 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div> </div>
## 🎉 News ## 🎉 News
- [x] [2024.10.20]🎯🎯📢📢Weve added a new feature to LightRAG: Graph Visualization.
- [x] [2024.10.18]🎯🎯📢📢Weve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
## Install ## Install
* Install from source * Install from source (Recommend)
```bash ```bash
cd LightRAG cd LightRAG
@@ -43,12 +48,21 @@ pip install lightrag-hku
```bash ```bash
curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
``` ```
Use the below Python snippet to initialize LightRAG and perform queries: Use the below Python snippet (in a script) to initialize LightRAG and perform queries:
```python ```python
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
@@ -79,7 +93,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
<details> <details>
<summary> Using Open AI-like APIs </summary> <summary> Using Open AI-like APIs </summary>
LightRAG also support Open AI-like chat/embeddings APIs: * LightRAG also supports Open AI-like chat/embeddings APIs:
```python ```python
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
@@ -117,7 +131,7 @@ rag = LightRAG(
<details> <details>
<summary> Using Hugging Face Models </summary> <summary> Using Hugging Face Models </summary>
If you want to use Hugging Face models, you only need to set LightRAG as follows: * If you want to use Hugging Face models, you only need to set LightRAG as follows:
```python ```python
from lightrag.llm import hf_model_complete, hf_embedding from lightrag.llm import hf_model_complete, hf_embedding
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
@@ -143,7 +157,8 @@ rag = LightRAG(
<details> <details>
<summary> Using Ollama Models </summary> <summary> Using Ollama Models </summary>
If you want to use Ollama models, you only need to set LightRAG as follows:
* If you want to use Ollama models, you only need to set LightRAG as follows:
```python ```python
from lightrag.llm import ollama_model_complete, ollama_embedding from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -164,6 +179,29 @@ rag = LightRAG(
), ),
) )
``` ```
* Increasing the `num_ctx` parameter:
1. Pull the model:
```python
ollama pull qwen2
```
2. Display the model file:
```python
ollama show --modelfile qwen2 > Modelfile
```
3. Edit the Modelfile by adding the following line:
```python
PARAMETER num_ctx 32768
```
4. Create the modified model:
```python
ollama create -f Modelfile qwen2m
```
</details> </details>
### Batch Insert ### Batch Insert
@@ -181,12 +219,167 @@ rag = LightRAG(working_dir="./dickens")
with open("./newText.txt") as f: with open("./newText.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
``` ```
### Graph Visualization
<details>
<summary> Graph visualization with html </summary>
* The following code can be found in `examples/graph_visual_with_html.py`
```python
import networkx as nx
from pyvis.network import Network
# Load the GraphML file
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
# Create a Pyvis network
net = Network(notebook=True)
# Convert NetworkX graph to Pyvis network
net.from_nx(G)
# Save and display the network
net.show('knowledge_graph.html')
```
</details>
<details>
<summary> Graph visualization with Neo4j </summary>
* The following code can be found in `examples/graph_visual_with_neo4j.py`
```python
import os
import json
from lightrag.utils import xml_to_json
from neo4j import GraphDatabase
# Constants
WORKING_DIR = "./dickens"
BATCH_SIZE_NODES = 500
BATCH_SIZE_EDGES = 100
# Neo4j connection credentials
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
print(f"Error: File not found - {xml_path}")
return None
json_data = xml_to_json(xml_path)
if json_data:
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
else:
print("Failed to create JSON data")
return None
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
def main():
# Paths
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
if json_data is None:
return
# Load nodes and edges
nodes = json_data.get('nodes', [])
edges = json_data.get('edges', [])
# Neo4j queries
create_nodes_query = """
UNWIND $nodes AS node
MERGE (e:Entity {id: node.id})
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
"""
create_edges_query = """
UNWIND $edges AS edge
MATCH (source {id: edge.source})
MATCH (target {id: edge.target})
WITH source, target, edge,
CASE
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
WHEN edge.keywords CONTAINS 'located' THEN 'located'
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
END AS relType
CALL apoc.create.relationship(source, relType, {
weight: edge.weight,
description: edge.description,
keywords: edge.keywords,
source_id: edge.source_id
}, target) YIELD rel
RETURN count(*)
"""
set_displayname_and_labels_query = """
MATCH (n)
SET n.displayName = n.id
WITH n
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
RETURN count(*)
"""
# Create a Neo4j driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
try:
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
# Insert edges in batches
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()
if __name__ == "__main__":
main()
```
</details>
## Evaluation ## Evaluation
### Dataset ### Dataset
The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
### Generate Query ### Generate Query
LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`. LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
<details> <details>
<summary> Prompt </summary> <summary> Prompt </summary>
@@ -380,7 +573,7 @@ def insert_text(rag, file_path):
### Step-2 Generate Queries ### Step-2 Generate Queries
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries. We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
<details> <details>
<summary> Code </summary> <summary> Code </summary>
@@ -427,11 +620,16 @@ def extract_queries(file_path):
. .
├── examples ├── examples
├── batch_eval.py ├── batch_eval.py
├── graph_visual_with_html.py
├── graph_visual_with_neo4j.py
├── generate_query.py ├── generate_query.py
├── lightrag_azure_openai_demo.py
├── lightrag_bedrock_demo.py
├── lightrag_hf_demo.py ├── lightrag_hf_demo.py
├── lightrag_ollama_demo.py ├── lightrag_ollama_demo.py
├── lightrag_openai_compatible_demo.py ├── lightrag_openai_compatible_demo.py
── lightrag_openai_demo.py ── lightrag_openai_demo.py
└── vram_management_demo.py
├── lightrag ├── lightrag
├── __init__.py ├── __init__.py
├── base.py ├── base.py
@@ -446,6 +644,8 @@ def extract_queries(file_path):
├── Step_1.py ├── Step_1.py
├── Step_2.py ├── Step_2.py
└── Step_3.py └── Step_3.py
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE ├── LICENSE
├── README.md ├── README.md
├── requirements.txt ├── requirements.txt

View File

@@ -1,4 +1,3 @@
import os
import re import re
import json import json
import jsonlines import jsonlines
@@ -9,22 +8,22 @@ from openai import OpenAI
def batch_eval(query_file, result1_file, result2_file, output_file_path): def batch_eval(query_file, result1_file, result2_file, output_file_path):
client = OpenAI() client = OpenAI()
with open(query_file, 'r') as f: with open(query_file, "r") as f:
data = f.read() data = f.read()
queries = re.findall(r'- Question \d+: (.+)', data) queries = re.findall(r"- Question \d+: (.+)", data)
with open(result1_file, 'r') as f: with open(result1_file, "r") as f:
answers1 = json.load(f) answers1 = json.load(f)
answers1 = [i['result'] for i in answers1] answers1 = [i["result"] for i in answers1]
with open(result2_file, 'r') as f: with open(result2_file, "r") as f:
answers2 = json.load(f) answers2 = json.load(f)
answers2 = [i['result'] for i in answers2] answers2 = [i["result"] for i in answers2]
requests = [] requests = []
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)): for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
sys_prompt = f""" sys_prompt = """
---Role--- ---Role---
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**. You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
""" """
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
}} }}
""" """
request_data = { request_data = {
"custom_id": f"request-{i+1}", "custom_id": f"request-{i+1}",
"method": "POST", "method": "POST",
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
"model": "gpt-4o-mini", "model": "gpt-4o-mini",
"messages": [ "messages": [
{"role": "system", "content": sys_prompt}, {"role": "system", "content": sys_prompt},
{"role": "user", "content": prompt} {"role": "user", "content": prompt},
], ],
} },
} }
requests.append(request_data) requests.append(request_data)
with jsonlines.open(output_file_path, mode='w') as writer: with jsonlines.open(output_file_path, mode="w") as writer:
for request in requests: for request in requests:
writer.write(request) writer.write(request)
print(f"Batch API requests written to {output_file_path}") print(f"Batch API requests written to {output_file_path}")
batch_input_file = client.files.create( batch_input_file = client.files.create(
file=open(output_file_path, "rb"), file=open(output_file_path, "rb"), purpose="batch"
purpose="batch"
) )
batch_input_file_id = batch_input_file.id batch_input_file_id = batch_input_file.id
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
input_file_id=batch_input_file_id, input_file_id=batch_input_file_id,
endpoint="/v1/chat/completions", endpoint="/v1/chat/completions",
completion_window="24h", completion_window="24h",
metadata={ metadata={"description": "nightly eval job"},
"description": "nightly eval job"
}
) )
print(f'Batch {batch.id} has been created.') print(f"Batch {batch.id} has been created.")
if __name__ == "__main__": if __name__ == "__main__":
batch_eval() batch_eval()

View File

@@ -1,9 +1,8 @@
import os
from openai import OpenAI from openai import OpenAI
# os.environ["OPENAI_API_KEY"] = "" # os.environ["OPENAI_API_KEY"] = ""
def openai_complete_if_cache( def openai_complete_if_cache(
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -47,9 +46,9 @@ if __name__ == "__main__":
... ...
""" """
result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt) result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
file_path = f"./queries.txt" file_path = "./queries.txt"
with open(file_path, "w") as file: with open(file_path, "w") as file:
file.write(result) file.write(result)

View File

@@ -0,0 +1,19 @@
import networkx as nx
from pyvis.network import Network
import random
# Load the GraphML file
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
# Create a Pyvis network
net = Network(notebook=True)
# Convert NetworkX graph to Pyvis network
net.from_nx(G)
# Add colors to nodes
for node in net.nodes:
node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
# Save and display the network
net.show('knowledge_graph.html')

View File

@@ -0,0 +1,118 @@
import os
import json
from lightrag.utils import xml_to_json
from neo4j import GraphDatabase
# Constants
WORKING_DIR = "./dickens"
BATCH_SIZE_NODES = 500
BATCH_SIZE_EDGES = 100
# Neo4j connection credentials
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
print(f"Error: File not found - {xml_path}")
return None
json_data = xml_to_json(xml_path)
if json_data:
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
else:
print("Failed to create JSON data")
return None
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
def main():
# Paths
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
if json_data is None:
return
# Load nodes and edges
nodes = json_data.get('nodes', [])
edges = json_data.get('edges', [])
# Neo4j queries
create_nodes_query = """
UNWIND $nodes AS node
MERGE (e:Entity {id: node.id})
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
"""
create_edges_query = """
UNWIND $edges AS edge
MATCH (source {id: edge.source})
MATCH (target {id: edge.target})
WITH source, target, edge,
CASE
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
WHEN edge.keywords CONTAINS 'located' THEN 'located'
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
END AS relType
CALL apoc.create.relationship(source, relType, {
weight: edge.weight,
description: edge.description,
keywords: edge.keywords,
source_id: edge.source_id
}, target) YIELD rel
RETURN count(*)
"""
set_displayname_and_labels_query = """
MATCH (n)
SET n.displayName = n.id
WITH n
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
RETURN count(*)
"""
# Create a Neo4j driver
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
try:
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
# Insert edges in batches
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,125 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
from dotenv import load_dotenv
import aiohttp
import logging
logging.basicConfig(level=logging.INFO)
load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
headers = {
"Content-Type": "application/json",
"api-key": AZURE_OPENAI_API_KEY,
}
endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}"
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
payload = {
"messages": messages,
"temperature": kwargs.get("temperature", 0),
"top_p": kwargs.get("top_p", 1),
"n": kwargs.get("n", 1),
}
async with aiohttp.ClientSession() as session:
async with session.post(endpoint, headers=headers, json=payload) as response:
if response.status != 200:
raise ValueError(
f"Request failed with status {response.status}: {await response.text()}"
)
result = await response.json()
return result["choices"][0]["message"]["content"]
async def embedding_func(texts: list[str]) -> np.ndarray:
headers = {
"Content-Type": "application/json",
"api-key": AZURE_OPENAI_API_KEY,
}
endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}"
payload = {"input": texts}
async with aiohttp.ClientSession() as session:
async with session.post(endpoint, headers=headers, json=payload) as response:
if response.status != 200:
raise ValueError(
f"Request failed with status {response.status}: {await response.text()}"
)
result = await response.json()
embeddings = [item["embedding"] for item in result["data"]]
return np.array(embeddings)
async def test_funcs():
result = await llm_model_func("How are you?")
print("Resposta do llm_model_func: ", result)
result = await embedding_func(["How are you?"])
print("Resultado do embedding_func: ", result.shape)
print("Dimensão da embedding: ", result.shape[1])
asyncio.run(test_funcs())
embedding_dimension = 3072
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
book1 = open("./book_1.txt", encoding="utf-8")
book2 = open("./book_2.txt", encoding="utf-8")
rag.insert([book1.read(), book2.read()])
query_text = "What are the main themes?"
print("Result (Naive):")
print(rag.query(query_text, param=QueryParam(mode="naive")))
print("\nResult (Local):")
print(rag.query(query_text, param=QueryParam(mode="local")))
print("\nResult (Global):")
print(rag.query(query_text, param=QueryParam(mode="global")))
print("\nResult (Hybrid):")
print(rag.query(query_text, param=QueryParam(mode="hybrid")))

View File

@@ -0,0 +1,36 @@
"""
LightRAG meets Amazon Bedrock ⛰️
"""
import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import bedrock_complete, bedrock_embedding
from lightrag.utils import EmbeddingFunc
logging.getLogger("aiobotocore").setLevel(logging.WARNING)
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=bedrock_complete,
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
embedding_func=EmbeddingFunc(
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
for mode in ["naive", "local", "global", "hybrid"]:
print("\n+-" + "-" * len(mode) + "-+")
print(f"| {mode.capitalize()} |")
print("+-" + "-" * len(mode) + "-+\n")
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
)

View File

@@ -1,10 +1,9 @@
import os import os
import sys
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import hf_model_complete, hf_embedding from lightrag.llm import hf_model_complete, hf_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from transformers import AutoModel,AutoTokenizer from transformers import AutoModel, AutoTokenizer
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
@@ -14,15 +13,19 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=hf_model_complete, llm_model_func=hf_model_complete,
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=384, embedding_dim=384,
max_token_size=5000, max_token_size=5000,
func=lambda texts: hf_embedding( func=lambda texts: hf_embedding(
texts, texts,
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"), tokenizer=AutoTokenizer.from_pretrained(
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") "sentence-transformers/all-MiniLM-L6-v2"
) ),
embed_model=AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
),
), ),
) )
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -12,14 +12,11 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete, llm_model_func=ollama_model_complete,
llm_model_name='your_model_name', llm_model_name="your_model_name",
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=768, embedding_dim=768,
max_token_size=8192, max_token_size=8192,
func=lambda texts: ollama_embedding( func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
texts,
embed_model="nomic-embed-text"
)
), ),
) )
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -10,6 +10,7 @@ WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -20,17 +21,19 @@ async def llm_model_func(
history_messages=history_messages, history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"), api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar", base_url="https://api.upstage.ai/v1/solar",
**kwargs **kwargs,
) )
async def embedding_func(texts: list[str]) -> np.ndarray: async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding( return await openai_embedding(
texts, texts,
model="solar-embedding-1-large-query", model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"), api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar" base_url="https://api.upstage.ai/v1/solar",
) )
# function test # function test
async def test_funcs(): async def test_funcs():
result = await llm_model_func("How are you?") result = await llm_model_func("How are you?")
@@ -39,6 +42,7 @@ async def test_funcs():
result = await embedding_func(["How are you?"]) result = await embedding_func(["How are you?"])
print("embedding_func: ", result) print("embedding_func: ", result)
asyncio.run(test_funcs()) asyncio.run(test_funcs())
@@ -46,10 +50,8 @@ rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=4096, embedding_dim=4096, max_token_size=8192, func=embedding_func
max_token_size=8192, ),
func=embedding_func
)
) )
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,9 +1,7 @@
import os import os
import sys
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete from lightrag.llm import gpt_4o_mini_complete
from transformers import AutoModel,AutoTokenizer
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete llm_model_func=gpt_4o_mini_complete,
# llm_model_func=gpt_4o_complete # llm_model_func=gpt_4o_complete
) )
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
# Perform naive search # Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search # Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search # Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search # Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -0,0 +1,82 @@
import os
import time
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc
# Working directory and the directory path for text files
WORKING_DIR = "./dickens"
TEXT_FILES_DIR = "/llm/mt"
# Create the working directory if it doesn't exist
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# Initialize LightRAG
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="qwen2.5:3b-instruct-max-context",
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
),
)
# Read all .txt files from the TEXT_FILES_DIR directory
texts = []
for filename in os.listdir(TEXT_FILES_DIR):
if filename.endswith('.txt'):
file_path = os.path.join(TEXT_FILES_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as file:
texts.append(file.read())
# Batch insert texts into LightRAG with a retry mechanism
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
for _ in range(retries):
try:
rag.insert(texts)
return
except Exception as e:
print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
time.sleep(delay)
raise RuntimeError("Failed to insert texts after multiple retries.")
insert_texts_with_retry(rag, texts)
# Perform different types of queries and handle potential errors
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
except Exception as e:
print(f"Error performing naive search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
except Exception as e:
print(f"Error performing local search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
except Exception as e:
print(f"Error performing global search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
except Exception as e:
print(f"Error performing hybrid search: {e}")
# Function to clear VRAM resources
def clear_vram():
os.system("sudo nvidia-smi --gpu-reset")
# Regularly clear VRAM to prevent overflow
clear_vram_interval = 3600 # Clear once every hour
start_time = time.time()
while True:
current_time = time.time()
if current_time - start_time > clear_vram_interval:
clear_vram()
start_time = current_time
time.sleep(60) # Check the time every minute

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG, QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "0.0.6" __version__ = "0.0.7"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -12,6 +12,7 @@ TextChunkSchema = TypedDict(
T = TypeVar("T") T = TypeVar("T")
@dataclass @dataclass
class QueryParam: class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global" mode: Literal["local", "global", "hybrid", "naive"] = "global"
@@ -36,6 +37,7 @@ class StorageNameSpace:
"""commit the storage operations after querying""" """commit the storage operations after querying"""
pass pass
@dataclass @dataclass
class BaseVectorStorage(StorageNameSpace): class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc embedding_func: EmbeddingFunc
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
""" """
raise NotImplementedError raise NotImplementedError
@dataclass @dataclass
class BaseKVStorage(Generic[T], StorageNameSpace): class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]: async def all_keys(self) -> list[str]:

View File

@@ -3,10 +3,12 @@ import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Type, cast, Any from typing import Type, cast
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding from .llm import (
gpt_4o_mini_complete,
openai_embedding,
)
from .operate import ( from .operate import (
chunking_by_token_size, chunking_by_token_size,
extract_entities, extract_entities,
@@ -37,6 +39,7 @@ from .base import (
QueryParam, QueryParam,
) )
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@@ -69,7 +72,6 @@ class LightRAG:
"dimensions": 1536, "dimensions": 1536,
"num_walks": 10, "num_walks": 10,
"walk_length": 40, "walk_length": 40,
"num_walks": 10,
"window_size": 2, "window_size": 2,
"iterations": 3, "iterations": 3,
"random_seed": 3, "random_seed": 3,
@@ -77,13 +79,13 @@ class LightRAG:
) )
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding) # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding) embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32 embedding_batch_num: int = 32
embedding_func_max_async: int = 16 embedding_func_max_async: int = 16
# LLM # LLM
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete# llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768 llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16 llm_model_max_async: int = 16
@@ -133,28 +135,22 @@ class LightRAG:
self.embedding_func self.embedding_func
) )
self.entities_vdb = ( self.entities_vdb = self.vector_db_storage_cls(
self.vector_db_storage_cls( namespace="entities",
namespace="entities", global_config=asdict(self),
global_config=asdict(self), embedding_func=self.embedding_func,
embedding_func=self.embedding_func, meta_fields={"entity_name"},
meta_fields={"entity_name"}
)
) )
self.relationships_vdb = ( self.relationships_vdb = self.vector_db_storage_cls(
self.vector_db_storage_cls( namespace="relationships",
namespace="relationships", global_config=asdict(self),
global_config=asdict(self), embedding_func=self.embedding_func,
embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id"},
meta_fields={"src_id", "tgt_id"}
)
) )
self.chunks_vdb = ( self.chunks_vdb = self.vector_db_storage_cls(
self.vector_db_storage_cls( namespace="chunks",
namespace="chunks", global_config=asdict(self),
global_config=asdict(self), embedding_func=self.embedding_func,
embedding_func=self.embedding_func,
)
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
@@ -177,7 +173,7 @@ class LightRAG:
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs): if not len(new_docs):
logger.warning(f"All docs are already in the storage") logger.warning("All docs are already in the storage")
return return
logger.info(f"[New Docs] inserting {len(new_docs)} docs") logger.info(f"[New Docs] inserting {len(new_docs)} docs")
@@ -203,7 +199,7 @@ class LightRAG:
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
} }
if not len(inserting_chunks): if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage") logger.warning("All chunks are already in the storage")
return return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
@@ -291,7 +287,6 @@ class LightRAG:
await self._query_done() await self._query_done()
return response return response
async def _query_done(self): async def _query_done(self):
tasks = [] tasks = []
for storage_inst in [self.llm_response_cache]: for storage_inst in [self.llm_response_cache]:
@@ -299,5 +294,3 @@ class LightRAG:
continue continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@@ -1,4 +1,7 @@
import os import os
import copy
import json
import aioboto3
import numpy as np import numpy as np
import ollama import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -8,24 +11,34 @@ from tenacity import (
wait_exponential, wait_exponential,
retry_if_exception_type, retry_if_exception_type,
) )
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import torch import torch
from .base import BaseKVStorage from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs from .utils import compute_args_hash, wrap_embedding_func_with_attrs
import copy
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
) )
async def openai_complete_if_cache( async def openai_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
) -> str: ) -> str:
if api_key: if api_key:
os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
@@ -48,15 +61,105 @@ async def openai_complete_if_cache(
) )
return response.choices[0].message.content return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
if hashing_kv is not None:
await hashing_kv.upsert(
{
args_hash: {
"return": response["output"]["message"]["content"][0]["text"],
"model": model,
}
}
)
return response["output"]["message"]["content"][0]["text"]
async def hf_model_if_cache( async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
model_name = model model_name = model
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto') hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
if hf_tokenizer.pad_token == None: if hf_tokenizer.pad_token is None:
# print("use eos token") # print("use eos token")
hf_tokenizer.pad_token = hf_tokenizer.eos_token hf_tokenizer.pad_token = hf_tokenizer.eos_token
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto') hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
@@ -69,30 +172,51 @@ async def hf_model_if_cache(
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
input_prompt = '' input_prompt = ""
try: try:
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_prompt = hf_tokenizer.apply_chat_template(
except: messages, tokenize=False, add_generation_prompt=True
)
except Exception:
try: try:
ori_message = copy.deepcopy(messages) ori_message = copy.deepcopy(messages)
if messages[0]['role'] == "system": if messages[0]["role"] == "system":
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content'] messages[1]["content"] = (
"<system>"
+ messages[0]["content"]
+ "</system>\n"
+ messages[1]["content"]
)
messages = messages[1:] messages = messages[1:]
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_prompt = hf_tokenizer.apply_chat_template(
except: messages, tokenize=False, add_generation_prompt=True
)
except Exception:
len_message = len(ori_message) len_message = len(ori_message)
for msgid in range(len_message): for msgid in range(len_message):
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n' input_prompt = (
input_prompt
+ "<"
+ ori_message[msgid]["role"]
+ ">"
+ ori_message[msgid]["content"]
+ "</"
+ ori_message[msgid]["role"]
+ ">\n"
)
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda") input_ids = hf_tokenizer(
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True) input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
output = hf_model.generate(
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True) response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert( await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
{args_hash: {"return": response_text, "model": model}}
)
return response_text return response_text
async def ollama_model_if_cache( async def ollama_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -122,6 +246,7 @@ async def ollama_model_if_cache(
return result return result
async def gpt_4o_complete( async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -145,10 +270,23 @@ async def gpt_4o_mini_complete(
**kwargs, **kwargs,
) )
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await bedrock_complete_if_cache(
"anthropic.claude-3-haiku-20240307-v1:0",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def hf_model_complete( async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name'] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await hf_model_if_cache( return await hf_model_if_cache(
model_name, model_name,
prompt, prompt,
@@ -157,10 +295,11 @@ async def hf_model_complete(
**kwargs, **kwargs,
) )
async def ollama_model_complete( async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
model_name = kwargs['hashing_kv'].global_config['llm_model_name'] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await ollama_model_if_cache( return await ollama_model_if_cache(
model_name, model_name,
prompt, prompt,
@@ -169,30 +308,113 @@ async def ollama_model_complete(
**kwargs, **kwargs,
) )
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
) )
async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray: async def openai_embedding(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key: if api_key:
os.environ["OPENAI_API_KEY"] = api_key os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"
) )
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embedding(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids
with torch.no_grad(): with torch.no_grad():
outputs = embed_model(input_ids) outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy() return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
embed_text = [] embed_text = []
for text in texts: for text in texts:
@@ -201,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
return embed_text return embed_text
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
async def main(): async def main():
result = await gpt_4o_mini_complete('How are you?') result = await gpt_4o_mini_complete("How are you?")
print(result) print(result)
asyncio.run(main()) asyncio.run(main())

View File

@@ -25,6 +25,7 @@ from .base import (
) )
from .prompt import GRAPH_FIELD_SEP, PROMPTS from .prompt import GRAPH_FIELD_SEP, PROMPTS
def chunking_by_token_size( def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
): ):
@@ -45,6 +46,7 @@ def chunking_by_token_size(
) )
return results return results
async def _handle_entity_relation_summary( async def _handle_entity_relation_summary(
entity_or_relation_name: str, entity_or_relation_name: str,
description: str, description: str,
@@ -76,7 +78,7 @@ async def _handle_single_entity_extraction(
record_attributes: list[str], record_attributes: list[str],
chunk_key: str, chunk_key: str,
): ):
if record_attributes[0] != '"entity"' or len(record_attributes) < 4: if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None return None
# add this record as a node in the G # add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper()) entity_name = clean_str(record_attributes[1].upper())
@@ -97,7 +99,7 @@ async def _handle_single_relationship_extraction(
record_attributes: list[str], record_attributes: list[str],
chunk_key: str, chunk_key: str,
): ):
if record_attributes[0] != '"relationship"' or len(record_attributes) < 5: if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
return None return None
# add this record as edge # add this record as edge
source = clean_str(record_attributes[1].upper()) source = clean_str(record_attributes[1].upper())
@@ -232,6 +234,7 @@ async def _merge_edges_then_upsert(
return edge_data return edge_data
async def extract_entities( async def extract_entities(
chunks: dict[str, TextChunkSchema], chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage, knwoledge_graph_inst: BaseGraphStorage,
@@ -352,7 +355,9 @@ async def extract_entities(
logger.warning("Didn't extract any entities, maybe your LLM is not working") logger.warning("Didn't extract any entities, maybe your LLM is not working")
return None return None
if not len(all_relationships_data): if not len(all_relationships_data):
logger.warning("Didn't extract any relationships, maybe your LLM is not working") logger.warning(
"Didn't extract any relationships, maybe your LLM is not working"
)
return None return None
if entity_vdb is not None: if entity_vdb is not None:
@@ -370,7 +375,10 @@ async def extract_entities(
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): { compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"], "src_id": dp["src_id"],
"tgt_id": dp["tgt_id"], "tgt_id": dp["tgt_id"],
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"], "content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
} }
for dp in all_relationships_data for dp in all_relationships_data
} }
@@ -378,6 +386,7 @@ async def extract_entities(
return knwoledge_graph_inst return knwoledge_graph_inst
async def local_query( async def local_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@@ -387,6 +396,7 @@ async def local_query(
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
) -> str: ) -> str:
context = None
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
@@ -396,24 +406,32 @@ async def local_query(
try: try:
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", []) keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords) keywords = ", ".join(keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError:
try: try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("low_level_keywords", []) keywords = keywords_data.get("low_level_keywords", [])
keywords = ', '.join(keywords) keywords = ", ".join(keywords)
# Handle parsing error # Handle parsing error
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
context = await _build_local_query_context( if keywords:
keywords, context = await _build_local_query_context(
knowledge_graph_inst, keywords,
entities_vdb, knowledge_graph_inst,
text_chunks_db, entities_vdb,
query_param, text_chunks_db,
) query_param,
)
if query_param.only_need_context: if query_param.only_need_context:
return context return context
if context is None: if context is None:
@@ -426,11 +444,20 @@ async def local_query(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
) )
if len(response)>len(sys_prompt): if len(response) > len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip() response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response return response
async def _build_local_query_context( async def _build_local_query_context(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@@ -512,6 +539,7 @@ async def _build_local_query_context(
``` ```
""" """
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
node_datas: list[dict], node_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
@@ -572,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units] all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
return all_text_units return all_text_units
async def _find_most_related_edges_from_entities( async def _find_most_related_edges_from_entities(
node_datas: list[dict], node_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
@@ -605,6 +634,7 @@ async def _find_most_related_edges_from_entities(
) )
return all_edges_data return all_edges_data
async def global_query( async def global_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@@ -614,6 +644,7 @@ async def global_query(
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
) -> str: ) -> str:
context = None
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
@@ -623,27 +654,34 @@ async def global_query(
try: try:
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", []) keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords) keywords = ", ".join(keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError:
try: try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result) keywords_data = json.loads(result)
keywords = keywords_data.get("high_level_keywords", []) keywords = keywords_data.get("high_level_keywords", [])
keywords = ', '.join(keywords) keywords = ", ".join(keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# Handle parsing error # Handle parsing error
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
if keywords:
context = await _build_global_query_context( context = await _build_global_query_context(
keywords, keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
if query_param.only_need_context: if query_param.only_need_context:
return context return context
@@ -658,11 +696,20 @@ async def global_query(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
) )
if len(response)>len(sys_prompt): if len(response) > len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip() response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response return response
async def _build_global_query_context( async def _build_global_query_context(
keywords, keywords,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@@ -758,6 +805,7 @@ async def _build_global_query_context(
``` ```
""" """
async def _find_most_related_entities_from_relationships( async def _find_most_related_entities_from_relationships(
edge_datas: list[dict], edge_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
@@ -788,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
return node_datas return node_datas
async def _find_related_text_unit_from_relationships( async def _find_related_text_unit_from_relationships(
edge_datas: list[dict], edge_datas: list[dict],
query_param: QueryParam, query_param: QueryParam,
text_chunks_db: BaseKVStorage[TextChunkSchema], text_chunks_db: BaseKVStorage[TextChunkSchema],
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
): ):
text_units = [ text_units = [
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
for dp in edge_datas for dp in edge_datas
@@ -815,9 +863,7 @@ async def _find_related_text_unit_from_relationships(
all_text_units = [ all_text_units = [
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
] ]
all_text_units = sorted( all_text_units = sorted(all_text_units, key=lambda x: x["order"])
all_text_units, key=lambda x: x["order"]
)
all_text_units = truncate_list_by_token_size( all_text_units = truncate_list_by_token_size(
all_text_units, all_text_units,
key=lambda x: x["data"]["content"], key=lambda x: x["data"]["content"],
@@ -827,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
return all_text_units return all_text_units
async def hybrid_query( async def hybrid_query(
query, query,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@@ -836,6 +883,8 @@ async def hybrid_query(
query_param: QueryParam, query_param: QueryParam,
global_config: dict, global_config: dict,
) -> str: ) -> str:
low_level_context = None
high_level_context = None
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt_temp = PROMPTS["keywords_extraction"]
@@ -846,37 +895,46 @@ async def hybrid_query(
keywords_data = json.loads(result) keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords) hl_keywords = ", ".join(hl_keywords)
ll_keywords = ', '.join(ll_keywords) ll_keywords = ", ".join(ll_keywords)
except json.JSONDecodeError as e: except json.JSONDecodeError:
try: try:
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json') result = (
result.replace(kw_prompt[:-1], "")
.replace("user", "")
.replace("model", "")
.strip()
)
result = "{" + result.split("{")[1].split("}")[0] + "}"
keywords_data = json.loads(result) keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", []) hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", [])
hl_keywords = ', '.join(hl_keywords) hl_keywords = ", ".join(hl_keywords)
ll_keywords = ', '.join(ll_keywords) ll_keywords = ", ".join(ll_keywords)
# Handle parsing error # Handle parsing error
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}") print(f"JSON parsing error: {e}")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
low_level_context = await _build_local_query_context( if ll_keywords:
ll_keywords, low_level_context = await _build_local_query_context(
knowledge_graph_inst, ll_keywords,
entities_vdb, knowledge_graph_inst,
text_chunks_db, entities_vdb,
query_param, text_chunks_db,
) query_param,
)
high_level_context = await _build_global_query_context( if hl_keywords:
hl_keywords, high_level_context = await _build_global_query_context(
knowledge_graph_inst, hl_keywords,
entities_vdb, knowledge_graph_inst,
relationships_vdb, entities_vdb,
text_chunks_db, relationships_vdb,
query_param, text_chunks_db,
) query_param,
)
context = combine_contexts(high_level_context, low_level_context) context = combine_contexts(high_level_context, low_level_context)
@@ -893,52 +951,77 @@ async def hybrid_query(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
) )
if len(response)>len(sys_prompt): if len(response) > len(sys_prompt):
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip() response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response return response
def combine_contexts(high_level_context, low_level_context): def combine_contexts(high_level_context, low_level_context):
# Function to extract entities, relationships, and sources from context strings # Function to extract entities, relationships, and sources from context strings
def extract_sections(context): def extract_sections(context):
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) entities_match = re.search(
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL) )
relationships_match = re.search(
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
sources_match = re.search(
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
)
entities = entities_match.group(1) if entities_match else '' entities = entities_match.group(1) if entities_match else ""
relationships = relationships_match.group(1) if relationships_match else '' relationships = relationships_match.group(1) if relationships_match else ""
sources = sources_match.group(1) if sources_match else '' sources = sources_match.group(1) if sources_match else ""
return entities, relationships, sources return entities, relationships, sources
# Extract sections from both contexts # Extract sections from both contexts
if high_level_context==None: if high_level_context is None:
warnings.warn("High Level context is None. Return empty High entity/relationship/source") warnings.warn(
hl_entities, hl_relationships, hl_sources = '','','' "High Level context is None. Return empty High entity/relationship/source"
)
hl_entities, hl_relationships, hl_sources = "", "", ""
else: else:
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context) hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
if low_level_context is None:
if low_level_context==None: warnings.warn(
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source") "Low Level context is None. Return empty Low entity/relationship/source"
ll_entities, ll_relationships, ll_sources = '','','' )
ll_entities, ll_relationships, ll_sources = "", "", ""
else: else:
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
# Combine and deduplicate the entities # Combine and deduplicate the entities
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n'))) combined_entities_set = set(
combined_entities = '\n'.join(combined_entities_set) filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
)
combined_entities = "\n".join(combined_entities_set)
# Combine and deduplicate the relationships # Combine and deduplicate the relationships
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n'))) combined_relationships_set = set(
combined_relationships = '\n'.join(combined_relationships_set) filter(
None,
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
)
)
combined_relationships = "\n".join(combined_relationships_set)
# Combine and deduplicate the sources # Combine and deduplicate the sources
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n'))) combined_sources_set = set(
combined_sources = '\n'.join(combined_sources_set) filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
)
combined_sources = "\n".join(combined_sources_set)
# Format the combined context # Format the combined context
return f""" return f"""
@@ -951,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
{combined_sources} {combined_sources}
""" """
async def naive_query( async def naive_query(
query, query,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
@@ -983,8 +1067,16 @@ async def naive_query(
system_prompt=sys_prompt, system_prompt=sys_prompt,
) )
if len(response)>len(sys_prompt): if len(response) > len(sys_prompt):
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip() response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response return response

View File

@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"] PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
PROMPTS[ PROMPTS["entity_extraction"] = """-Goal-
"entity_extraction"
] = """-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps- -Steps-
@@ -146,9 +144,7 @@ PROMPTS[
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question." PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
PROMPTS[ PROMPTS["rag_response"] = """---Role---
"rag_response"
] = """---Role---
You are a helpful assistant responding to questions about data in the tables provided. You are a helpful assistant responding to questions about data in the tables provided.
@@ -163,25 +159,10 @@ Do not include information where the supporting evidence for it is not provided.
{response_type} {response_type}
---Data tables--- ---Data tables---
{context_data} {context_data}
---Goal---
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
If you don't know the answer, just say so. Do not make anything up.
Do not include information where the supporting evidence for it is not provided.
---Target response length and format---
{response_type}
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
""" """
@@ -241,9 +222,7 @@ Output:
""" """
PROMPTS[ PROMPTS["naive_rag_response"] = """You're a helpful assistant
"naive_rag_response"
] = """You're a helpful assistant
Below are the knowledge you know: Below are the knowledge you know:
{content_data} {content_data}
--- ---

View File

@@ -1,16 +1,11 @@
import asyncio import asyncio
import html import html
import json
import os import os
from collections import defaultdict from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Union, cast from typing import Any, Union, cast
import pickle
import hnswlib
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
import xxhash
from .utils import load_json, logger, write_json from .utils import load_json, logger, write_json
from .base import ( from .base import (
@@ -19,6 +14,7 @@ from .base import (
BaseVectorStorage, BaseVectorStorage,
) )
@dataclass @dataclass
class JsonKVStorage(BaseKVStorage): class JsonKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
async def drop(self): async def drop(self):
self._data = {} self._data = {}
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2 cosine_better_than_threshold: float = 0.2
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
) )
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def index_done_callback(self): async def index_done_callback(self):
self._client.save() self._client.save()
@dataclass @dataclass
class NetworkXStorage(BaseGraphStorage): class NetworkXStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
graph = graph.copy() graph = graph.copy()
graph = cast(nx.Graph, largest_connected_component(graph)) graph = cast(nx.Graph, largest_connected_component(graph))
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore node_mapping = {
node: html.unescape(node.upper().strip()) for node in graph.nodes()
} # type: ignore
graph = nx.relabel_nodes(graph, node_mapping) graph = nx.relabel_nodes(graph, node_mapping)
return NetworkXStorage._stabilize_graph(graph) return NetworkXStorage._stabilize_graph(graph)

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
from typing import Any, Union from typing import Any, Union
import xml.etree.ElementTree as ET
import numpy as np import numpy as np
import tiktoken import tiktoken
@@ -16,18 +17,22 @@ ENCODER = None
logger = logging.getLogger("lightrag") logger = logging.getLogger("lightrag")
def set_logger(log_file: str): def set_logger(log_file: str):
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(log_file) file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
if not logger.handlers: if not logger.handlers:
logger.addHandler(file_handler) logger.addHandler(file_handler)
@dataclass @dataclass
class EmbeddingFunc: class EmbeddingFunc:
embedding_dim: int embedding_dim: int
@@ -37,6 +42,7 @@ class EmbeddingFunc:
async def __call__(self, *args, **kwargs) -> np.ndarray: async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs) return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]: def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string""" """Locate the JSON string body from a string"""
maybe_json_str = re.search(r"{.*}", content, re.DOTALL) maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -45,6 +51,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
else: else:
return None return None
def convert_response_to_json(response: str) -> dict: def convert_response_to_json(response: str) -> dict:
json_str = locate_json_string_body_from_string(response) json_str = locate_json_string_body_from_string(response)
assert json_str is not None, f"Unable to parse JSON from response: {response}" assert json_str is not None, f"Unable to parse JSON from response: {response}"
@@ -55,12 +62,15 @@ def convert_response_to_json(response: str) -> dict:
logger.error(f"Failed to parse JSON: {json_str}") logger.error(f"Failed to parse JSON: {json_str}")
raise e from None raise e from None
def compute_args_hash(*args): def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest() return md5(str(args).encode()).hexdigest()
def compute_mdhash_id(content, prefix: str = ""): def compute_mdhash_id(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest() return prefix + md5(content.encode()).hexdigest()
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func""" """Add restriction of maximum async calling times for a async func"""
@@ -82,6 +92,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
return final_decro return final_decro
def wrap_embedding_func_with_attrs(**kwargs): def wrap_embedding_func_with_attrs(**kwargs):
"""Wrap a function with attributes""" """Wrap a function with attributes"""
@@ -91,16 +102,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
return final_decro return final_decro
def load_json(file_name): def load_json(file_name):
if not os.path.exists(file_name): if not os.path.exists(file_name):
return None return None
with open(file_name, encoding="utf-8") as f: with open(file_name, encoding="utf-8") as f:
return json.load(f) return json.load(f)
def write_json(json_obj, file_name): def write_json(json_obj, file_name):
with open(file_name, "w", encoding="utf-8") as f: with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False) json.dump(json_obj, f, indent=2, ensure_ascii=False)
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
global ENCODER global ENCODER
if ENCODER is None: if ENCODER is None:
@@ -116,12 +130,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
content = ENCODER.decode(tokens) content = ENCODER.decode(tokens)
return content return content
def pack_user_ass_to_openai_messages(*args: str): def pack_user_ass_to_openai_messages(*args: str):
roles = ["user", "assistant"] roles = ["user", "assistant"]
return [ return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args) {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
] ]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers""" """Split a string by multiple markers"""
if not markers: if not markers:
@@ -129,6 +145,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
results = re.split("|".join(re.escape(marker) for marker in markers), content) results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()] return [r.strip() for r in results if r.strip()]
# Refer the utils functions of the official GraphRAG implementation: # Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag # https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str: def clean_str(input: Any) -> str:
@@ -141,9 +158,11 @@ def clean_str(input: Any) -> str:
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
def is_float_regex(value): def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size""" """Truncate a list of data by token size"""
if max_token_size <= 0: if max_token_size <= 0:
@@ -155,11 +174,61 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
return list_data[:i] return list_data[:i]
return list_data return list_data
def list_of_list_to_csv(data: list[list]): def list_of_list_to_csv(data: list[list]):
return "\n".join( return "\n".join(
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
) )
def save_data_to_file(data, file_name): def save_data_to_file(data, file_name):
with open(file_name, 'w', encoding='utf-8') as f: with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4) json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Print the root element's tag and attributes to confirm the file has been correctly loaded
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
data = {
"nodes": [],
"edges": []
}
# Use namespace
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
for node in root.findall('.//node', namespace):
node_data = {
"id": node.get('id').strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
}
data["nodes"].append(node_data)
for edge in root.findall('.//edge', namespace):
edge_data = {
"source": edge.get('source').strip('"'),
"target": edge.get('target').strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
}
data["edges"].append(edge_data)
# Print the number of nodes and edges found
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None

View File

@@ -3,11 +3,11 @@ import json
import glob import glob
import argparse import argparse
def extract_unique_contexts(input_directory, output_directory):
def extract_unique_contexts(input_directory, output_directory):
os.makedirs(output_directory, exist_ok=True) os.makedirs(output_directory, exist_ok=True)
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl')) jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
print(f"Found {len(jsonl_files)} JSONL files.") print(f"Found {len(jsonl_files)} JSONL files.")
for file_path in jsonl_files: for file_path in jsonl_files:
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
print(f"Processing file: {filename}") print(f"Processing file: {filename}")
try: try:
with open(file_path, 'r', encoding='utf-8') as infile: with open(file_path, "r", encoding="utf-8") as infile:
for line_number, line in enumerate(infile, start=1): for line_number, line in enumerate(infile, start=1):
line = line.strip() line = line.strip()
if not line: if not line:
continue continue
try: try:
json_obj = json.loads(line) json_obj = json.loads(line)
context = json_obj.get('context') context = json_obj.get("context")
if context and context not in unique_contexts_dict: if context and context not in unique_contexts_dict:
unique_contexts_dict[context] = None unique_contexts_dict[context] = None
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f"JSON decoding error in file {filename} at line {line_number}: {e}") print(
f"JSON decoding error in file {filename} at line {line_number}: {e}"
)
except FileNotFoundError: except FileNotFoundError:
print(f"File not found: {filename}") print(f"File not found: {filename}")
continue continue
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
continue continue
unique_contexts_list = list(unique_contexts_dict.keys()) unique_contexts_list = list(unique_contexts_dict.keys())
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.") print(
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
)
try: try:
with open(output_path, 'w', encoding='utf-8') as outfile: with open(output_path, "w", encoding="utf-8") as outfile:
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4) json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
print(f"Unique `context` entries have been saved to: {output_filename}") print(f"Unique `context` entries have been saved to: {output_filename}")
except Exception as e: except Exception as e:
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_dir', type=str, default='../datasets') parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts') parser.add_argument(
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
)
args = parser.parse_args() args = parser.parse_args()

View File

@@ -4,8 +4,9 @@ import time
from lightrag import LightRAG from lightrag import LightRAG
def insert_text(rag, file_path): def insert_text(rag, file_path):
with open(file_path, mode='r') as f: with open(file_path, mode="r") as f:
unique_contexts = json.load(f) unique_contexts = json.load(f)
retries = 0 retries = 0
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
if retries == max_retries: if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries") print("Insertion failed after exceeding the maximum number of retries")
cls = "agriculture" cls = "agriculture"
WORKING_DIR = "../{cls}" WORKING_DIR = "../{cls}"

View File

@@ -0,0 +1,71 @@
import os
import json
import time
import numpy as np
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm import openai_complete_if_cache, openai_embedding
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
)
## /For Upstage API
def insert_text(rag, file_path):
with open(file_path, mode="r") as f:
unique_contexts = json.load(f)
retries = 0
max_retries = 3
while retries < max_retries:
try:
rag.insert(unique_contexts)
break
except Exception as e:
retries += 1
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
time.sleep(10)
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
cls = "mix"
WORKING_DIR = f"../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")

View File

@@ -1,8 +1,8 @@
import os
import json import json
from openai import OpenAI from openai import OpenAI
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
def openai_complete_if_cache( def openai_complete_if_cache(
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -19,14 +19,16 @@ def openai_complete_if_cache(
) )
return response.choices[0].message.content return response.choices[0].message.content
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def get_summary(context, tot_tokens=2000): def get_summary(context, tot_tokens=2000):
tokens = tokenizer.tokenize(context) tokens = tokenizer.tokenize(context)
half_tokens = tot_tokens // 2 half_tokens = tot_tokens // 2
start_tokens = tokens[1000:1000 + half_tokens] start_tokens = tokens[1000 : 1000 + half_tokens]
end_tokens = tokens[-(1000 + half_tokens):1000] end_tokens = tokens[-(1000 + half_tokens) : 1000]
summary_tokens = start_tokens + end_tokens summary_tokens = start_tokens + end_tokens
summary = tokenizer.convert_tokens_to_string(summary_tokens) summary = tokenizer.convert_tokens_to_string(summary_tokens)
@@ -34,9 +36,9 @@ def get_summary(context, tot_tokens=2000):
return summary return summary
clses = ['agriculture'] clses = ["agriculture"]
for cls in clses: for cls in clses:
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f: with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
unique_contexts = json.load(f) unique_contexts = json.load(f)
summaries = [get_summary(context) for context in unique_contexts] summaries = [get_summary(context) for context in unique_contexts]
@@ -67,7 +69,7 @@ for cls in clses:
... ...
""" """
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt) result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
file_path = f"../datasets/questions/{cls}_questions.txt" file_path = f"../datasets/questions/{cls}_questions.txt"
with open(file_path, "w") as file: with open(file_path, "w") as file:

View File

@@ -4,16 +4,18 @@ import asyncio
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from tqdm import tqdm from tqdm import tqdm
def extract_queries(file_path): def extract_queries(file_path):
with open(file_path, 'r') as f: with open(file_path, "r") as f:
data = f.read() data = f.read()
data = data.replace('**', '') data = data.replace("**", "")
queries = re.findall(r'- Question \d+: (.+)', data) queries = re.findall(r"- Question \d+: (.+)", data)
return queries return queries
async def process_query(query_text, rag_instance, query_param): async def process_query(query_text, rag_instance, query_param):
try: try:
result, context = await rag_instance.aquery(query_text, param=query_param) result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
except Exception as e: except Exception as e:
return None, {"query": query_text, "error": str(e)} return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop: def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
return loop return loop
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file: with open(output_file, "a", encoding="utf-8") as result_file, open(
error_file, "a", encoding="utf-8"
) as err_file:
result_file.write("[\n") result_file.write("[\n")
first_entry = True first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"): for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param)) result, error = loop.run_until_complete(
process_query(query_text, rag_instance, query_param)
)
if result: if result:
if not first_entry: if not first_entry:
@@ -50,13 +60,16 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
result_file.write("\n]") result_file.write("\n]")
if __name__ == "__main__": if __name__ == "__main__":
cls = "agriculture" cls = "agriculture"
mode = "hybrid" mode = "hybrid"
WORKING_DIR = "../{cls}" WORKING_DIR = f"../{cls}"
rag = LightRAG(working_dir=WORKING_DIR) rag = LightRAG(working_dir=WORKING_DIR)
query_param = QueryParam(mode=mode) query_param = QueryParam(mode=mode)
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt") queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, "result.json", "errors.json") run_queries_and_save_to_json(
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
)

View File

@@ -0,0 +1,115 @@
import os
import re
import json
import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
)
## /For Upstage API
def extract_queries(file_path):
with open(file_path, "r") as f:
data = f.read()
data = data.replace("**", "")
queries = re.findall(r"- Question \d+: (.+)", data)
return queries
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result, "context": context}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file
):
loop = always_get_an_event_loop()
with open(output_file, "a", encoding="utf-8") as result_file, open(
error_file, "a", encoding="utf-8"
) as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(
process_query(query_text, rag_instance, query_param)
)
if result:
if not first_entry:
result_file.write(",\n")
json.dump(result, result_file, ensure_ascii=False, indent=4)
first_entry = False
elif error:
json.dump(error, err_file, ensure_ascii=False, indent=4)
err_file.write("\n")
result_file.write("\n]")
if __name__ == "__main__":
cls = "mix"
mode = "hybrid"
WORKING_DIR = f"../{cls}"
rag = LightRAG(working_dir=WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
query_param = QueryParam(mode=mode)
base_dir = "../datasets/questions"
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
run_queries_and_save_to_json(
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
)

View File

@@ -1,12 +1,14 @@
openai
tiktoken
networkx
graspologic
nano-vectordb
hnswlib
xxhash
tenacity
transformers
torch
ollama
accelerate accelerate
aioboto3
graspologic
hnswlib
nano-vectordb
networkx
ollama
openai
tenacity
tiktoken
torch
transformers
xxhash
pyvis