Merge branch 'main' into main

This commit is contained in:
wiltshirek
2024-11-01 16:50:45 -04:00
committed by GitHub
20 changed files with 932 additions and 183 deletions

30
.github/workflows/linting.yaml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Linting and Formatting
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
lint-and-format:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit
- name: Run pre-commit
run: pre-commit run --all-files

3
.gitignore vendored
View File

@@ -8,4 +8,5 @@ dist/
env/ env/
local_neo4jWorkDir/ local_neo4jWorkDir/
neo4jWorkDir/ neo4jWorkDir/
ignore_this.txt ignore_this.txt
.venv/

227
README.md
View File

@@ -8,7 +8,7 @@
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a> <a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a> <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
<a href='https://discord.gg/mvsfu2Tg'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a> <a href='https://discord.gg/rdE8YVPm'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
</p> </p>
<p> <p>
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' /> <img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
@@ -22,11 +22,17 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div> </div>
## 🎉 News ## 🎉 News
- [x] [2024.10.20]🎯🎯📢📢Weve added a new feature to LightRAG: Graph Visualization. - [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
- [x] [2024.10.18]🎯🎯📢📢Weve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author! - [x] [2024.10.20]🎯📢Weve added a new feature to LightRAG: Graph Visualization.
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉 - [x] [2024.10.18]🎯📢Weve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
## Algorithm Flowchart
![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3)
## Install ## Install
@@ -58,8 +64,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
######### #########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio # import nest_asyncio
# nest_asyncio.apply() # nest_asyncio.apply()
######### #########
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
@@ -190,8 +196,11 @@ see test_neo4j.py for a working example.
<details> <details>
<summary> Using Ollama Models </summary> <summary> Using Ollama Models </summary>
* If you want to use Ollama models, you only need to set LightRAG as follows: ### Overview
If you want to use Ollama models, you need to pull model you plan to use and embedding model, for example `nomic-embed-text`.
Then you only need to set LightRAG as follows:
```python ```python
from lightrag.llm import ollama_model_complete, ollama_embedding from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -213,28 +222,59 @@ rag = LightRAG(
) )
``` ```
* Increasing the `num_ctx` parameter: ### Increasing context size
In order for LightRAG to work context should be at least 32k tokens. By default Ollama models have context size of 8k. You can achieve this using one of two ways:
#### Increasing the `num_ctx` parameter in Modelfile.
1. Pull the model: 1. Pull the model:
```python ```bash
ollama pull qwen2 ollama pull qwen2
``` ```
2. Display the model file: 2. Display the model file:
```python ```bash
ollama show --modelfile qwen2 > Modelfile ollama show --modelfile qwen2 > Modelfile
``` ```
3. Edit the Modelfile by adding the following line: 3. Edit the Modelfile by adding the following line:
```python ```bash
PARAMETER num_ctx 32768 PARAMETER num_ctx 32768
``` ```
4. Create the modified model: 4. Create the modified model:
```python ```bash
ollama create -f Modelfile qwen2m ollama create -f Modelfile qwen2m
``` ```
#### Setup `num_ctx` via Ollama API.
Tiy can use `llm_model_kwargs` param to configure ollama:
```python
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete, # Use Ollama model for text generation
llm_model_name='your_model_name', # Your model name
llm_model_kwargs={"options": {"num_ctx": 32768}},
# Use Ollama embedding function
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
),
)
```
#### Fully functional example
There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
#### Low RAM GPUs
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
</details> </details>
### Query Param ### Query Param
@@ -265,12 +305,33 @@ rag.insert(["TEXT1", "TEXT2",...])
```python ```python
# Incremental Insert: Insert new documents into an existing LightRAG instance # Incremental Insert: Insert new documents into an existing LightRAG instance
rag = LightRAG(working_dir="./dickens") rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
with open("./newText.txt") as f: with open("./newText.txt") as f:
rag.insert(f.read()) rag.insert(f.read())
``` ```
### Multi-file Type Support
The `testract` supports reading file types such as TXT, DOCX, PPTX, CSV, and PDF.
```python
import textract
file_path = 'TEXT.pdf'
text_content = textract.process(file_path)
rag.insert(text_content.decode('utf-8'))
```
### Graph Visualization ### Graph Visualization
<details> <details>
@@ -361,8 +422,8 @@ def main():
SET e.entity_type = node.entity_type, SET e.entity_type = node.entity_type,
e.description = node.description, e.description = node.description,
e.source_id = node.source_id, e.source_id = node.source_id,
e.displayName = node.id e.displayName = node.id
REMOVE e:Entity REMOVE e:Entity
WITH e, node WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*) RETURN count(*)
@@ -415,7 +476,7 @@ def main():
except Exception as e: except Exception as e:
print(f"Error occurred: {e}") print(f"Error occurred: {e}")
finally: finally:
driver.close() driver.close()
@@ -425,6 +486,125 @@ if __name__ == "__main__":
</details> </details>
## API Server Implementation
LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
### Setting up the API Server
<details>
<summary>Click to expand setup instructions</summary>
1. First, ensure you have the required dependencies:
```bash
pip install fastapi uvicorn pydantic
```
2. Set up your environment variables:
```bash
export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
```
3. Run the API server:
```bash
python examples/lightrag_api_openai_compatible_demo.py
```
The server will start on `http://0.0.0.0:8020`.
</details>
### API Endpoints
The API server provides the following endpoints:
#### 1. Query Endpoint
<details>
<summary>Click to view Query endpoint details</summary>
- **URL:** `/query`
- **Method:** POST
- **Body:**
```json
{
"query": "Your question here",
"mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
}
```
- **Example:**
```bash
curl -X POST "http://127.0.0.1:8020/query" \
-H "Content-Type: application/json" \
-d '{"query": "What are the main themes?", "mode": "hybrid"}'
```
</details>
#### 2. Insert Text Endpoint
<details>
<summary>Click to view Insert Text endpoint details</summary>
- **URL:** `/insert`
- **Method:** POST
- **Body:**
```json
{
"text": "Your text content here"
}
```
- **Example:**
```bash
curl -X POST "http://127.0.0.1:8020/insert" \
-H "Content-Type: application/json" \
-d '{"text": "Content to be inserted into RAG"}'
```
</details>
#### 3. Insert File Endpoint
<details>
<summary>Click to view Insert File endpoint details</summary>
- **URL:** `/insert_file`
- **Method:** POST
- **Body:**
```json
{
"file_path": "path/to/your/file.txt"
}
```
- **Example:**
```bash
curl -X POST "http://127.0.0.1:8020/insert_file" \
-H "Content-Type: application/json" \
-d '{"file_path": "./book.txt"}'
```
</details>
#### 4. Health Check Endpoint
<details>
<summary>Click to view Health Check endpoint details</summary>
- **URL:** `/health`
- **Method:** GET
- **Example:**
```bash
curl -X GET "http://127.0.0.1:8020/health"
```
</details>
### Configuration
The API server can be configured using environment variables:
- `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
### Error Handling
<details>
<summary>Click to view error handling details</summary>
The API includes comprehensive error handling:
- File not found errors (404)
- Processing errors (500)
- Supports multiple file encodings (UTF-8 and GBK)
</details>
## Evaluation ## Evaluation
### Dataset ### Dataset
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
@@ -671,12 +851,14 @@ def extract_queries(file_path):
. .
├── examples ├── examples
├── batch_eval.py ├── batch_eval.py
├── generate_query.py
├── graph_visual_with_html.py ├── graph_visual_with_html.py
├── graph_visual_with_neo4j.py ├── graph_visual_with_neo4j.py
├── generate_query.py ├── lightrag_api_openai_compatible_demo.py
├── lightrag_azure_openai_demo.py ├── lightrag_azure_openai_demo.py
├── lightrag_bedrock_demo.py ├── lightrag_bedrock_demo.py
├── lightrag_hf_demo.py ├── lightrag_hf_demo.py
├── lightrag_lmdeploy_demo.py
├── lightrag_ollama_demo.py ├── lightrag_ollama_demo.py
├── lightrag_openai_compatible_demo.py ├── lightrag_openai_compatible_demo.py
├── lightrag_openai_demo.py ├── lightrag_openai_demo.py
@@ -693,8 +875,10 @@ def extract_queries(file_path):
└── utils.py └── utils.py
├── reproduce ├── reproduce
├── Step_0.py ├── Step_0.py
├── Step_1_openai_compatible.py
├── Step_1.py ├── Step_1.py
├── Step_2.py ├── Step_2.py
├── Step_3_openai_compatible.py
└── Step_3.py └── Step_3.py
├── .gitignore ├── .gitignore
├── .pre-commit-config.yaml ├── .pre-commit-config.yaml
@@ -726,3 +910,6 @@ archivePrefix={arXiv},
primaryClass={cs.IR} primaryClass={cs.IR}
} }
``` ```

View File

@@ -3,17 +3,17 @@ from pyvis.network import Network
import random import random
# Load the GraphML file # Load the GraphML file
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml') G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
# Create a Pyvis network # Create a Pyvis network
net = Network(notebook=True) net = Network(height="100vh", notebook=True)
# Convert NetworkX graph to Pyvis network # Convert NetworkX graph to Pyvis network
net.from_nx(G) net.from_nx(G)
# Add colors to nodes # Add colors to nodes
for node in net.nodes: for node in net.nodes:
node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
# Save and display the network # Save and display the network
net.show('knowledge_graph.html') net.show("knowledge_graph.html")

View File

@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j" NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password" NEO4J_PASSWORD = "your_password"
def convert_xml_to_json(xml_path, output_path): def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output.""" """Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path): if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
json_data = xml_to_json(xml_path) json_data = xml_to_json(xml_path)
if json_data: if json_data:
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2) json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}") print(f"JSON file created: {output_path}")
return json_data return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
print("Failed to create JSON data") print("Failed to create JSON data")
return None return None
def process_in_batches(tx, query, data, batch_size): def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query.""" """Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size): for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size] batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
def main(): def main():
# Paths # Paths
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml') xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
json_file = os.path.join(WORKING_DIR, 'graph_data.json') json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON # Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file) json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
return return
# Load nodes and edges # Load nodes and edges
nodes = json_data.get('nodes', []) nodes = json_data.get("nodes", [])
edges = json_data.get('edges', []) edges = json_data.get("edges", [])
# Neo4j queries # Neo4j queries
create_nodes_query = """ create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
SET e.entity_type = node.entity_type, SET e.entity_type = node.entity_type,
e.description = node.description, e.description = node.description,
e.source_id = node.source_id, e.source_id = node.source_id,
e.displayName = node.id e.displayName = node.id
REMOVE e:Entity REMOVE e:Entity
WITH e, node WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*) RETURN count(*)
@@ -100,19 +103,24 @@ def main():
# Execute queries in batches # Execute queries in batches
with driver.session() as session: with driver.session() as session:
# Insert nodes in batches # Insert nodes in batches
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES) session.execute_write(
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
)
# Insert edges in batches # Insert edges in batches
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES) session.execute_write(
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
)
# Set displayName and labels # Set displayName and labels
session.run(set_displayname_and_labels_query) session.run(set_displayname_and_labels_query)
except Exception as e: except Exception as e:
print(f"Error occurred: {e}") print(f"Error occurred: {e}")
finally: finally:
driver.close() driver.close()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -0,0 +1,164 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from typing import Optional
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
DEFAULT_RAG_DIR = "index_default"
app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# LLM model function
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
**kwargs,
)
# Embedding function
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="text-embedding-3-large",
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
)
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=3072, max_token_size=8192, func=embedding_func
),
)
# Data models
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
class InsertRequest(BaseModel):
text: str
class InsertFileRequest(BaseModel):
file_path: str
class Response(BaseModel):
status: str
data: Optional[str] = None
message: Optional[str] = None
# API routes
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
)
return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/insert", response_model=Response)
async def insert_endpoint(request: InsertRequest):
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(request.text))
return Response(status="success", message="Text inserted successfully")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/insert_file", response_model=Response)
async def insert_file(request: InsertFileRequest):
try:
# Check if file exists
if not os.path.exists(request.file_path):
raise HTTPException(
status_code=404, detail=f"File not found: {request.file_path}"
)
# Read file content
try:
with open(request.file_path, "r", encoding="utf-8") as f:
content = f.read()
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
with open(request.file_path, "r", encoding="gbk") as f:
content = f.read()
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
return Response(
status="success",
message=f"File content from {request.file_path} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)
# Usage example
# To run the server, use the following command in your terminal:
# python lightrag_api_openai_compatible_demo.py
# Example requests:
# 1. Query:
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
# 2. Insert text:
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -0,0 +1,75 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def lmdeploy_model_complete(
prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await lmdeploy_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
## please specify chat_template if your local path does not follow original HF file name,
## or model_name is a pytorch model on huggingface.co,
## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
## for a list of chat_template available in lmdeploy.
chat_template="llama3",
# model_format ='awq', # if you are using awq quantization model.
# quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
**kwargs,
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=lmdeploy_model_complete,
llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
texts,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
embed_model=AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
),
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,26 +1,32 @@
import os import os
import logging
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens" WORKING_DIR = "./dickens"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete, llm_model_func=ollama_model_complete,
llm_model_name="your_model_name", llm_model_name="gemma2:2b",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=768, embedding_dim=768,
max_token_size=8192, max_token_size=8192,
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"), func=lambda texts: ollama_embedding(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
), ),
) )
with open("./book.txt", "r", encoding="utf-8") as f: with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read()) rag.insert(f.read())

View File

@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
) )
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim
# function test # function test
async def test_funcs(): async def test_funcs():
result = await llm_model_func("How are you?") result = await llm_model_func("How are you?")
@@ -43,37 +50,59 @@ async def test_funcs():
print("embedding_func: ", result) print("embedding_func: ", result)
asyncio.run(test_funcs()) # asyncio.run(test_funcs())
rag = LightRAG( async def main():
working_dir=WORKING_DIR, try:
llm_model_func=llm_model_func, embedding_dimension = await get_embedding_dim()
embedding_func=EmbeddingFunc( print(f"Detected embedding dimension: {embedding_dimension}")
embedding_dim=4096, max_token_size=8192, func=embedding_func
), rag = LightRAG(
) working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())
# Perform naive search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
# Perform local search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
# Perform global search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
)
# Perform hybrid search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
)
except Exception as e:
print(f"An error occurred: {e}")
with open("./book.txt", "r", encoding="utf-8") as f: if __name__ == "__main__":
rag.insert(f.read()) asyncio.run(main())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
texts, texts,
model="netease-youdao/bce-embedding-base_v1", model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("SILICONFLOW_API_KEY"), api_key=os.getenv("SILICONFLOW_API_KEY"),
max_token_size=512 max_token_size=512,
) )

View File

@@ -27,11 +27,12 @@ rag = LightRAG(
# Read all .txt files from the TEXT_FILES_DIR directory # Read all .txt files from the TEXT_FILES_DIR directory
texts = [] texts = []
for filename in os.listdir(TEXT_FILES_DIR): for filename in os.listdir(TEXT_FILES_DIR):
if filename.endswith('.txt'): if filename.endswith(".txt"):
file_path = os.path.join(TEXT_FILES_DIR, filename) file_path = os.path.join(TEXT_FILES_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, "r", encoding="utf-8") as file:
texts.append(file.read()) texts.append(file.read())
# Batch insert texts into LightRAG with a retry mechanism # Batch insert texts into LightRAG with a retry mechanism
def insert_texts_with_retry(rag, texts, retries=3, delay=5): def insert_texts_with_retry(rag, texts, retries=3, delay=5):
for _ in range(retries): for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
rag.insert(texts) rag.insert(texts)
return return
except Exception as e: except Exception as e:
print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...") print(
f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
)
time.sleep(delay) time.sleep(delay)
raise RuntimeError("Failed to insert texts after multiple retries.") raise RuntimeError("Failed to insert texts after multiple retries.")
insert_texts_with_retry(rag, texts) insert_texts_with_retry(rag, texts)
# Perform different types of queries and handle potential errors # Perform different types of queries and handle potential errors
try: try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
except Exception as e: except Exception as e:
print(f"Error performing naive search: {e}") print(f"Error performing naive search: {e}")
try: try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
except Exception as e: except Exception as e:
print(f"Error performing local search: {e}") print(f"Error performing local search: {e}")
try: try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
except Exception as e: except Exception as e:
print(f"Error performing global search: {e}") print(f"Error performing global search: {e}")
try: try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
except Exception as e: except Exception as e:
print(f"Error performing hybrid search: {e}") print(f"Error performing hybrid search: {e}")
# Function to clear VRAM resources # Function to clear VRAM resources
def clear_vram(): def clear_vram():
os.system("sudo nvidia-smi --gpu-reset") os.system("sudo nvidia-smi --gpu-reset")
# Regularly clear VRAM to prevent overflow # Regularly clear VRAM to prevent overflow
clear_vram_interval = 3600 # Clear once every hour clear_vram_interval = 3600 # Clear once every hour
start_time = time.time() start_time = time.time()

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "0.0.7" __version__ = "0.0.8"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -109,6 +109,7 @@ class LightRAG:
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768 llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16 llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
# storage # storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
@@ -179,7 +180,11 @@ class LightRAG:
) )
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache) partial(
self.llm_model_func,
hashing_kv=self.llm_response_cache,
**self.llm_model_kwargs,
)
) )
def _get_storage_class(self) -> Type[BaseGraphStorage]: def _get_storage_class(self) -> Type[BaseGraphStorage]:
return { return {
@@ -239,7 +244,7 @@ class LightRAG:
logger.info("[Entity Extraction]...") logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities( maybe_new_kg = await extract_entities(
inserting_chunks, inserting_chunks,
knwoledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb, entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb, relationships_vdb=self.relationships_vdb,
global_config=asdict(self), global_config=asdict(self),

View File

@@ -7,7 +7,13 @@ import aiohttp
import numpy as np import numpy as np
import ollama import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)
import base64 import base64
import struct import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
) )
return response.choices[0].message.content return response.choices[0].message.content
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
) )
async def azure_openai_complete_if_cache(model, async def azure_openai_complete_if_cache(
model,
prompt, prompt,
system_prompt=None, system_prompt=None,
history_messages=[], history_messages=[],
base_url=None, base_url=None,
api_key=None, api_key=None,
**kwargs): **kwargs,
):
if api_key: if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url: if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), openai_async_client = AsyncAzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION")) api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = [] messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
) )
return response.choices[0].message.content return response.choices[0].message.content
class BedrockError(Exception): class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock""" """Generic error for issues related to Amazon Bedrock"""
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def initialize_hf_model(model_name): def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True) hf_tokenizer = AutoTokenizer.from_pretrained(
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None: if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token hf_tokenizer.pad_token = hf_tokenizer.eos_token
@@ -266,10 +282,13 @@ async def hf_model_if_cache(
input_ids = hf_tokenizer( input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda") ).to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate( output = hf_model.generate(
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
) )
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text return response_text
@@ -280,8 +299,10 @@ async def ollama_model_if_cache(
) -> str: ) -> str:
kwargs.pop("max_tokens", None) kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None) kwargs.pop("response_format", None)
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
ollama_client = ollama.AsyncClient() ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
@@ -305,6 +326,135 @@ async def ollama_model_if_cache(
return result return result
@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
log_level="WARNING",
)
return lmdeploy_pipe
async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
return response
async def gpt_4o_complete( async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -328,8 +478,9 @@ async def gpt_4o_mini_complete(
**kwargs, **kwargs,
) )
async def azure_openai_complete( async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
return await azure_openai_complete_if_cache( return await azure_openai_complete_if_cache(
"conversation-4o-mini", "conversation-4o-mini",
@@ -339,6 +490,7 @@ async def azure_openai_complete(
**kwargs, **kwargs,
) )
async def bedrock_complete( async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
@@ -418,9 +570,11 @@ async def azure_openai_embedding(
if base_url: if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), openai_async_client = AsyncAzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION")) api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"
@@ -440,35 +594,28 @@ async def siliconcloud_embedding(
max_token_size: int = 512, max_token_size: int = 512,
api_key: str = None, api_key: str = None,
) -> np.ndarray: ) -> np.ndarray:
if api_key and not api_key.startswith('Bearer '): if api_key and not api_key.startswith("Bearer "):
api_key = 'Bearer ' + api_key api_key = "Bearer " + api_key
headers = { headers = {"Authorization": api_key, "Content-Type": "application/json"}
"Authorization": api_key,
"Content-Type": "application/json"
}
truncate_texts = [text[0:max_token_size] for text in texts] truncate_texts = [text[0:max_token_size] for text in texts]
payload = { payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
"model": model,
"input": truncate_texts,
"encoding_format": "base64"
}
base64_strings = [] base64_strings = []
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response: async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json() content = await response.json()
if 'code' in content: if "code" in content:
raise ValueError(content) raise ValueError(content)
base64_strings = [item['embedding'] for item in content['data']] base64_strings = [item["embedding"] for item in content["data"]]
embeddings = [] embeddings = []
for string in base64_strings: for string in base64_strings:
decode_bytes = base64.b64decode(string) decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4 n = len(decode_bytes) // 4
float_array = struct.unpack('<' + 'f' * n, decode_bytes) float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array) embeddings.append(float_array)
return np.array(embeddings) return np.array(embeddings)
@@ -555,14 +702,16 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
return embeddings.detach().numpy() return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray: async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
embed_text = [] embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts: for text in texts:
data = ollama.embeddings(model=embed_model, prompt=text) data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"]) embed_text.append(data["embedding"])
return embed_text return embed_text
class Model(BaseModel): class Model(BaseModel):
""" """
This is a Pydantic model class named 'Model' that is used to define a custom language model. This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +729,20 @@ class Model(BaseModel):
The 'kwargs' dictionary contains the model name and API key to be passed to the function. The 'kwargs' dictionary contains the model name and API key to be passed to the function.
""" """
gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string") gen_func: Callable[[Any], str] = Field(
kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc") ...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: Dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class MultiModel(): class MultiModel:
""" """
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers. Could also be used for spliting across diffrent models or providers.
@@ -611,26 +766,31 @@ class MultiModel():
) )
``` ```
""" """
def __init__(self, models: List[Model]): def __init__(self, models: List[Model]):
self._models = models self._models = models
self._current_model = 0 self._current_model = 0
def _next_model(self): def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models) self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model] return self._models[self._current_model]
async def llm_model_func( async def llm_model_func(
self, self, prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name kwargs.pop("model", None) # stop from overwriting the custom model name
next_model = self._next_model() next_model = self._next_model()
args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs) args = dict(
prompt=prompt,
return await next_model.gen_func( system_prompt=system_prompt,
**args history_messages=history_messages,
**kwargs,
**next_model.kwargs,
) )
return await next_model.gen_func(**args)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio

View File

@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert( async def _merge_nodes_then_upsert(
entity_name: str, entity_name: str,
nodes_data: list[dict], nodes_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict, global_config: dict,
): ):
already_entitiy_types = [] already_entitiy_types = []
already_source_ids = [] already_source_ids = []
already_description = [] already_description = []
already_node = await knwoledge_graph_inst.get_node(entity_name) already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None: if already_node is not None:
already_entitiy_types.append(already_node["entity_type"]) already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend( already_source_ids.extend(
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
description=description, description=description,
source_id=source_id, source_id=source_id,
) )
await knwoledge_graph_inst.upsert_node( await knowledge_graph_inst.upsert_node(
entity_name, entity_name,
node_data=node_data, node_data=node_data,
) )
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
src_id: str, src_id: str,
tgt_id: str, tgt_id: str,
edges_data: list[dict], edges_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
global_config: dict, global_config: dict,
): ):
already_weights = [] already_weights = []
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
already_description = [] already_description = []
already_keywords = [] already_keywords = []
if await knwoledge_graph_inst.has_edge(src_id, tgt_id): if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id) already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"]) already_weights.append(already_edge["weight"])
already_source_ids.extend( already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
set([dp["source_id"] for dp in edges_data] + already_source_ids) set([dp["source_id"] for dp in edges_data] + already_source_ids)
) )
for need_insert_id in [src_id, tgt_id]: for need_insert_id in [src_id, tgt_id]:
if not (await knwoledge_graph_inst.has_node(need_insert_id)): if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knwoledge_graph_inst.upsert_node( await knowledge_graph_inst.upsert_node(
need_insert_id, need_insert_id,
node_data={ node_data={
"source_id": source_id, "source_id": source_id,
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
description = await _handle_entity_relation_summary( description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config (src_id, tgt_id), description, global_config
) )
await knwoledge_graph_inst.upsert_edge( await knowledge_graph_inst.upsert_edge(
src_id, src_id,
tgt_id, tgt_id,
edge_data=dict( edge_data=dict(
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
async def extract_entities( async def extract_entities(
chunks: dict[str, TextChunkSchema], chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage, entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
global_config: dict, global_config: dict,
@@ -341,13 +341,13 @@ async def extract_entities(
maybe_edges[tuple(sorted(k))].extend(v) maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather( all_entities_data = await asyncio.gather(
*[ *[
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config) _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items() for k, v in maybe_nodes.items()
] ]
) )
all_relationships_data = await asyncio.gather( all_relationships_data = await asyncio.gather(
*[ *[
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config) _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
for k, v in maybe_edges.items() for k, v in maybe_edges.items()
] ]
) )
@@ -384,7 +384,7 @@ async def extract_entities(
} }
await relationships_vdb.upsert(data_for_vdb) await relationships_vdb.upsert(data_for_vdb)
return knwoledge_graph_inst return knowledge_graph_inst
async def local_query( async def local_query(

View File

@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f: with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4) json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file): def xml_to_json(xml_file):
try: try:
tree = ET.parse(xml_file) tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
print(f"Root element: {root.tag}") print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}") print(f"Root attributes: {root.attrib}")
data = { data = {"nodes": [], "edges": []}
"nodes": [],
"edges": []
}
# Use namespace # Use namespace
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'} namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
for node in root.findall('.//node', namespace): for node in root.findall(".//node", namespace):
node_data = { node_data = {
"id": node.get('id').strip('"'), "id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "", "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "", if node.find("./data[@key='d0']", namespace) is not None
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else "" else "",
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
} }
data["nodes"].append(node_data) data["nodes"].append(node_data)
for edge in root.findall('.//edge', namespace): for edge in root.findall(".//edge", namespace):
edge_data = { edge_data = {
"source": edge.get('source').strip('"'), "source": edge.get("source").strip('"'),
"target": edge.get('target').strip('"'), "target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0, "weight": float(edge.find("./data[@key='d3']", namespace).text)
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "", if edge.find("./data[@key='d3']", namespace) is not None
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "", else 0.0,
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else "" "description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
} }
data["edges"].append(edge_data) data["edges"].append(edge_data)

View File

@@ -18,8 +18,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param): async def process_query(query_text, rag_instance, query_param):
try: try:
result, context = await rag_instance.aquery(query_text, param=query_param) result = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result, "context": context}, None return {"query": query_text, "result": result}, None
except Exception as e: except Exception as e:
return None, {"query": query_text, "error": str(e)} return None, {"query": query_text, "error": str(e)}

View File

@@ -50,8 +50,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param): async def process_query(query_text, rag_instance, query_param):
try: try:
result, context = await rag_instance.aquery(query_text, param=query_param) result = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result, "context": context}, None return {"query": query_text, "result": result}, None
except Exception as e: except Exception as e:
return None, {"query": query_text, "error": str(e)} return None, {"query": query_text, "error": str(e)}

View File

@@ -1,16 +1,17 @@
accelerate accelerate
aioboto3 aioboto3
aiohttp
graspologic graspologic
hnswlib hnswlib
nano-vectordb nano-vectordb
neo4j
networkx networkx
ollama ollama
openai openai
pyvis
tenacity tenacity
tiktoken tiktoken
torch torch
transformers transformers
xxhash xxhash
pyvis # lmdeploy[all]
aiohttp
neo4j

View File

@@ -1,39 +1,88 @@
import setuptools import setuptools
from pathlib import Path
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
vars2find = ["__author__", "__version__", "__url__"] # Reading the long description from README.md
vars2readme = {} def read_long_description():
with open("./lightrag/__init__.py") as f: try:
for line in f.readlines(): return Path("README.md").read_text(encoding="utf-8")
for v in vars2find: except FileNotFoundError:
if line.startswith(v): return "A description of LightRAG is currently unavailable."
line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
vars2readme[v] = line.split("=")[1]
deps = []
with open("./requirements.txt") as f: # Retrieving metadata from __init__.py
for line in f.readlines(): def retrieve_metadata():
if not line.strip(): vars2find = ["__author__", "__version__", "__url__"]
continue vars2readme = {}
deps.append(line.strip()) try:
with open("./lightrag/__init__.py") as f:
for line in f.readlines():
for v in vars2find:
if line.startswith(v):
line = (
line.replace(" ", "")
.replace('"', "")
.replace("'", "")
.strip()
)
vars2readme[v] = line.split("=")[1]
except FileNotFoundError:
raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
# Checking if all required variables are found
missing_vars = [v for v in vars2find if v not in vars2readme]
if missing_vars:
raise ValueError(
f"Missing required metadata variables in __init__.py: {missing_vars}"
)
return vars2readme
# Reading dependencies from requirements.txt
def read_requirements():
deps = []
try:
with open("./requirements.txt") as f:
deps = [line.strip() for line in f if line.strip()]
except FileNotFoundError:
print(
"Warning: 'requirements.txt' not found. No dependencies will be installed."
)
return deps
metadata = retrieve_metadata()
long_description = read_long_description()
requirements = read_requirements()
setuptools.setup( setuptools.setup(
name="lightrag-hku", name="lightrag-hku",
url=vars2readme["__url__"], url=metadata["__url__"],
version=vars2readme["__version__"], version=metadata["__version__"],
author=vars2readme["__author__"], author=metadata["__author__"],
description="LightRAG: Simple and Fast Retrieval-Augmented Generation", description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
packages=["lightrag"], packages=setuptools.find_packages(
exclude=("tests*", "docs*")
), # Automatically find packages
classifiers=[ classifiers=[
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Intended Audience :: Developers",
"Topic :: Software Development :: Libraries :: Python Modules",
], ],
python_requires=">=3.9", python_requires=">=3.9",
install_requires=deps, install_requires=requirements,
include_package_data=True, # Includes non-code files from MANIFEST.in
project_urls={ # Additional project metadata
"Documentation": metadata.get("__url__", ""),
"Source": metadata.get("__url__", ""),
"Tracker": f"{metadata.get('__url__', '')}/issues"
if metadata.get("__url__")
else "",
},
) )