Merge branch 'HKUDS:main' into main
This commit is contained in:
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
__pycache__
|
||||||
|
*.egg-info
|
||||||
|
dickens/
|
||||||
|
book.txt
|
||||||
|
lightrag-dev/
|
||||||
|
.idea/
|
||||||
|
dist/
|
22
.pre-commit-config.yaml
Normal file
22
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: requirements-txt-fixer
|
||||||
|
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.6.4
|
||||||
|
hooks:
|
||||||
|
- id: ruff-format
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
|
||||||
|
|
||||||
|
- repo: https://github.com/mgedmin/check-manifest
|
||||||
|
rev: "0.49"
|
||||||
|
hooks:
|
||||||
|
- id: check-manifest
|
||||||
|
stages: [manual]
|
220
README.md
220
README.md
@@ -6,10 +6,12 @@
|
|||||||
<div align='center'>
|
<div align='center'>
|
||||||
<p>
|
<p>
|
||||||
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
||||||
|
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
|
||||||
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
|
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
|
||||||
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
|
<a href='https://discord.gg/mvsfu2Tg'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
|
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
|
||||||
<img src="https://img.shields.io/badge/python->=3.9.11-blue">
|
<img src="https://img.shields.io/badge/python->=3.9.11-blue">
|
||||||
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
||||||
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
||||||
@@ -20,12 +22,15 @@ This repository hosts the code of LightRAG. The structure of this code is based
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
## 🎉 News
|
## 🎉 News
|
||||||
|
- [x] [2024.10.20]🎯🎯📢📢We’ve added a new feature to LightRAG: Graph Visualization.
|
||||||
|
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
||||||
|
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
|
||||||
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||||
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
* Install from source
|
* Install from source (Recommend)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd LightRAG
|
cd LightRAG
|
||||||
@@ -43,12 +48,21 @@ pip install lightrag-hku
|
|||||||
```bash
|
```bash
|
||||||
curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
|
curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
|
||||||
```
|
```
|
||||||
Use the below Python snippet to initialize LightRAG and perform queries:
|
Use the below Python snippet (in a script) to initialize LightRAG and perform queries:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||||
|
|
||||||
|
#########
|
||||||
|
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
|
||||||
|
# import nest_asyncio
|
||||||
|
# nest_asyncio.apply()
|
||||||
|
#########
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
@@ -79,7 +93,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
|
|||||||
<details>
|
<details>
|
||||||
<summary> Using Open AI-like APIs </summary>
|
<summary> Using Open AI-like APIs </summary>
|
||||||
|
|
||||||
LightRAG also support Open AI-like chat/embeddings APIs:
|
* LightRAG also supports Open AI-like chat/embeddings APIs:
|
||||||
```python
|
```python
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
@@ -117,7 +131,7 @@ rag = LightRAG(
|
|||||||
<details>
|
<details>
|
||||||
<summary> Using Hugging Face Models </summary>
|
<summary> Using Hugging Face Models </summary>
|
||||||
|
|
||||||
If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
* If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
||||||
```python
|
```python
|
||||||
from lightrag.llm import hf_model_complete, hf_embedding
|
from lightrag.llm import hf_model_complete, hf_embedding
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
@@ -143,7 +157,8 @@ rag = LightRAG(
|
|||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary> Using Ollama Models </summary>
|
<summary> Using Ollama Models </summary>
|
||||||
If you want to use Ollama models, you only need to set LightRAG as follows:
|
|
||||||
|
* If you want to use Ollama models, you only need to set LightRAG as follows:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from lightrag.llm import ollama_model_complete, ollama_embedding
|
from lightrag.llm import ollama_model_complete, ollama_embedding
|
||||||
@@ -164,6 +179,29 @@ rag = LightRAG(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
* Increasing the `num_ctx` parameter:
|
||||||
|
|
||||||
|
1. Pull the model:
|
||||||
|
```python
|
||||||
|
ollama pull qwen2
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Display the model file:
|
||||||
|
```python
|
||||||
|
ollama show --modelfile qwen2 > Modelfile
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Edit the Modelfile by adding the following line:
|
||||||
|
```python
|
||||||
|
PARAMETER num_ctx 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Create the modified model:
|
||||||
|
```python
|
||||||
|
ollama create -f Modelfile qwen2m
|
||||||
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Batch Insert
|
### Batch Insert
|
||||||
@@ -181,12 +219,167 @@ rag = LightRAG(working_dir="./dickens")
|
|||||||
with open("./newText.txt") as f:
|
with open("./newText.txt") as f:
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Graph Visualization
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Graph visualization with html </summary>
|
||||||
|
|
||||||
|
* The following code can be found in `examples/graph_visual_with_html.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import networkx as nx
|
||||||
|
from pyvis.network import Network
|
||||||
|
|
||||||
|
# Load the GraphML file
|
||||||
|
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
|
||||||
|
|
||||||
|
# Create a Pyvis network
|
||||||
|
net = Network(notebook=True)
|
||||||
|
|
||||||
|
# Convert NetworkX graph to Pyvis network
|
||||||
|
net.from_nx(G)
|
||||||
|
|
||||||
|
# Save and display the network
|
||||||
|
net.show('knowledge_graph.html')
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Graph visualization with Neo4j </summary>
|
||||||
|
|
||||||
|
* The following code can be found in `examples/graph_visual_with_neo4j.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from lightrag.utils import xml_to_json
|
||||||
|
from neo4j import GraphDatabase
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
BATCH_SIZE_NODES = 500
|
||||||
|
BATCH_SIZE_EDGES = 100
|
||||||
|
|
||||||
|
# Neo4j connection credentials
|
||||||
|
NEO4J_URI = "bolt://localhost:7687"
|
||||||
|
NEO4J_USERNAME = "neo4j"
|
||||||
|
NEO4J_PASSWORD = "your_password"
|
||||||
|
|
||||||
|
def convert_xml_to_json(xml_path, output_path):
|
||||||
|
"""Converts XML file to JSON and saves the output."""
|
||||||
|
if not os.path.exists(xml_path):
|
||||||
|
print(f"Error: File not found - {xml_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_data = xml_to_json(xml_path)
|
||||||
|
if json_data:
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"JSON file created: {output_path}")
|
||||||
|
return json_data
|
||||||
|
else:
|
||||||
|
print("Failed to create JSON data")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_in_batches(tx, query, data, batch_size):
|
||||||
|
"""Process data in batches and execute the given query."""
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i + batch_size]
|
||||||
|
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Paths
|
||||||
|
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
|
||||||
|
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
|
||||||
|
|
||||||
|
# Convert XML to JSON
|
||||||
|
json_data = convert_xml_to_json(xml_file, json_file)
|
||||||
|
if json_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load nodes and edges
|
||||||
|
nodes = json_data.get('nodes', [])
|
||||||
|
edges = json_data.get('edges', [])
|
||||||
|
|
||||||
|
# Neo4j queries
|
||||||
|
create_nodes_query = """
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (e:Entity {id: node.id})
|
||||||
|
SET e.entity_type = node.entity_type,
|
||||||
|
e.description = node.description,
|
||||||
|
e.source_id = node.source_id,
|
||||||
|
e.displayName = node.id
|
||||||
|
REMOVE e:Entity
|
||||||
|
WITH e, node
|
||||||
|
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
create_edges_query = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (source {id: edge.source})
|
||||||
|
MATCH (target {id: edge.target})
|
||||||
|
WITH source, target, edge,
|
||||||
|
CASE
|
||||||
|
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
|
||||||
|
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
|
||||||
|
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
|
||||||
|
WHEN edge.keywords CONTAINS 'located' THEN 'located'
|
||||||
|
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
|
||||||
|
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
|
||||||
|
END AS relType
|
||||||
|
CALL apoc.create.relationship(source, relType, {
|
||||||
|
weight: edge.weight,
|
||||||
|
description: edge.description,
|
||||||
|
keywords: edge.keywords,
|
||||||
|
source_id: edge.source_id
|
||||||
|
}, target) YIELD rel
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_displayname_and_labels_query = """
|
||||||
|
MATCH (n)
|
||||||
|
SET n.displayName = n.id
|
||||||
|
WITH n
|
||||||
|
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a Neo4j driver
|
||||||
|
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute queries in batches
|
||||||
|
with driver.session() as session:
|
||||||
|
# Insert nodes in batches
|
||||||
|
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
|
||||||
|
|
||||||
|
# Insert edges in batches
|
||||||
|
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
|
||||||
|
|
||||||
|
# Set displayName and labels
|
||||||
|
session.run(set_displayname_and_labels_query)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
driver.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## Evaluation
|
## Evaluation
|
||||||
### Dataset
|
### Dataset
|
||||||
The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
||||||
|
|
||||||
### Generate Query
|
### Generate Query
|
||||||
LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
|
LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary> Prompt </summary>
|
<summary> Prompt </summary>
|
||||||
@@ -380,7 +573,7 @@ def insert_text(rag, file_path):
|
|||||||
|
|
||||||
### Step-2 Generate Queries
|
### Step-2 Generate Queries
|
||||||
|
|
||||||
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
|
We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary> Code </summary>
|
<summary> Code </summary>
|
||||||
@@ -427,11 +620,16 @@ def extract_queries(file_path):
|
|||||||
.
|
.
|
||||||
├── examples
|
├── examples
|
||||||
│ ├── batch_eval.py
|
│ ├── batch_eval.py
|
||||||
|
│ ├── graph_visual_with_html.py
|
||||||
|
│ ├── graph_visual_with_neo4j.py
|
||||||
│ ├── generate_query.py
|
│ ├── generate_query.py
|
||||||
|
│ ├── lightrag_azure_openai_demo.py
|
||||||
|
│ ├── lightrag_bedrock_demo.py
|
||||||
│ ├── lightrag_hf_demo.py
|
│ ├── lightrag_hf_demo.py
|
||||||
│ ├── lightrag_ollama_demo.py
|
│ ├── lightrag_ollama_demo.py
|
||||||
│ ├── lightrag_openai_compatible_demo.py
|
│ ├── lightrag_openai_compatible_demo.py
|
||||||
│ └── lightrag_openai_demo.py
|
│ ├── lightrag_openai_demo.py
|
||||||
|
│ └── vram_management_demo.py
|
||||||
├── lightrag
|
├── lightrag
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── base.py
|
│ ├── base.py
|
||||||
@@ -446,6 +644,8 @@ def extract_queries(file_path):
|
|||||||
│ ├── Step_1.py
|
│ ├── Step_1.py
|
||||||
│ ├── Step_2.py
|
│ ├── Step_2.py
|
||||||
│ └── Step_3.py
|
│ └── Step_3.py
|
||||||
|
├── .gitignore
|
||||||
|
├── .pre-commit-config.yaml
|
||||||
├── LICENSE
|
├── LICENSE
|
||||||
├── README.md
|
├── README.md
|
||||||
├── requirements.txt
|
├── requirements.txt
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import jsonlines
|
import jsonlines
|
||||||
@@ -9,22 +8,22 @@ from openai import OpenAI
|
|||||||
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
||||||
client = OpenAI()
|
client = OpenAI()
|
||||||
|
|
||||||
with open(query_file, 'r') as f:
|
with open(query_file, "r") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||||
|
|
||||||
with open(result1_file, 'r') as f:
|
with open(result1_file, "r") as f:
|
||||||
answers1 = json.load(f)
|
answers1 = json.load(f)
|
||||||
answers1 = [i['result'] for i in answers1]
|
answers1 = [i["result"] for i in answers1]
|
||||||
|
|
||||||
with open(result2_file, 'r') as f:
|
with open(result2_file, "r") as f:
|
||||||
answers2 = json.load(f)
|
answers2 = json.load(f)
|
||||||
answers2 = [i['result'] for i in answers2]
|
answers2 = [i["result"] for i in answers2]
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
||||||
sys_prompt = f"""
|
sys_prompt = """
|
||||||
---Role---
|
---Role---
|
||||||
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
||||||
"""
|
"""
|
||||||
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"custom_id": f"request-{i+1}",
|
"custom_id": f"request-{i+1}",
|
||||||
"method": "POST",
|
"method": "POST",
|
||||||
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": sys_prompt},
|
{"role": "system", "content": sys_prompt},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
requests.append(request_data)
|
requests.append(request_data)
|
||||||
|
|
||||||
with jsonlines.open(output_file_path, mode='w') as writer:
|
with jsonlines.open(output_file_path, mode="w") as writer:
|
||||||
for request in requests:
|
for request in requests:
|
||||||
writer.write(request)
|
writer.write(request)
|
||||||
|
|
||||||
print(f"Batch API requests written to {output_file_path}")
|
print(f"Batch API requests written to {output_file_path}")
|
||||||
|
|
||||||
batch_input_file = client.files.create(
|
batch_input_file = client.files.create(
|
||||||
file=open(output_file_path, "rb"),
|
file=open(output_file_path, "rb"), purpose="batch"
|
||||||
purpose="batch"
|
|
||||||
)
|
)
|
||||||
batch_input_file_id = batch_input_file.id
|
batch_input_file_id = batch_input_file.id
|
||||||
|
|
||||||
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|||||||
input_file_id=batch_input_file_id,
|
input_file_id=batch_input_file_id,
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
metadata={
|
metadata={"description": "nightly eval job"},
|
||||||
"description": "nightly eval job"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f'Batch {batch.id} has been created.')
|
print(f"Batch {batch.id} has been created.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
batch_eval()
|
batch_eval()
|
@@ -1,9 +1,8 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
# os.environ["OPENAI_API_KEY"] = ""
|
# os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
|
||||||
|
|
||||||
def openai_complete_if_cache(
|
def openai_complete_if_cache(
|
||||||
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -47,9 +46,9 @@ if __name__ == "__main__":
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt)
|
result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
|
||||||
|
|
||||||
file_path = f"./queries.txt"
|
file_path = "./queries.txt"
|
||||||
with open(file_path, "w") as file:
|
with open(file_path, "w") as file:
|
||||||
file.write(result)
|
file.write(result)
|
||||||
|
|
||||||
|
19
examples/graph_visual_with_html.py
Normal file
19
examples/graph_visual_with_html.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import networkx as nx
|
||||||
|
from pyvis.network import Network
|
||||||
|
import random
|
||||||
|
|
||||||
|
# Load the GraphML file
|
||||||
|
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
|
||||||
|
|
||||||
|
# Create a Pyvis network
|
||||||
|
net = Network(notebook=True)
|
||||||
|
|
||||||
|
# Convert NetworkX graph to Pyvis network
|
||||||
|
net.from_nx(G)
|
||||||
|
|
||||||
|
# Add colors to nodes
|
||||||
|
for node in net.nodes:
|
||||||
|
node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
|
||||||
|
|
||||||
|
# Save and display the network
|
||||||
|
net.show('knowledge_graph.html')
|
118
examples/graph_visual_with_neo4j.py
Normal file
118
examples/graph_visual_with_neo4j.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from lightrag.utils import xml_to_json
|
||||||
|
from neo4j import GraphDatabase
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
BATCH_SIZE_NODES = 500
|
||||||
|
BATCH_SIZE_EDGES = 100
|
||||||
|
|
||||||
|
# Neo4j connection credentials
|
||||||
|
NEO4J_URI = "bolt://localhost:7687"
|
||||||
|
NEO4J_USERNAME = "neo4j"
|
||||||
|
NEO4J_PASSWORD = "your_password"
|
||||||
|
|
||||||
|
def convert_xml_to_json(xml_path, output_path):
|
||||||
|
"""Converts XML file to JSON and saves the output."""
|
||||||
|
if not os.path.exists(xml_path):
|
||||||
|
print(f"Error: File not found - {xml_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_data = xml_to_json(xml_path)
|
||||||
|
if json_data:
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"JSON file created: {output_path}")
|
||||||
|
return json_data
|
||||||
|
else:
|
||||||
|
print("Failed to create JSON data")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_in_batches(tx, query, data, batch_size):
|
||||||
|
"""Process data in batches and execute the given query."""
|
||||||
|
for i in range(0, len(data), batch_size):
|
||||||
|
batch = data[i:i + batch_size]
|
||||||
|
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Paths
|
||||||
|
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
|
||||||
|
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
|
||||||
|
|
||||||
|
# Convert XML to JSON
|
||||||
|
json_data = convert_xml_to_json(xml_file, json_file)
|
||||||
|
if json_data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load nodes and edges
|
||||||
|
nodes = json_data.get('nodes', [])
|
||||||
|
edges = json_data.get('edges', [])
|
||||||
|
|
||||||
|
# Neo4j queries
|
||||||
|
create_nodes_query = """
|
||||||
|
UNWIND $nodes AS node
|
||||||
|
MERGE (e:Entity {id: node.id})
|
||||||
|
SET e.entity_type = node.entity_type,
|
||||||
|
e.description = node.description,
|
||||||
|
e.source_id = node.source_id,
|
||||||
|
e.displayName = node.id
|
||||||
|
REMOVE e:Entity
|
||||||
|
WITH e, node
|
||||||
|
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
create_edges_query = """
|
||||||
|
UNWIND $edges AS edge
|
||||||
|
MATCH (source {id: edge.source})
|
||||||
|
MATCH (target {id: edge.target})
|
||||||
|
WITH source, target, edge,
|
||||||
|
CASE
|
||||||
|
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
|
||||||
|
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
|
||||||
|
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
|
||||||
|
WHEN edge.keywords CONTAINS 'located' THEN 'located'
|
||||||
|
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
|
||||||
|
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
|
||||||
|
END AS relType
|
||||||
|
CALL apoc.create.relationship(source, relType, {
|
||||||
|
weight: edge.weight,
|
||||||
|
description: edge.description,
|
||||||
|
keywords: edge.keywords,
|
||||||
|
source_id: edge.source_id
|
||||||
|
}, target) YIELD rel
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
set_displayname_and_labels_query = """
|
||||||
|
MATCH (n)
|
||||||
|
SET n.displayName = n.id
|
||||||
|
WITH n
|
||||||
|
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
|
||||||
|
RETURN count(*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a Neo4j driver
|
||||||
|
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute queries in batches
|
||||||
|
with driver.session() as session:
|
||||||
|
# Insert nodes in batches
|
||||||
|
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
|
||||||
|
|
||||||
|
# Insert edges in batches
|
||||||
|
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
|
||||||
|
|
||||||
|
# Set displayName and labels
|
||||||
|
session.run(set_displayname_and_labels_query)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
driver.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
125
examples/lightrag_azure_openai_demo.py
Normal file
125
examples/lightrag_azure_openai_demo.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import aiohttp
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
|
||||||
|
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||||
|
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
|
||||||
|
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||||
|
|
||||||
|
AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
|
||||||
|
AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if os.path.exists(WORKING_DIR):
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(WORKING_DIR)
|
||||||
|
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": AZURE_OPENAI_API_KEY,
|
||||||
|
}
|
||||||
|
endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}"
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
if history_messages:
|
||||||
|
messages.extend(history_messages)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": kwargs.get("temperature", 0),
|
||||||
|
"top_p": kwargs.get("top_p", 1),
|
||||||
|
"n": kwargs.get("n", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(endpoint, headers=headers, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Request failed with status {response.status}: {await response.text()}"
|
||||||
|
)
|
||||||
|
result = await response.json()
|
||||||
|
return result["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": AZURE_OPENAI_API_KEY,
|
||||||
|
}
|
||||||
|
endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}"
|
||||||
|
|
||||||
|
payload = {"input": texts}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(endpoint, headers=headers, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Request failed with status {response.status}: {await response.text()}"
|
||||||
|
)
|
||||||
|
result = await response.json()
|
||||||
|
embeddings = [item["embedding"] for item in result["data"]]
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_funcs():
|
||||||
|
result = await llm_model_func("How are you?")
|
||||||
|
print("Resposta do llm_model_func: ", result)
|
||||||
|
|
||||||
|
result = await embedding_func(["How are you?"])
|
||||||
|
print("Resultado do embedding_func: ", result.shape)
|
||||||
|
print("Dimensão da embedding: ", result.shape[1])
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
embedding_dimension = 3072
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=embedding_dimension,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=embedding_func,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
book1 = open("./book_1.txt", encoding="utf-8")
|
||||||
|
book2 = open("./book_2.txt", encoding="utf-8")
|
||||||
|
|
||||||
|
rag.insert([book1.read(), book2.read()])
|
||||||
|
|
||||||
|
query_text = "What are the main themes?"
|
||||||
|
|
||||||
|
print("Result (Naive):")
|
||||||
|
print(rag.query(query_text, param=QueryParam(mode="naive")))
|
||||||
|
|
||||||
|
print("\nResult (Local):")
|
||||||
|
print(rag.query(query_text, param=QueryParam(mode="local")))
|
||||||
|
|
||||||
|
print("\nResult (Global):")
|
||||||
|
print(rag.query(query_text, param=QueryParam(mode="global")))
|
||||||
|
|
||||||
|
print("\nResult (Hybrid):")
|
||||||
|
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
|
36
examples/lightrag_bedrock_demo.py
Normal file
36
examples/lightrag_bedrock_demo.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
LightRAG meets Amazon Bedrock ⛰️
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import bedrock_complete, bedrock_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
logging.getLogger("aiobotocore").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=bedrock_complete,
|
||||||
|
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
for mode in ["naive", "local", "global", "hybrid"]:
|
||||||
|
print("\n+-" + "-" * len(mode) + "-+")
|
||||||
|
print(f"| {mode.capitalize()} |")
|
||||||
|
print("+-" + "-" * len(mode) + "-+\n")
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
|
||||||
|
)
|
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import hf_model_complete, hf_embedding
|
from lightrag.llm import hf_model_complete, hf_embedding
|
||||||
@@ -14,15 +13,19 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=hf_model_complete,
|
llm_model_func=hf_model_complete,
|
||||||
llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
|
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=384,
|
embedding_dim=384,
|
||||||
max_token_size=5000,
|
max_token_size=5000,
|
||||||
func=lambda texts: hf_embedding(
|
func=lambda texts: hf_embedding(
|
||||||
texts,
|
texts,
|
||||||
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
tokenizer=AutoTokenizer.from_pretrained(
|
||||||
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
"sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
),
|
||||||
|
embed_model=AutoModel.from_pretrained(
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
@@ -12,14 +12,11 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=ollama_model_complete,
|
llm_model_func=ollama_model_complete,
|
||||||
llm_model_name='your_model_name',
|
llm_model_name="your_model_name",
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=768,
|
embedding_dim=768,
|
||||||
max_token_size=8192,
|
max_token_size=8192,
|
||||||
func=lambda texts: ollama_embedding(
|
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
|
||||||
texts,
|
|
||||||
embed_model="nomic-embed-text"
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
@@ -10,6 +10,7 @@ WORKING_DIR = "./dickens"
|
|||||||
if not os.path.exists(WORKING_DIR):
|
if not os.path.exists(WORKING_DIR):
|
||||||
os.mkdir(WORKING_DIR)
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
async def llm_model_func(
|
async def llm_model_func(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -20,17 +21,19 @@ async def llm_model_func(
|
|||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar",
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
return await openai_embedding(
|
return await openai_embedding(
|
||||||
texts,
|
texts,
|
||||||
model="solar-embedding-1-large-query",
|
model="solar-embedding-1-large-query",
|
||||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
base_url="https://api.upstage.ai/v1/solar"
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# function test
|
# function test
|
||||||
async def test_funcs():
|
async def test_funcs():
|
||||||
result = await llm_model_func("How are you?")
|
result = await llm_model_func("How are you?")
|
||||||
@@ -39,6 +42,7 @@ async def test_funcs():
|
|||||||
result = await embedding_func(["How are you?"])
|
result = await embedding_func(["How are you?"])
|
||||||
print("embedding_func: ", result)
|
print("embedding_func: ", result)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(test_funcs())
|
asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
|
||||||
@@ -46,10 +50,8 @@ rag = LightRAG(
|
|||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=llm_model_func,
|
llm_model_func=llm_model_func,
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=4096,
|
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||||
max_token_size=8192,
|
),
|
||||||
func=embedding_func
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
@@ -1,9 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
from lightrag.llm import gpt_4o_mini_complete
|
||||||
from transformers import AutoModel,AutoTokenizer
|
|
||||||
|
|
||||||
WORKING_DIR = "./dickens"
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
|
|||||||
|
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=gpt_4o_mini_complete
|
llm_model_func=gpt_4o_mini_complete,
|
||||||
# llm_model_func=gpt_4o_complete
|
# llm_model_func=gpt_4o_complete
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
|
|||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
||||||
|
82
examples/vram_management_demo.py
Normal file
82
examples/vram_management_demo.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import ollama_model_complete, ollama_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
|
||||||
|
# Working directory and the directory path for text files
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
TEXT_FILES_DIR = "/llm/mt"
|
||||||
|
|
||||||
|
# Create the working directory if it doesn't exist
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
# Initialize LightRAG
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=ollama_model_complete,
|
||||||
|
llm_model_name="qwen2.5:3b-instruct-max-context",
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=768,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read all .txt files from the TEXT_FILES_DIR directory
|
||||||
|
texts = []
|
||||||
|
for filename in os.listdir(TEXT_FILES_DIR):
|
||||||
|
if filename.endswith('.txt'):
|
||||||
|
file_path = os.path.join(TEXT_FILES_DIR, filename)
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
texts.append(file.read())
|
||||||
|
|
||||||
|
# Batch insert texts into LightRAG with a retry mechanism
|
||||||
|
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
|
||||||
|
for _ in range(retries):
|
||||||
|
try:
|
||||||
|
rag.insert(texts)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
|
||||||
|
time.sleep(delay)
|
||||||
|
raise RuntimeError("Failed to insert texts after multiple retries.")
|
||||||
|
|
||||||
|
insert_texts_with_retry(rag, texts)
|
||||||
|
|
||||||
|
# Perform different types of queries and handle potential errors
|
||||||
|
try:
|
||||||
|
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error performing naive search: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error performing local search: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error performing global search: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error performing hybrid search: {e}")
|
||||||
|
|
||||||
|
# Function to clear VRAM resources
|
||||||
|
def clear_vram():
|
||||||
|
os.system("sudo nvidia-smi --gpu-reset")
|
||||||
|
|
||||||
|
# Regularly clear VRAM to prevent overflow
|
||||||
|
clear_vram_interval = 3600 # Clear once every hour
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - start_time > clear_vram_interval:
|
||||||
|
clear_vram()
|
||||||
|
start_time = current_time
|
||||||
|
time.sleep(60) # Check the time every minute
|
@@ -1,5 +1,5 @@
|
|||||||
from .lightrag import LightRAG, QueryParam
|
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||||
|
|
||||||
__version__ = "0.0.6"
|
__version__ = "0.0.7"
|
||||||
__author__ = "Zirui Guo"
|
__author__ = "Zirui Guo"
|
||||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||||
|
@@ -12,6 +12,7 @@ TextChunkSchema = TypedDict(
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueryParam:
|
class QueryParam:
|
||||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||||
@@ -36,6 +37,7 @@ class StorageNameSpace:
|
|||||||
"""commit the storage operations after querying"""
|
"""commit the storage operations after querying"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseVectorStorage(StorageNameSpace):
|
class BaseVectorStorage(StorageNameSpace):
|
||||||
embedding_func: EmbeddingFunc
|
embedding_func: EmbeddingFunc
|
||||||
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseKVStorage(Generic[T], StorageNameSpace):
|
class BaseKVStorage(Generic[T], StorageNameSpace):
|
||||||
async def all_keys(self) -> list[str]:
|
async def all_keys(self) -> list[str]:
|
||||||
|
@@ -3,10 +3,12 @@ import os
|
|||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type, cast, Any
|
from typing import Type, cast
|
||||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
|
||||||
|
|
||||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
|
from .llm import (
|
||||||
|
gpt_4o_mini_complete,
|
||||||
|
openai_embedding,
|
||||||
|
)
|
||||||
from .operate import (
|
from .operate import (
|
||||||
chunking_by_token_size,
|
chunking_by_token_size,
|
||||||
extract_entities,
|
extract_entities,
|
||||||
@@ -37,6 +39,7 @@ from .base import (
|
|||||||
QueryParam,
|
QueryParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
@@ -69,7 +72,6 @@ class LightRAG:
|
|||||||
"dimensions": 1536,
|
"dimensions": 1536,
|
||||||
"num_walks": 10,
|
"num_walks": 10,
|
||||||
"walk_length": 40,
|
"walk_length": 40,
|
||||||
"num_walks": 10,
|
|
||||||
"window_size": 2,
|
"window_size": 2,
|
||||||
"iterations": 3,
|
"iterations": 3,
|
||||||
"random_seed": 3,
|
"random_seed": 3,
|
||||||
@@ -83,7 +85,7 @@ class LightRAG:
|
|||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||||
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||||
llm_model_max_token_size: int = 32768
|
llm_model_max_token_size: int = 32768
|
||||||
llm_model_max_async: int = 16
|
llm_model_max_async: int = 16
|
||||||
|
|
||||||
@@ -133,29 +135,23 @@ class LightRAG:
|
|||||||
self.embedding_func
|
self.embedding_func
|
||||||
)
|
)
|
||||||
|
|
||||||
self.entities_vdb = (
|
self.entities_vdb = self.vector_db_storage_cls(
|
||||||
self.vector_db_storage_cls(
|
|
||||||
namespace="entities",
|
namespace="entities",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"entity_name"}
|
meta_fields={"entity_name"},
|
||||||
)
|
)
|
||||||
)
|
self.relationships_vdb = self.vector_db_storage_cls(
|
||||||
self.relationships_vdb = (
|
|
||||||
self.vector_db_storage_cls(
|
|
||||||
namespace="relationships",
|
namespace="relationships",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
meta_fields={"src_id", "tgt_id"}
|
meta_fields={"src_id", "tgt_id"},
|
||||||
)
|
)
|
||||||
)
|
self.chunks_vdb = self.vector_db_storage_cls(
|
||||||
self.chunks_vdb = (
|
|
||||||
self.vector_db_storage_cls(
|
|
||||||
namespace="chunks",
|
namespace="chunks",
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
embedding_func=self.embedding_func,
|
embedding_func=self.embedding_func,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||||
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
||||||
@@ -177,7 +173,7 @@ class LightRAG:
|
|||||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||||
if not len(new_docs):
|
if not len(new_docs):
|
||||||
logger.warning(f"All docs are already in the storage")
|
logger.warning("All docs are already in the storage")
|
||||||
return
|
return
|
||||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||||
|
|
||||||
@@ -203,7 +199,7 @@ class LightRAG:
|
|||||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||||
}
|
}
|
||||||
if not len(inserting_chunks):
|
if not len(inserting_chunks):
|
||||||
logger.warning(f"All chunks are already in the storage")
|
logger.warning("All chunks are already in the storage")
|
||||||
return
|
return
|
||||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||||
|
|
||||||
@@ -291,7 +287,6 @@ class LightRAG:
|
|||||||
await self._query_done()
|
await self._query_done()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _query_done(self):
|
async def _query_done(self):
|
||||||
tasks = []
|
tasks = []
|
||||||
for storage_inst in [self.llm_response_cache]:
|
for storage_inst in [self.llm_response_cache]:
|
||||||
@@ -299,5 +294,3 @@ class LightRAG:
|
|||||||
continue
|
continue
|
||||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
275
lightrag/llm.py
275
lightrag/llm.py
@@ -1,4 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import aioboto3
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ollama
|
import ollama
|
||||||
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
||||||
@@ -8,24 +11,34 @@ from tenacity import (
|
|||||||
wait_exponential,
|
wait_exponential,
|
||||||
retry_if_exception_type,
|
retry_if_exception_type,
|
||||||
)
|
)
|
||||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
from .base import BaseKVStorage
|
from .base import BaseKVStorage
|
||||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||||
import copy
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
)
|
)
|
||||||
async def openai_complete_if_cache(
|
async def openai_complete_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
|
model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
base_url=None,
|
||||||
|
api_key=None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
os.environ["OPENAI_API_KEY"] = api_key
|
||||||
|
|
||||||
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
openai_async_client = (
|
||||||
|
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||||
|
)
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -48,15 +61,105 @@ async def openai_complete_if_cache(
|
|||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockError(Exception):
|
||||||
|
"""Generic error for issues related to Amazon Bedrock"""
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(5),
|
||||||
|
wait=wait_exponential(multiplier=1, max=60),
|
||||||
|
retry=retry_if_exception_type((BedrockError)),
|
||||||
|
)
|
||||||
|
async def bedrock_complete_if_cache(
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
aws_access_key_id=None,
|
||||||
|
aws_secret_access_key=None,
|
||||||
|
aws_session_token=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||||
|
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||||
|
)
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||||
|
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||||
|
)
|
||||||
|
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||||
|
"AWS_SESSION_TOKEN", aws_session_token
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fix message history format
|
||||||
|
messages = []
|
||||||
|
for history_message in history_messages:
|
||||||
|
message = copy.copy(history_message)
|
||||||
|
message["content"] = [{"text": message["content"]}]
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
# Add user prompt
|
||||||
|
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||||
|
|
||||||
|
# Initialize Converse API arguments
|
||||||
|
args = {"modelId": model, "messages": messages}
|
||||||
|
|
||||||
|
# Define system prompt
|
||||||
|
if system_prompt:
|
||||||
|
args["system"] = [{"text": system_prompt}]
|
||||||
|
|
||||||
|
# Map and set up inference parameters
|
||||||
|
inference_params_map = {
|
||||||
|
"max_tokens": "maxTokens",
|
||||||
|
"top_p": "topP",
|
||||||
|
"stop_sequences": "stopSequences",
|
||||||
|
}
|
||||||
|
if inference_params := list(
|
||||||
|
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
||||||
|
):
|
||||||
|
args["inferenceConfig"] = {}
|
||||||
|
for param in inference_params:
|
||||||
|
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||||
|
kwargs.pop(param)
|
||||||
|
)
|
||||||
|
|
||||||
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
|
if hashing_kv is not None:
|
||||||
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
|
if if_cache_return is not None:
|
||||||
|
return if_cache_return["return"]
|
||||||
|
|
||||||
|
# Call model via Converse API
|
||||||
|
session = aioboto3.Session()
|
||||||
|
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||||
|
try:
|
||||||
|
response = await bedrock_async_client.converse(**args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(e)
|
||||||
|
|
||||||
|
if hashing_kv is not None:
|
||||||
|
await hashing_kv.upsert(
|
||||||
|
{
|
||||||
|
args_hash: {
|
||||||
|
"return": response["output"]["message"]["content"][0]["text"],
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response["output"]["message"]["content"][0]["text"]
|
||||||
|
|
||||||
|
|
||||||
async def hf_model_if_cache(
|
async def hf_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = model
|
model_name = model
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
||||||
if hf_tokenizer.pad_token == None:
|
if hf_tokenizer.pad_token is None:
|
||||||
# print("use eos token")
|
# print("use eos token")
|
||||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
|
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -69,30 +172,51 @@ async def hf_model_if_cache(
|
|||||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
if if_cache_return is not None:
|
if if_cache_return is not None:
|
||||||
return if_cache_return["return"]
|
return if_cache_return["return"]
|
||||||
input_prompt = ''
|
input_prompt = ""
|
||||||
try:
|
try:
|
||||||
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
input_prompt = hf_tokenizer.apply_chat_template(
|
||||||
except:
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
try:
|
try:
|
||||||
ori_message = copy.deepcopy(messages)
|
ori_message = copy.deepcopy(messages)
|
||||||
if messages[0]['role'] == "system":
|
if messages[0]["role"] == "system":
|
||||||
messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
|
messages[1]["content"] = (
|
||||||
|
"<system>"
|
||||||
|
+ messages[0]["content"]
|
||||||
|
+ "</system>\n"
|
||||||
|
+ messages[1]["content"]
|
||||||
|
)
|
||||||
messages = messages[1:]
|
messages = messages[1:]
|
||||||
input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
input_prompt = hf_tokenizer.apply_chat_template(
|
||||||
except:
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
len_message = len(ori_message)
|
len_message = len(ori_message)
|
||||||
for msgid in range(len_message):
|
for msgid in range(len_message):
|
||||||
input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
|
input_prompt = (
|
||||||
|
input_prompt
|
||||||
|
+ "<"
|
||||||
|
+ ori_message[msgid]["role"]
|
||||||
|
+ ">"
|
||||||
|
+ ori_message[msgid]["content"]
|
||||||
|
+ "</"
|
||||||
|
+ ori_message[msgid]["role"]
|
||||||
|
+ ">\n"
|
||||||
|
)
|
||||||
|
|
||||||
input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
|
input_ids = hf_tokenizer(
|
||||||
output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
|
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||||
|
).to("cuda")
|
||||||
|
output = hf_model.generate(
|
||||||
|
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
|
||||||
|
)
|
||||||
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert(
|
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
||||||
{args_hash: {"return": response_text, "model": model}}
|
|
||||||
)
|
|
||||||
return response_text
|
return response_text
|
||||||
|
|
||||||
|
|
||||||
async def ollama_model_if_cache(
|
async def ollama_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -122,6 +246,7 @@ async def ollama_model_if_cache(
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def gpt_4o_complete(
|
async def gpt_4o_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -145,10 +270,23 @@ async def gpt_4o_mini_complete(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def bedrock_complete(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await bedrock_complete_if_cache(
|
||||||
|
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def hf_model_complete(
|
async def hf_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
return await hf_model_if_cache(
|
return await hf_model_if_cache(
|
||||||
model_name,
|
model_name,
|
||||||
prompt,
|
prompt,
|
||||||
@@ -157,10 +295,11 @@ async def hf_model_complete(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def ollama_model_complete(
|
async def ollama_model_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = kwargs['hashing_kv'].global_config['llm_model_name']
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
return await ollama_model_if_cache(
|
return await ollama_model_if_cache(
|
||||||
model_name,
|
model_name,
|
||||||
prompt,
|
prompt,
|
||||||
@@ -169,30 +308,113 @@ async def ollama_model_complete(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
)
|
)
|
||||||
async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
|
async def openai_embedding(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
base_url: str = None,
|
||||||
|
api_key: str = None,
|
||||||
|
) -> np.ndarray:
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["OPENAI_API_KEY"] = api_key
|
os.environ["OPENAI_API_KEY"] = api_key
|
||||||
|
|
||||||
openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
openai_async_client = (
|
||||||
|
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||||
|
)
|
||||||
response = await openai_async_client.embeddings.create(
|
response = await openai_async_client.embeddings.create(
|
||||||
model=model, input=texts, encoding_format="float"
|
model=model, input=texts, encoding_format="float"
|
||||||
)
|
)
|
||||||
return np.array([dp.embedding for dp in response.data])
|
return np.array([dp.embedding for dp in response.data])
|
||||||
|
|
||||||
|
|
||||||
|
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
|
# @retry(
|
||||||
|
# stop=stop_after_attempt(3),
|
||||||
|
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
||||||
|
# )
|
||||||
|
async def bedrock_embedding(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "amazon.titan-embed-text-v2:0",
|
||||||
|
aws_access_key_id=None,
|
||||||
|
aws_secret_access_key=None,
|
||||||
|
aws_session_token=None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
||||||
|
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
||||||
|
)
|
||||||
|
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
||||||
|
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
||||||
|
)
|
||||||
|
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||||
|
"AWS_SESSION_TOKEN", aws_session_token
|
||||||
|
)
|
||||||
|
|
||||||
|
session = aioboto3.Session()
|
||||||
|
async with session.client("bedrock-runtime") as bedrock_async_client:
|
||||||
|
if (model_provider := model.split(".")[0]) == "amazon":
|
||||||
|
embed_texts = []
|
||||||
|
for text in texts:
|
||||||
|
if "v2" in model:
|
||||||
|
body = json.dumps(
|
||||||
|
{
|
||||||
|
"inputText": text,
|
||||||
|
# 'dimensions': embedding_dim,
|
||||||
|
"embeddingTypes": ["float"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif "v1" in model:
|
||||||
|
body = json.dumps({"inputText": text})
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model {model} is not supported!")
|
||||||
|
|
||||||
|
response = await bedrock_async_client.invoke_model(
|
||||||
|
modelId=model,
|
||||||
|
body=body,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_body = await response.get("body").json()
|
||||||
|
|
||||||
|
embed_texts.append(response_body["embedding"])
|
||||||
|
elif model_provider == "cohere":
|
||||||
|
body = json.dumps(
|
||||||
|
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await bedrock_async_client.invoke_model(
|
||||||
|
model=model,
|
||||||
|
body=body,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
|
||||||
|
embed_texts = response_body["embeddings"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
||||||
|
|
||||||
|
return np.array(embed_texts)
|
||||||
|
|
||||||
|
|
||||||
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||||
input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
|
input_ids = tokenizer(
|
||||||
|
texts, return_tensors="pt", padding=True, truncation=True
|
||||||
|
).input_ids
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = embed_model(input_ids)
|
outputs = embed_model(input_ids)
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
return embeddings.detach().numpy()
|
return embeddings.detach().numpy()
|
||||||
|
|
||||||
|
|
||||||
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
||||||
embed_text = []
|
embed_text = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
@@ -201,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
|||||||
|
|
||||||
return embed_text
|
return embed_text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
result = await gpt_4o_mini_complete('How are you?')
|
result = await gpt_4o_mini_complete("How are you?")
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
@@ -25,6 +25,7 @@ from .base import (
|
|||||||
)
|
)
|
||||||
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
||||||
|
|
||||||
|
|
||||||
def chunking_by_token_size(
|
def chunking_by_token_size(
|
||||||
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
||||||
):
|
):
|
||||||
@@ -45,6 +46,7 @@ def chunking_by_token_size(
|
|||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def _handle_entity_relation_summary(
|
async def _handle_entity_relation_summary(
|
||||||
entity_or_relation_name: str,
|
entity_or_relation_name: str,
|
||||||
description: str,
|
description: str,
|
||||||
@@ -76,7 +78,7 @@ async def _handle_single_entity_extraction(
|
|||||||
record_attributes: list[str],
|
record_attributes: list[str],
|
||||||
chunk_key: str,
|
chunk_key: str,
|
||||||
):
|
):
|
||||||
if record_attributes[0] != '"entity"' or len(record_attributes) < 4:
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
||||||
return None
|
return None
|
||||||
# add this record as a node in the G
|
# add this record as a node in the G
|
||||||
entity_name = clean_str(record_attributes[1].upper())
|
entity_name = clean_str(record_attributes[1].upper())
|
||||||
@@ -97,7 +99,7 @@ async def _handle_single_relationship_extraction(
|
|||||||
record_attributes: list[str],
|
record_attributes: list[str],
|
||||||
chunk_key: str,
|
chunk_key: str,
|
||||||
):
|
):
|
||||||
if record_attributes[0] != '"relationship"' or len(record_attributes) < 5:
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
||||||
return None
|
return None
|
||||||
# add this record as edge
|
# add this record as edge
|
||||||
source = clean_str(record_attributes[1].upper())
|
source = clean_str(record_attributes[1].upper())
|
||||||
@@ -232,6 +234,7 @@ async def _merge_edges_then_upsert(
|
|||||||
|
|
||||||
return edge_data
|
return edge_data
|
||||||
|
|
||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
chunks: dict[str, TextChunkSchema],
|
chunks: dict[str, TextChunkSchema],
|
||||||
knwoledge_graph_inst: BaseGraphStorage,
|
knwoledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -352,7 +355,9 @@ async def extract_entities(
|
|||||||
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
||||||
return None
|
return None
|
||||||
if not len(all_relationships_data):
|
if not len(all_relationships_data):
|
||||||
logger.warning("Didn't extract any relationships, maybe your LLM is not working")
|
logger.warning(
|
||||||
|
"Didn't extract any relationships, maybe your LLM is not working"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if entity_vdb is not None:
|
if entity_vdb is not None:
|
||||||
@@ -370,7 +375,10 @@ async def extract_entities(
|
|||||||
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
||||||
"src_id": dp["src_id"],
|
"src_id": dp["src_id"],
|
||||||
"tgt_id": dp["tgt_id"],
|
"tgt_id": dp["tgt_id"],
|
||||||
"content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
|
"content": dp["keywords"]
|
||||||
|
+ dp["src_id"]
|
||||||
|
+ dp["tgt_id"]
|
||||||
|
+ dp["description"],
|
||||||
}
|
}
|
||||||
for dp in all_relationships_data
|
for dp in all_relationships_data
|
||||||
}
|
}
|
||||||
@@ -378,6 +386,7 @@ async def extract_entities(
|
|||||||
|
|
||||||
return knwoledge_graph_inst
|
return knwoledge_graph_inst
|
||||||
|
|
||||||
|
|
||||||
async def local_query(
|
async def local_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -387,6 +396,7 @@ async def local_query(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
context = None
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
|
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
@@ -396,17 +406,25 @@ async def local_query(
|
|||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("low_level_keywords", [])
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
result = (
|
||||||
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||||
|
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("low_level_keywords", [])
|
keywords = keywords_data.get("low_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"JSON parsing error: {e}")
|
print(f"JSON parsing error: {e}")
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
if keywords:
|
||||||
context = await _build_local_query_context(
|
context = await _build_local_query_context(
|
||||||
keywords,
|
keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -427,10 +445,19 @@ async def local_query(
|
|||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _build_local_query_context(
|
async def _build_local_query_context(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -512,6 +539,7 @@ async def _build_local_query_context(
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_text_unit_from_entities(
|
async def _find_most_related_text_unit_from_entities(
|
||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
@@ -572,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
|
|||||||
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
||||||
return all_text_units
|
return all_text_units
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_edges_from_entities(
|
async def _find_most_related_edges_from_entities(
|
||||||
node_datas: list[dict],
|
node_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
@@ -605,6 +634,7 @@ async def _find_most_related_edges_from_entities(
|
|||||||
)
|
)
|
||||||
return all_edges_data
|
return all_edges_data
|
||||||
|
|
||||||
|
|
||||||
async def global_query(
|
async def global_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -614,6 +644,7 @@ async def global_query(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
context = None
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
|
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
@@ -623,19 +654,26 @@ async def global_query(
|
|||||||
try:
|
try:
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
result = (
|
||||||
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||||
|
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
keywords = keywords_data.get("high_level_keywords", [])
|
keywords = keywords_data.get("high_level_keywords", [])
|
||||||
keywords = ', '.join(keywords)
|
keywords = ", ".join(keywords)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
print(f"JSON parsing error: {e}")
|
print(f"JSON parsing error: {e}")
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
if keywords:
|
||||||
context = await _build_global_query_context(
|
context = await _build_global_query_context(
|
||||||
keywords,
|
keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -659,10 +697,19 @@ async def global_query(
|
|||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def _build_global_query_context(
|
async def _build_global_query_context(
|
||||||
keywords,
|
keywords,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -758,6 +805,7 @@ async def _build_global_query_context(
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def _find_most_related_entities_from_relationships(
|
async def _find_most_related_entities_from_relationships(
|
||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
@@ -788,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
|
|||||||
|
|
||||||
return node_datas
|
return node_datas
|
||||||
|
|
||||||
|
|
||||||
async def _find_related_text_unit_from_relationships(
|
async def _find_related_text_unit_from_relationships(
|
||||||
edge_datas: list[dict],
|
edge_datas: list[dict],
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
):
|
):
|
||||||
|
|
||||||
text_units = [
|
text_units = [
|
||||||
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
||||||
for dp in edge_datas
|
for dp in edge_datas
|
||||||
@@ -815,9 +863,7 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
all_text_units = [
|
all_text_units = [
|
||||||
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
||||||
]
|
]
|
||||||
all_text_units = sorted(
|
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
||||||
all_text_units, key=lambda x: x["order"]
|
|
||||||
)
|
|
||||||
all_text_units = truncate_list_by_token_size(
|
all_text_units = truncate_list_by_token_size(
|
||||||
all_text_units,
|
all_text_units,
|
||||||
key=lambda x: x["data"]["content"],
|
key=lambda x: x["data"]["content"],
|
||||||
@@ -827,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
|
|
||||||
return all_text_units
|
return all_text_units
|
||||||
|
|
||||||
|
|
||||||
async def hybrid_query(
|
async def hybrid_query(
|
||||||
query,
|
query,
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
@@ -836,6 +883,8 @@ async def hybrid_query(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
low_level_context = None
|
||||||
|
high_level_context = None
|
||||||
use_model_func = global_config["llm_model_func"]
|
use_model_func = global_config["llm_model_func"]
|
||||||
|
|
||||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||||
@@ -846,21 +895,29 @@ async def hybrid_query(
|
|||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
hl_keywords = ', '.join(hl_keywords)
|
hl_keywords = ", ".join(hl_keywords)
|
||||||
ll_keywords = ', '.join(ll_keywords)
|
ll_keywords = ", ".join(ll_keywords)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip().strip('```').strip('json')
|
result = (
|
||||||
|
result.replace(kw_prompt[:-1], "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||||
|
|
||||||
keywords_data = json.loads(result)
|
keywords_data = json.loads(result)
|
||||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||||
hl_keywords = ', '.join(hl_keywords)
|
hl_keywords = ", ".join(hl_keywords)
|
||||||
ll_keywords = ', '.join(ll_keywords)
|
ll_keywords = ", ".join(ll_keywords)
|
||||||
# Handle parsing error
|
# Handle parsing error
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"JSON parsing error: {e}")
|
print(f"JSON parsing error: {e}")
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
if ll_keywords:
|
||||||
low_level_context = await _build_local_query_context(
|
low_level_context = await _build_local_query_context(
|
||||||
ll_keywords,
|
ll_keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -869,6 +926,7 @@ async def hybrid_query(
|
|||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hl_keywords:
|
||||||
high_level_context = await _build_global_query_context(
|
high_level_context = await _build_global_query_context(
|
||||||
hl_keywords,
|
hl_keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -894,51 +952,76 @@ async def hybrid_query(
|
|||||||
system_prompt=sys_prompt,
|
system_prompt=sys_prompt,
|
||||||
)
|
)
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def combine_contexts(high_level_context, low_level_context):
|
def combine_contexts(high_level_context, low_level_context):
|
||||||
# Function to extract entities, relationships, and sources from context strings
|
# Function to extract entities, relationships, and sources from context strings
|
||||||
|
|
||||||
def extract_sections(context):
|
def extract_sections(context):
|
||||||
entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
entities_match = re.search(
|
||||||
relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||||
sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
|
)
|
||||||
|
relationships_match = re.search(
|
||||||
|
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||||
|
)
|
||||||
|
sources_match = re.search(
|
||||||
|
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
||||||
|
)
|
||||||
|
|
||||||
entities = entities_match.group(1) if entities_match else ''
|
entities = entities_match.group(1) if entities_match else ""
|
||||||
relationships = relationships_match.group(1) if relationships_match else ''
|
relationships = relationships_match.group(1) if relationships_match else ""
|
||||||
sources = sources_match.group(1) if sources_match else ''
|
sources = sources_match.group(1) if sources_match else ""
|
||||||
|
|
||||||
return entities, relationships, sources
|
return entities, relationships, sources
|
||||||
|
|
||||||
# Extract sections from both contexts
|
# Extract sections from both contexts
|
||||||
|
|
||||||
if high_level_context==None:
|
if high_level_context is None:
|
||||||
warnings.warn("High Level context is None. Return empty High entity/relationship/source")
|
warnings.warn(
|
||||||
hl_entities, hl_relationships, hl_sources = '','',''
|
"High Level context is None. Return empty High entity/relationship/source"
|
||||||
|
)
|
||||||
|
hl_entities, hl_relationships, hl_sources = "", "", ""
|
||||||
else:
|
else:
|
||||||
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
||||||
|
|
||||||
|
if low_level_context is None:
|
||||||
if low_level_context==None:
|
warnings.warn(
|
||||||
warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
|
"Low Level context is None. Return empty Low entity/relationship/source"
|
||||||
ll_entities, ll_relationships, ll_sources = '','',''
|
)
|
||||||
|
ll_entities, ll_relationships, ll_sources = "", "", ""
|
||||||
else:
|
else:
|
||||||
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Combine and deduplicate the entities
|
# Combine and deduplicate the entities
|
||||||
combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
|
combined_entities_set = set(
|
||||||
combined_entities = '\n'.join(combined_entities_set)
|
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
|
||||||
|
)
|
||||||
|
combined_entities = "\n".join(combined_entities_set)
|
||||||
|
|
||||||
# Combine and deduplicate the relationships
|
# Combine and deduplicate the relationships
|
||||||
combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
|
combined_relationships_set = set(
|
||||||
combined_relationships = '\n'.join(combined_relationships_set)
|
filter(
|
||||||
|
None,
|
||||||
|
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
combined_relationships = "\n".join(combined_relationships_set)
|
||||||
|
|
||||||
# Combine and deduplicate the sources
|
# Combine and deduplicate the sources
|
||||||
combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
|
combined_sources_set = set(
|
||||||
combined_sources = '\n'.join(combined_sources_set)
|
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
|
||||||
|
)
|
||||||
|
combined_sources = "\n".join(combined_sources_set)
|
||||||
|
|
||||||
# Format the combined context
|
# Format the combined context
|
||||||
return f"""
|
return f"""
|
||||||
@@ -951,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
|
|||||||
{combined_sources}
|
{combined_sources}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
query,
|
query,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
@@ -984,7 +1068,15 @@ async def naive_query(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(response) > len(sys_prompt):
|
if len(response) > len(sys_prompt):
|
||||||
response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
|
response = (
|
||||||
|
response[len(sys_prompt) :]
|
||||||
|
.replace(sys_prompt, "")
|
||||||
|
.replace("user", "")
|
||||||
|
.replace("model", "")
|
||||||
|
.replace(query, "")
|
||||||
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
|
|||||||
|
|
||||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
||||||
|
|
||||||
PROMPTS[
|
PROMPTS["entity_extraction"] = """-Goal-
|
||||||
"entity_extraction"
|
|
||||||
] = """-Goal-
|
|
||||||
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
||||||
|
|
||||||
-Steps-
|
-Steps-
|
||||||
@@ -146,9 +144,7 @@ PROMPTS[
|
|||||||
|
|
||||||
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
||||||
|
|
||||||
PROMPTS[
|
PROMPTS["rag_response"] = """---Role---
|
||||||
"rag_response"
|
|
||||||
] = """---Role---
|
|
||||||
|
|
||||||
You are a helpful assistant responding to questions about data in the tables provided.
|
You are a helpful assistant responding to questions about data in the tables provided.
|
||||||
|
|
||||||
@@ -163,25 +159,10 @@ Do not include information where the supporting evidence for it is not provided.
|
|||||||
|
|
||||||
{response_type}
|
{response_type}
|
||||||
|
|
||||||
|
|
||||||
---Data tables---
|
---Data tables---
|
||||||
|
|
||||||
{context_data}
|
{context_data}
|
||||||
|
|
||||||
|
|
||||||
---Goal---
|
|
||||||
|
|
||||||
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
|
|
||||||
|
|
||||||
If you don't know the answer, just say so. Do not make anything up.
|
|
||||||
|
|
||||||
Do not include information where the supporting evidence for it is not provided.
|
|
||||||
|
|
||||||
|
|
||||||
---Target response length and format---
|
|
||||||
|
|
||||||
{response_type}
|
|
||||||
|
|
||||||
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
|
Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -241,9 +222,7 @@ Output:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PROMPTS[
|
PROMPTS["naive_rag_response"] = """You're a helpful assistant
|
||||||
"naive_rag_response"
|
|
||||||
] = """You're a helpful assistant
|
|
||||||
Below are the knowledge you know:
|
Below are the knowledge you know:
|
||||||
{content_data}
|
{content_data}
|
||||||
---
|
---
|
||||||
|
@@ -1,16 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from dataclasses import dataclass
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union, cast
|
||||||
import pickle
|
|
||||||
import hnswlib
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nano_vectordb import NanoVectorDB
|
from nano_vectordb import NanoVectorDB
|
||||||
import xxhash
|
|
||||||
|
|
||||||
from .utils import load_json, logger, write_json
|
from .utils import load_json, logger, write_json
|
||||||
from .base import (
|
from .base import (
|
||||||
@@ -19,6 +14,7 @@ from .base import (
|
|||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JsonKVStorage(BaseKVStorage):
|
class JsonKVStorage(BaseKVStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
async def drop(self):
|
async def drop(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
cosine_better_than_threshold: float = 0.2
|
cosine_better_than_threshold: float = 0.2
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||||
)
|
)
|
||||||
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
async def index_done_callback(self):
|
async def index_done_callback(self):
|
||||||
self._client.save()
|
self._client.save()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NetworkXStorage(BaseGraphStorage):
|
class NetworkXStorage(BaseGraphStorage):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
graph = graph.copy()
|
graph = graph.copy()
|
||||||
graph = cast(nx.Graph, largest_connected_component(graph))
|
graph = cast(nx.Graph, largest_connected_component(graph))
|
||||||
node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
|
node_mapping = {
|
||||||
|
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
||||||
|
} # type: ignore
|
||||||
graph = nx.relabel_nodes(graph, node_mapping)
|
graph = nx.relabel_nodes(graph, node_mapping)
|
||||||
return NetworkXStorage._stabilize_graph(graph)
|
return NetworkXStorage._stabilize_graph(graph)
|
||||||
|
|
||||||
|
@@ -8,6 +8,7 @@ from dataclasses import dataclass
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@@ -16,18 +17,22 @@ ENCODER = None
|
|||||||
|
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
|
|
||||||
|
|
||||||
def set_logger(log_file: str):
|
def set_logger(log_file: str):
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
@@ -37,6 +42,7 @@ class EmbeddingFunc:
|
|||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||||
"""Locate the JSON string body from a string"""
|
"""Locate the JSON string body from a string"""
|
||||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||||
@@ -45,6 +51,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def convert_response_to_json(response: str) -> dict:
|
def convert_response_to_json(response: str) -> dict:
|
||||||
json_str = locate_json_string_body_from_string(response)
|
json_str = locate_json_string_body_from_string(response)
|
||||||
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
||||||
@@ -55,12 +62,15 @@ def convert_response_to_json(response: str) -> dict:
|
|||||||
logger.error(f"Failed to parse JSON: {json_str}")
|
logger.error(f"Failed to parse JSON: {json_str}")
|
||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args):
|
def compute_args_hash(*args):
|
||||||
return md5(str(args).encode()).hexdigest()
|
return md5(str(args).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def compute_mdhash_id(content, prefix: str = ""):
|
def compute_mdhash_id(content, prefix: str = ""):
|
||||||
return prefix + md5(content.encode()).hexdigest()
|
return prefix + md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
||||||
"""Add restriction of maximum async calling times for a async func"""
|
"""Add restriction of maximum async calling times for a async func"""
|
||||||
|
|
||||||
@@ -82,6 +92,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
|||||||
|
|
||||||
return final_decro
|
return final_decro
|
||||||
|
|
||||||
|
|
||||||
def wrap_embedding_func_with_attrs(**kwargs):
|
def wrap_embedding_func_with_attrs(**kwargs):
|
||||||
"""Wrap a function with attributes"""
|
"""Wrap a function with attributes"""
|
||||||
|
|
||||||
@@ -91,16 +102,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
|||||||
|
|
||||||
return final_decro
|
return final_decro
|
||||||
|
|
||||||
|
|
||||||
def load_json(file_name):
|
def load_json(file_name):
|
||||||
if not os.path.exists(file_name):
|
if not os.path.exists(file_name):
|
||||||
return None
|
return None
|
||||||
with open(file_name, encoding="utf-8") as f:
|
with open(file_name, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
||||||
global ENCODER
|
global ENCODER
|
||||||
if ENCODER is None:
|
if ENCODER is None:
|
||||||
@@ -116,12 +130,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
|||||||
content = ENCODER.decode(tokens)
|
content = ENCODER.decode(tokens)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def pack_user_ass_to_openai_messages(*args: str):
|
def pack_user_ass_to_openai_messages(*args: str):
|
||||||
roles = ["user", "assistant"]
|
roles = ["user", "assistant"]
|
||||||
return [
|
return [
|
||||||
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
||||||
"""Split a string by multiple markers"""
|
"""Split a string by multiple markers"""
|
||||||
if not markers:
|
if not markers:
|
||||||
@@ -129,6 +145,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
|
|||||||
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
||||||
return [r.strip() for r in results if r.strip()]
|
return [r.strip() for r in results if r.strip()]
|
||||||
|
|
||||||
|
|
||||||
# Refer the utils functions of the official GraphRAG implementation:
|
# Refer the utils functions of the official GraphRAG implementation:
|
||||||
# https://github.com/microsoft/graphrag
|
# https://github.com/microsoft/graphrag
|
||||||
def clean_str(input: Any) -> str:
|
def clean_str(input: Any) -> str:
|
||||||
@@ -141,9 +158,11 @@ def clean_str(input: Any) -> str:
|
|||||||
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
||||||
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
||||||
|
|
||||||
|
|
||||||
def is_float_regex(value):
|
def is_float_regex(value):
|
||||||
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
||||||
|
|
||||||
|
|
||||||
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
||||||
"""Truncate a list of data by token size"""
|
"""Truncate a list of data by token size"""
|
||||||
if max_token_size <= 0:
|
if max_token_size <= 0:
|
||||||
@@ -155,11 +174,61 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|||||||
return list_data[:i]
|
return list_data[:i]
|
||||||
return list_data
|
return list_data
|
||||||
|
|
||||||
|
|
||||||
def list_of_list_to_csv(data: list[list]):
|
def list_of_list_to_csv(data: list[list]):
|
||||||
return "\n".join(
|
return "\n".join(
|
||||||
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_data_to_file(data, file_name):
|
def save_data_to_file(data, file_name):
|
||||||
with open(file_name, 'w', encoding='utf-8') as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
def xml_to_json(xml_file):
|
||||||
|
try:
|
||||||
|
tree = ET.parse(xml_file)
|
||||||
|
root = tree.getroot()
|
||||||
|
|
||||||
|
# Print the root element's tag and attributes to confirm the file has been correctly loaded
|
||||||
|
print(f"Root element: {root.tag}")
|
||||||
|
print(f"Root attributes: {root.attrib}")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"nodes": [],
|
||||||
|
"edges": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use namespace
|
||||||
|
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
|
||||||
|
|
||||||
|
for node in root.findall('.//node', namespace):
|
||||||
|
node_data = {
|
||||||
|
"id": node.get('id').strip('"'),
|
||||||
|
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
|
||||||
|
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
|
||||||
|
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
|
||||||
|
}
|
||||||
|
data["nodes"].append(node_data)
|
||||||
|
|
||||||
|
for edge in root.findall('.//edge', namespace):
|
||||||
|
edge_data = {
|
||||||
|
"source": edge.get('source').strip('"'),
|
||||||
|
"target": edge.get('target').strip('"'),
|
||||||
|
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
|
||||||
|
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
|
||||||
|
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
|
||||||
|
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
|
||||||
|
}
|
||||||
|
data["edges"].append(edge_data)
|
||||||
|
|
||||||
|
# Print the number of nodes and edges found
|
||||||
|
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
||||||
|
|
||||||
|
return data
|
||||||
|
except ET.ParseError as e:
|
||||||
|
print(f"Error parsing XML file: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
return None
|
||||||
|
@@ -3,11 +3,11 @@ import json
|
|||||||
import glob
|
import glob
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
def extract_unique_contexts(input_directory, output_directory):
|
|
||||||
|
|
||||||
|
def extract_unique_contexts(input_directory, output_directory):
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
os.makedirs(output_directory, exist_ok=True)
|
||||||
|
|
||||||
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
|
jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
|
||||||
print(f"Found {len(jsonl_files)} JSONL files.")
|
print(f"Found {len(jsonl_files)} JSONL files.")
|
||||||
|
|
||||||
for file_path in jsonl_files:
|
for file_path in jsonl_files:
|
||||||
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
|
|||||||
print(f"Processing file: {filename}")
|
print(f"Processing file: {filename}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8') as infile:
|
with open(file_path, "r", encoding="utf-8") as infile:
|
||||||
for line_number, line in enumerate(infile, start=1):
|
for line_number, line in enumerate(infile, start=1):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
json_obj = json.loads(line)
|
json_obj = json.loads(line)
|
||||||
context = json_obj.get('context')
|
context = json_obj.get("context")
|
||||||
if context and context not in unique_contexts_dict:
|
if context and context not in unique_contexts_dict:
|
||||||
unique_contexts_dict[context] = None
|
unique_contexts_dict[context] = None
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
|
print(
|
||||||
|
f"JSON decoding error in file {filename} at line {line_number}: {e}"
|
||||||
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"File not found: {filename}")
|
print(f"File not found: {filename}")
|
||||||
continue
|
continue
|
||||||
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
unique_contexts_list = list(unique_contexts_dict.keys())
|
unique_contexts_list = list(unique_contexts_dict.keys())
|
||||||
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
|
print(
|
||||||
|
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(output_path, 'w', encoding='utf-8') as outfile:
|
with open(output_path, "w", encoding="utf-8") as outfile:
|
||||||
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
||||||
print(f"Unique `context` entries have been saved to: {output_filename}")
|
print(f"Unique `context` entries have been saved to: {output_filename}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
|
parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
|
||||||
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
|
parser.add_argument(
|
||||||
|
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@@ -4,8 +4,9 @@ import time
|
|||||||
|
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
|
|
||||||
|
|
||||||
def insert_text(rag, file_path):
|
def insert_text(rag, file_path):
|
||||||
with open(file_path, mode='r') as f:
|
with open(file_path, mode="r") as f:
|
||||||
unique_contexts = json.load(f)
|
unique_contexts = json.load(f)
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
|
|||||||
if retries == max_retries:
|
if retries == max_retries:
|
||||||
print("Insertion failed after exceeding the maximum number of retries")
|
print("Insertion failed after exceeding the maximum number of retries")
|
||||||
|
|
||||||
|
|
||||||
cls = "agriculture"
|
cls = "agriculture"
|
||||||
WORKING_DIR = "../{cls}"
|
WORKING_DIR = "../{cls}"
|
||||||
|
|
||||||
|
71
reproduce/Step_1_openai_compatible.py
Normal file
71
reproduce/Step_1_openai_compatible.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lightrag import LightRAG
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
|
|
||||||
|
|
||||||
|
## For Upstage API
|
||||||
|
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"solar-mini",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await openai_embedding(
|
||||||
|
texts,
|
||||||
|
model="solar-embedding-1-large-query",
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## /For Upstage API
|
||||||
|
|
||||||
|
|
||||||
|
def insert_text(rag, file_path):
|
||||||
|
with open(file_path, mode="r") as f:
|
||||||
|
unique_contexts = json.load(f)
|
||||||
|
|
||||||
|
retries = 0
|
||||||
|
max_retries = 3
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
rag.insert(unique_contexts)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
retries += 1
|
||||||
|
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
|
||||||
|
time.sleep(10)
|
||||||
|
if retries == max_retries:
|
||||||
|
print("Insertion failed after exceeding the maximum number of retries")
|
||||||
|
|
||||||
|
|
||||||
|
cls = "mix"
|
||||||
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
@@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
def openai_complete_if_cache(
|
def openai_complete_if_cache(
|
||||||
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -19,7 +19,9 @@ def openai_complete_if_cache(
|
|||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
|
||||||
def get_summary(context, tot_tokens=2000):
|
def get_summary(context, tot_tokens=2000):
|
||||||
tokens = tokenizer.tokenize(context)
|
tokens = tokenizer.tokenize(context)
|
||||||
@@ -34,9 +36,9 @@ def get_summary(context, tot_tokens=2000):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
clses = ['agriculture']
|
clses = ["agriculture"]
|
||||||
for cls in clses:
|
for cls in clses:
|
||||||
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
|
with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
|
||||||
unique_contexts = json.load(f)
|
unique_contexts = json.load(f)
|
||||||
|
|
||||||
summaries = [get_summary(context) for context in unique_contexts]
|
summaries = [get_summary(context) for context in unique_contexts]
|
||||||
@@ -67,7 +69,7 @@ for cls in clses:
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
|
result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
|
||||||
|
|
||||||
file_path = f"../datasets/questions/{cls}_questions.txt"
|
file_path = f"../datasets/questions/{cls}_questions.txt"
|
||||||
with open(file_path, "w") as file:
|
with open(file_path, "w") as file:
|
||||||
|
@@ -4,16 +4,18 @@ import asyncio
|
|||||||
from lightrag import LightRAG, QueryParam
|
from lightrag import LightRAG, QueryParam
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def extract_queries(file_path):
|
def extract_queries(file_path):
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, "r") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
data = data.replace('**', '')
|
data = data.replace("**", "")
|
||||||
|
|
||||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||||
|
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
async def process_query(query_text, rag_instance, query_param):
|
async def process_query(query_text, rag_instance, query_param):
|
||||||
try:
|
try:
|
||||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||||
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
return loop
|
return loop
|
||||||
|
|
||||||
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
|
|
||||||
|
def run_queries_and_save_to_json(
|
||||||
|
queries, rag_instance, query_param, output_file, error_file
|
||||||
|
):
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
|
|
||||||
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
|
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
||||||
|
error_file, "a", encoding="utf-8"
|
||||||
|
) as err_file:
|
||||||
result_file.write("[\n")
|
result_file.write("[\n")
|
||||||
first_entry = True
|
first_entry = True
|
||||||
|
|
||||||
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||||
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
|
result, error = loop.run_until_complete(
|
||||||
|
process_query(query_text, rag_instance, query_param)
|
||||||
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
if not first_entry:
|
if not first_entry:
|
||||||
@@ -50,13 +60,16 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|||||||
|
|
||||||
result_file.write("\n]")
|
result_file.write("\n]")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cls = "agriculture"
|
cls = "agriculture"
|
||||||
mode = "hybrid"
|
mode = "hybrid"
|
||||||
WORKING_DIR = "../{cls}"
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
rag = LightRAG(working_dir=WORKING_DIR)
|
rag = LightRAG(working_dir=WORKING_DIR)
|
||||||
query_param = QueryParam(mode=mode)
|
query_param = QueryParam(mode=mode)
|
||||||
|
|
||||||
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
||||||
run_queries_and_save_to_json(queries, rag, query_param, "result.json", "errors.json")
|
run_queries_and_save_to_json(
|
||||||
|
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
|
||||||
|
)
|
||||||
|
115
reproduce/Step_3_openai_compatible.py
Normal file
115
reproduce/Step_3_openai_compatible.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from tqdm import tqdm
|
||||||
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
## For Upstage API
|
||||||
|
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"solar-mini",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await openai_embedding(
|
||||||
|
texts,
|
||||||
|
model="solar-embedding-1-large-query",
|
||||||
|
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||||
|
base_url="https://api.upstage.ai/v1/solar",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## /For Upstage API
|
||||||
|
|
||||||
|
|
||||||
|
def extract_queries(file_path):
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
data = data.replace("**", "")
|
||||||
|
|
||||||
|
queries = re.findall(r"- Question \d+: (.+)", data)
|
||||||
|
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
async def process_query(query_text, rag_instance, query_param):
|
||||||
|
try:
|
||||||
|
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||||
|
return {"query": query_text, "result": result, "context": context}, None
|
||||||
|
except Exception as e:
|
||||||
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def run_queries_and_save_to_json(
|
||||||
|
queries, rag_instance, query_param, output_file, error_file
|
||||||
|
):
|
||||||
|
loop = always_get_an_event_loop()
|
||||||
|
|
||||||
|
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
||||||
|
error_file, "a", encoding="utf-8"
|
||||||
|
) as err_file:
|
||||||
|
result_file.write("[\n")
|
||||||
|
first_entry = True
|
||||||
|
|
||||||
|
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||||
|
result, error = loop.run_until_complete(
|
||||||
|
process_query(query_text, rag_instance, query_param)
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
if not first_entry:
|
||||||
|
result_file.write(",\n")
|
||||||
|
json.dump(result, result_file, ensure_ascii=False, indent=4)
|
||||||
|
first_entry = False
|
||||||
|
elif error:
|
||||||
|
json.dump(error, err_file, ensure_ascii=False, indent=4)
|
||||||
|
err_file.write("\n")
|
||||||
|
|
||||||
|
result_file.write("\n]")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cls = "mix"
|
||||||
|
mode = "hybrid"
|
||||||
|
WORKING_DIR = f"../{cls}"
|
||||||
|
|
||||||
|
rag = LightRAG(working_dir=WORKING_DIR)
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||||
|
),
|
||||||
|
)
|
||||||
|
query_param = QueryParam(mode=mode)
|
||||||
|
|
||||||
|
base_dir = "../datasets/questions"
|
||||||
|
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
||||||
|
run_queries_and_save_to_json(
|
||||||
|
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
|
||||||
|
)
|
@@ -1,12 +1,14 @@
|
|||||||
openai
|
|
||||||
tiktoken
|
|
||||||
networkx
|
|
||||||
graspologic
|
|
||||||
nano-vectordb
|
|
||||||
hnswlib
|
|
||||||
xxhash
|
|
||||||
tenacity
|
|
||||||
transformers
|
|
||||||
torch
|
|
||||||
ollama
|
|
||||||
accelerate
|
accelerate
|
||||||
|
aioboto3
|
||||||
|
graspologic
|
||||||
|
hnswlib
|
||||||
|
nano-vectordb
|
||||||
|
networkx
|
||||||
|
ollama
|
||||||
|
openai
|
||||||
|
tenacity
|
||||||
|
tiktoken
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
|
xxhash
|
||||||
|
pyvis
|
Reference in New Issue
Block a user