Merge branch 'main' into main

This commit is contained in:
wiltshirek
2024-11-01 16:50:45 -04:00
committed by GitHub
20 changed files with 932 additions and 183 deletions

30
.github/workflows/linting.yaml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Linting and Formatting
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
lint-and-format:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit
- name: Run pre-commit
run: pre-commit run --all-files

3
.gitignore vendored
View File

@@ -8,4 +8,5 @@ dist/
env/
local_neo4jWorkDir/
neo4jWorkDir/
ignore_this.txt
ignore_this.txt
.venv/

227
README.md
View File

@@ -8,7 +8,7 @@
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
<a href='https://discord.gg/mvsfu2Tg'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
<a href='https://discord.gg/rdE8YVPm'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
</p>
<p>
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
@@ -22,11 +22,17 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div>
## 🎉 News
- [x] [2024.10.20]🎯🎯📢📢Weve added a new feature to LightRAG: Graph Visualization.
- [x] [2024.10.18]🎯🎯📢📢Weve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
- [x] [2024.10.20]🎯📢Weve added a new feature to LightRAG: Graph Visualization.
- [x] [2024.10.18]🎯📢Weve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
- [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
- [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
## Algorithm Flowchart
![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3)
## Install
@@ -58,8 +64,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
@@ -190,8 +196,11 @@ see test_neo4j.py for a working example.
<details>
<summary> Using Ollama Models </summary>
* If you want to use Ollama models, you only need to set LightRAG as follows:
### Overview
If you want to use Ollama models, you need to pull model you plan to use and embedding model, for example `nomic-embed-text`.
Then you only need to set LightRAG as follows:
```python
from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -213,28 +222,59 @@ rag = LightRAG(
)
```
* Increasing the `num_ctx` parameter:
### Increasing context size
In order for LightRAG to work context should be at least 32k tokens. By default Ollama models have context size of 8k. You can achieve this using one of two ways:
#### Increasing the `num_ctx` parameter in Modelfile.
1. Pull the model:
```python
```bash
ollama pull qwen2
```
2. Display the model file:
```python
```bash
ollama show --modelfile qwen2 > Modelfile
```
3. Edit the Modelfile by adding the following line:
```python
```bash
PARAMETER num_ctx 32768
```
4. Create the modified model:
```python
```bash
ollama create -f Modelfile qwen2m
```
#### Setup `num_ctx` via Ollama API.
Tiy can use `llm_model_kwargs` param to configure ollama:
```python
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete, # Use Ollama model for text generation
llm_model_name='your_model_name', # Your model name
llm_model_kwargs={"options": {"num_ctx": 32768}},
# Use Ollama embedding function
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
),
)
```
#### Fully functional example
There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
#### Low RAM GPUs
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
</details>
### Query Param
@@ -265,12 +305,33 @@ rag.insert(["TEXT1", "TEXT2",...])
```python
# Incremental Insert: Insert new documents into an existing LightRAG instance
rag = LightRAG(working_dir="./dickens")
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
with open("./newText.txt") as f:
rag.insert(f.read())
```
### Multi-file Type Support
The `testract` supports reading file types such as TXT, DOCX, PPTX, CSV, and PDF.
```python
import textract
file_path = 'TEXT.pdf'
text_content = textract.process(file_path)
rag.insert(text_content.decode('utf-8'))
```
### Graph Visualization
<details>
@@ -361,8 +422,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -415,7 +476,7 @@ def main():
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()
@@ -425,6 +486,125 @@ if __name__ == "__main__":
</details>
## API Server Implementation
LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
### Setting up the API Server
<details>
<summary>Click to expand setup instructions</summary>
1. First, ensure you have the required dependencies:
```bash
pip install fastapi uvicorn pydantic
```
2. Set up your environment variables:
```bash
export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
```
3. Run the API server:
```bash
python examples/lightrag_api_openai_compatible_demo.py
```
The server will start on `http://0.0.0.0:8020`.
</details>
### API Endpoints
The API server provides the following endpoints:
#### 1. Query Endpoint
<details>
<summary>Click to view Query endpoint details</summary>
- **URL:** `/query`
- **Method:** POST
- **Body:**
```json
{
"query": "Your question here",
"mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
}
```
- **Example:**
```bash
curl -X POST "http://127.0.0.1:8020/query" \
-H "Content-Type: application/json" \
-d '{"query": "What are the main themes?", "mode": "hybrid"}'
```
</details>
#### 2. Insert Text Endpoint
<details>
<summary>Click to view Insert Text endpoint details</summary>
- **URL:** `/insert`
- **Method:** POST
- **Body:**
```json
{
"text": "Your text content here"
}
```
- **Example:**
```bash
curl -X POST "http://127.0.0.1:8020/insert" \
-H "Content-Type: application/json" \
-d '{"text": "Content to be inserted into RAG"}'
```
</details>
#### 3. Insert File Endpoint
<details>
<summary>Click to view Insert File endpoint details</summary>
- **URL:** `/insert_file`
- **Method:** POST
- **Body:**
```json
{
"file_path": "path/to/your/file.txt"
}
```
- **Example:**
```bash
curl -X POST "http://127.0.0.1:8020/insert_file" \
-H "Content-Type: application/json" \
-d '{"file_path": "./book.txt"}'
```
</details>
#### 4. Health Check Endpoint
<details>
<summary>Click to view Health Check endpoint details</summary>
- **URL:** `/health`
- **Method:** GET
- **Example:**
```bash
curl -X GET "http://127.0.0.1:8020/health"
```
</details>
### Configuration
The API server can be configured using environment variables:
- `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
### Error Handling
<details>
<summary>Click to view error handling details</summary>
The API includes comprehensive error handling:
- File not found errors (404)
- Processing errors (500)
- Supports multiple file encodings (UTF-8 and GBK)
</details>
## Evaluation
### Dataset
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
@@ -671,12 +851,14 @@ def extract_queries(file_path):
.
├── examples
├── batch_eval.py
├── generate_query.py
├── graph_visual_with_html.py
├── graph_visual_with_neo4j.py
├── generate_query.py
├── lightrag_api_openai_compatible_demo.py
├── lightrag_azure_openai_demo.py
├── lightrag_bedrock_demo.py
├── lightrag_hf_demo.py
├── lightrag_lmdeploy_demo.py
├── lightrag_ollama_demo.py
├── lightrag_openai_compatible_demo.py
├── lightrag_openai_demo.py
@@ -693,8 +875,10 @@ def extract_queries(file_path):
└── utils.py
├── reproduce
├── Step_0.py
├── Step_1_openai_compatible.py
├── Step_1.py
├── Step_2.py
├── Step_3_openai_compatible.py
└── Step_3.py
├── .gitignore
├── .pre-commit-config.yaml
@@ -726,3 +910,6 @@ archivePrefix={arXiv},
primaryClass={cs.IR}
}
```

View File

@@ -3,17 +3,17 @@ from pyvis.network import Network
import random
# Load the GraphML file
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
# Create a Pyvis network
net = Network(notebook=True)
net = Network(height="100vh", notebook=True)
# Convert NetworkX graph to Pyvis network
net.from_nx(G)
# Add colors to nodes
for node in net.nodes:
node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
# Save and display the network
net.show('knowledge_graph.html')
net.show("knowledge_graph.html")

View File

@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
json_data = xml_to_json(xml_path)
if json_data:
with open(output_path, 'w', encoding='utf-8') as f:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
print("Failed to create JSON data")
return None
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
batch = data[i:i + batch_size]
batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
def main():
# Paths
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
return
# Load nodes and edges
nodes = json_data.get('nodes', [])
edges = json_data.get('edges', [])
nodes = json_data.get("nodes", [])
edges = json_data.get("edges", [])
# Neo4j queries
create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -100,19 +103,24 @@ def main():
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
session.execute_write(
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
)
# Insert edges in batches
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
session.execute_write(
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
)
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
finally:
driver.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,164 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np
from typing import Optional
import asyncio
import nest_asyncio
# Apply nest_asyncio to solve event loop issues
nest_asyncio.apply()
DEFAULT_RAG_DIR = "index_default"
app = FastAPI(title="LightRAG API", description="API for RAG operations")
# Configure working directory
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
print(f"WORKING_DIR: {WORKING_DIR}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# LLM model function
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
**kwargs,
)
# Embedding function
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="text-embedding-3-large",
api_key="YOUR_API_KEY",
base_url="YourURL/v1",
)
# Initialize RAG instance
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=3072, max_token_size=8192, func=embedding_func
),
)
# Data models
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
class InsertRequest(BaseModel):
text: str
class InsertFileRequest(BaseModel):
file_path: str
class Response(BaseModel):
status: str
data: Optional[str] = None
message: Optional[str] = None
# API routes
@app.post("/query", response_model=Response)
async def query_endpoint(request: QueryRequest):
try:
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
)
return Response(status="success", data=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/insert", response_model=Response)
async def insert_endpoint(request: InsertRequest):
try:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(request.text))
return Response(status="success", message="Text inserted successfully")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/insert_file", response_model=Response)
async def insert_file(request: InsertFileRequest):
try:
# Check if file exists
if not os.path.exists(request.file_path):
raise HTTPException(
status_code=404, detail=f"File not found: {request.file_path}"
)
# Read file content
try:
with open(request.file_path, "r", encoding="utf-8") as f:
content = f.read()
except UnicodeDecodeError:
# If UTF-8 decoding fails, try other encodings
with open(request.file_path, "r", encoding="gbk") as f:
content = f.read()
# Insert file content
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, lambda: rag.insert(content))
return Response(
status="success",
message=f"File content from {request.file_path} inserted successfully",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)
# Usage example
# To run the server, use the following command in your terminal:
# python lightrag_api_openai_compatible_demo.py
# Example requests:
# 1. Query:
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
# 2. Insert text:
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
# 3. Insert file:
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
# 4. Health check:
# curl -X GET "http://127.0.0.1:8020/health"

View File

@@ -0,0 +1,75 @@
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def lmdeploy_model_complete(
prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await lmdeploy_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
## please specify chat_template if your local path does not follow original HF file name,
## or model_name is a pytorch model on huggingface.co,
## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
## for a list of chat_template available in lmdeploy.
chat_template="llama3",
# model_format ='awq', # if you are using awq quantization model.
# quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
**kwargs,
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=lmdeploy_model_complete,
llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=5000,
func=lambda texts: hf_embedding(
texts,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
embed_model=AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
),
),
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

View File

@@ -1,26 +1,32 @@
import os
import logging
from lightrag import LightRAG, QueryParam
from lightrag.llm import ollama_model_complete, ollama_embedding
from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete,
llm_model_name="your_model_name",
llm_model_name="gemma2:2b",
llm_model_max_async=4,
llm_model_max_token_size=32768,
llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
func=lambda texts: ollama_embedding(
texts, embed_model="nomic-embed-text", host="http://localhost:11434"
),
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

View File

@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
)
async def get_embedding_dim():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
embedding_dim = embedding.shape[1]
return embedding_dim
# function test
async def test_funcs():
result = await llm_model_func("How are you?")
@@ -43,37 +50,59 @@ async def test_funcs():
print("embedding_func: ", result)
asyncio.run(test_funcs())
# asyncio.run(test_funcs())
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096, max_token_size=8192, func=embedding_func
),
)
async def main():
try:
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
with open("./book.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())
# Perform naive search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
# Perform local search
print(
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
# Perform global search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="global"),
)
)
# Perform hybrid search
print(
await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid"),
)
)
except Exception as e:
print(f"An error occurred: {e}")
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)
# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)
# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)
# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
texts,
model="netease-youdao/bce-embedding-base_v1",
api_key=os.getenv("SILICONFLOW_API_KEY"),
max_token_size=512
max_token_size=512,
)

View File

@@ -27,11 +27,12 @@ rag = LightRAG(
# Read all .txt files from the TEXT_FILES_DIR directory
texts = []
for filename in os.listdir(TEXT_FILES_DIR):
if filename.endswith('.txt'):
if filename.endswith(".txt"):
file_path = os.path.join(TEXT_FILES_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
texts.append(file.read())
# Batch insert texts into LightRAG with a retry mechanism
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
rag.insert(texts)
return
except Exception as e:
print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
print(
f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
)
time.sleep(delay)
raise RuntimeError("Failed to insert texts after multiple retries.")
insert_texts_with_retry(rag, texts)
# Perform different types of queries and handle potential errors
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
except Exception as e:
print(f"Error performing naive search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
except Exception as e:
print(f"Error performing local search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
except Exception as e:
print(f"Error performing global search: {e}")
try:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
except Exception as e:
print(f"Error performing hybrid search: {e}")
# Function to clear VRAM resources
def clear_vram():
os.system("sudo nvidia-smi --gpu-reset")
# Regularly clear VRAM to prevent overflow
clear_vram_interval = 3600 # Clear once every hour
start_time = time.time()

View File

@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "0.0.7"
__version__ = "0.0.8"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"

View File

@@ -109,6 +109,7 @@ class LightRAG:
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
llm_model_kwargs: dict = field(default_factory=dict)
# storage
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
@@ -179,7 +180,11 @@ class LightRAG:
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
partial(
self.llm_model_func,
hashing_kv=self.llm_response_cache,
**self.llm_model_kwargs,
)
)
def _get_storage_class(self) -> Type[BaseGraphStorage]:
return {
@@ -239,7 +244,7 @@ class LightRAG:
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
knwoledge_graph_inst=self.chunk_entity_relation_graph,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),

View File

@@ -7,7 +7,13 @@ import aiohttp
import numpy as np
import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
Timeout,
AsyncAzureOpenAI,
)
import base64
import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
)
return response.choices[0].message.content
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def azure_openai_complete_if_cache(model,
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs):
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
)
return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
if hf_tokenizer.pad_token is None:
hf_tokenizer.pad_token = hf_tokenizer.eos_token
@@ -266,10 +282,13 @@ async def hf_model_if_cache(
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
@@ -280,8 +299,10 @@ async def ollama_model_if_cache(
) -> str:
kwargs.pop("max_tokens", None)
kwargs.pop("response_format", None)
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
ollama_client = ollama.AsyncClient()
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -305,6 +326,135 @@ async def ollama_model_if_cache(
return result
@lru_cache(maxsize=1)
def initialize_lmdeploy_pipeline(
model,
tp=1,
chat_template=None,
log_level="WARNING",
model_format="hf",
quant_policy=0,
):
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
lmdeploy_pipe = pipeline(
model_path=model,
backend_config=TurbomindEngineConfig(
tp=tp, model_format=model_format, quant_policy=quant_policy
),
chat_template_config=ChatTemplateConfig(model_name=chat_template)
if chat_template
else None,
log_level="WARNING",
)
return lmdeploy_pipe
async def lmdeploy_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
chat_template=None,
model_format="hf",
quant_policy=0,
**kwargs,
) -> str:
"""
Args:
model (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
chat_template (str): needed when model is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
and when the model name of local path did not match the original model name in HF.
tp (int): tensor parallel
prompt (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
except Exception:
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
kwargs.pop("response_format", None)
max_new_tokens = kwargs.pop("max_tokens", 512)
tp = kwargs.pop("tp", 1)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
do_preprocess = kwargs.pop("do_preprocess", True)
do_sample = kwargs.pop("do_sample", False)
gen_params = kwargs
version = version_info
if do_sample is not None and version < (0, 6, 0):
raise RuntimeError(
"`do_sample` parameter is not supported by lmdeploy until "
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
)
else:
do_sample = True
gen_params.update(do_sample=do_sample)
lmdeploy_pipe = initialize_lmdeploy_pipeline(
model=model,
tp=tp,
chat_template=chat_template,
model_format=model_format,
quant_policy=quant_policy,
log_level="WARNING",
)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
if hashing_kv is not None:
args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None:
return if_cache_return["return"]
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens,
max_new_tokens=max_new_tokens,
**gen_params,
)
response = ""
async for res in lmdeploy_pipe.generate(
messages,
gen_config=gen_config,
do_preprocess=do_preprocess,
stream_response=False,
session_id=1,
):
response += res.response
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
return response
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -328,8 +478,9 @@ async def gpt_4o_mini_complete(
**kwargs,
)
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await azure_openai_complete_if_cache(
"conversation-4o-mini",
@@ -339,6 +490,7 @@ async def azure_openai_complete(
**kwargs,
)
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -418,9 +570,11 @@ async def azure_openai_embedding(
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
@@ -440,35 +594,28 @@ async def siliconcloud_embedding(
max_token_size: int = 512,
api_key: str = None,
) -> np.ndarray:
if api_key and not api_key.startswith('Bearer '):
api_key = 'Bearer ' + api_key
if api_key and not api_key.startswith("Bearer "):
api_key = "Bearer " + api_key
headers = {
"Authorization": api_key,
"Content-Type": "application/json"
}
headers = {"Authorization": api_key, "Content-Type": "application/json"}
truncate_texts = [text[0:max_token_size] for text in texts]
payload = {
"model": model,
"input": truncate_texts,
"encoding_format": "base64"
}
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
base64_strings = []
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
content = await response.json()
if 'code' in content:
if "code" in content:
raise ValueError(content)
base64_strings = [item['embedding'] for item in content['data']]
base64_strings = [item["embedding"] for item in content["data"]]
embeddings = []
for string in base64_strings:
decode_bytes = base64.b64decode(string)
n = len(decode_bytes) // 4
float_array = struct.unpack('<' + 'f' * n, decode_bytes)
float_array = struct.unpack("<" + "f" * n, decode_bytes)
embeddings.append(float_array)
return np.array(embeddings)
@@ -555,14 +702,16 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
return embeddings.detach().numpy()
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts:
data = ollama.embeddings(model=embed_model, prompt=text)
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +729,20 @@ class Model(BaseModel):
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
gen_func: Callable[[Any], str] = Field(
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: Dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
class Config:
arbitrary_types_allowed = True
class MultiModel():
class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
@@ -611,26 +766,31 @@ class MultiModel():
)
```
"""
def __init__(self, models: List[Model]):
self._models = models
self._current_model = 0
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
self,
prompt, system_prompt=None, history_messages=[], **kwargs
self, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("model", None) # stop from overwriting the custom model name
next_model = self._next_model()
args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
return await next_model.gen_func(
**args
args = dict(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
**next_model.kwargs,
)
return await next_model.gen_func(**args)
if __name__ == "__main__":
import asyncio

View File

@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
already_node = await knwoledge_graph_inst.get_node(entity_name)
already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
description=description,
source_id=source_id,
)
await knwoledge_graph_inst.upsert_node(
await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
knwoledge_graph_inst: BaseGraphStorage,
knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_weights = []
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
already_description = []
already_keywords = []
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
for need_insert_id in [src_id, tgt_id]:
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
await knwoledge_graph_inst.upsert_node(
if not (await knowledge_graph_inst.has_node(need_insert_id)):
await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config
)
await knwoledge_graph_inst.upsert_edge(
await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
knwoledge_graph_inst: BaseGraphStorage,
knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
@@ -341,13 +341,13 @@ async def extract_entities(
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
@@ -384,7 +384,7 @@ async def extract_entities(
}
await relationships_vdb.upsert(data_for_vdb)
return knwoledge_graph_inst
return knowledge_graph_inst
async def local_query(

View File

@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
data = {
"nodes": [],
"edges": []
}
data = {"nodes": [], "edges": []}
# Use namespace
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
for node in root.findall('.//node', namespace):
for node in root.findall(".//node", namespace):
node_data = {
"id": node.get('id').strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
"id": node.get("id").strip('"'),
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
if node.find("./data[@key='d0']", namespace) is not None
else "",
"description": node.find("./data[@key='d1']", namespace).text
if node.find("./data[@key='d1']", namespace) is not None
else "",
"source_id": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "",
}
data["nodes"].append(node_data)
for edge in root.findall('.//edge', namespace):
for edge in root.findall(".//edge", namespace):
edge_data = {
"source": edge.get('source').strip('"'),
"target": edge.get('target').strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
"source": edge.get("source").strip('"'),
"target": edge.get("target").strip('"'),
"weight": float(edge.find("./data[@key='d3']", namespace).text)
if edge.find("./data[@key='d3']", namespace) is not None
else 0.0,
"description": edge.find("./data[@key='d4']", namespace).text
if edge.find("./data[@key='d4']", namespace) is not None
else "",
"keywords": edge.find("./data[@key='d5']", namespace).text
if edge.find("./data[@key='d5']", namespace) is not None
else "",
"source_id": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "",
}
data["edges"].append(edge_data)

View File

@@ -18,8 +18,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result, "context": context}, None
result = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}

View File

@@ -50,8 +50,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result, "context": context}, None
result = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}

View File

@@ -1,16 +1,17 @@
accelerate
aioboto3
aiohttp
graspologic
hnswlib
nano-vectordb
neo4j
networkx
ollama
openai
pyvis
tenacity
tiktoken
torch
transformers
xxhash
pyvis
aiohttp
neo4j
# lmdeploy[all]

View File

@@ -1,39 +1,88 @@
import setuptools
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
from pathlib import Path
vars2find = ["__author__", "__version__", "__url__"]
vars2readme = {}
with open("./lightrag/__init__.py") as f:
for line in f.readlines():
for v in vars2find:
if line.startswith(v):
line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
vars2readme[v] = line.split("=")[1]
# Reading the long description from README.md
def read_long_description():
try:
return Path("README.md").read_text(encoding="utf-8")
except FileNotFoundError:
return "A description of LightRAG is currently unavailable."
deps = []
with open("./requirements.txt") as f:
for line in f.readlines():
if not line.strip():
continue
deps.append(line.strip())
# Retrieving metadata from __init__.py
def retrieve_metadata():
vars2find = ["__author__", "__version__", "__url__"]
vars2readme = {}
try:
with open("./lightrag/__init__.py") as f:
for line in f.readlines():
for v in vars2find:
if line.startswith(v):
line = (
line.replace(" ", "")
.replace('"', "")
.replace("'", "")
.strip()
)
vars2readme[v] = line.split("=")[1]
except FileNotFoundError:
raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
# Checking if all required variables are found
missing_vars = [v for v in vars2find if v not in vars2readme]
if missing_vars:
raise ValueError(
f"Missing required metadata variables in __init__.py: {missing_vars}"
)
return vars2readme
# Reading dependencies from requirements.txt
def read_requirements():
deps = []
try:
with open("./requirements.txt") as f:
deps = [line.strip() for line in f if line.strip()]
except FileNotFoundError:
print(
"Warning: 'requirements.txt' not found. No dependencies will be installed."
)
return deps
metadata = retrieve_metadata()
long_description = read_long_description()
requirements = read_requirements()
setuptools.setup(
name="lightrag-hku",
url=vars2readme["__url__"],
version=vars2readme["__version__"],
author=vars2readme["__author__"],
url=metadata["__url__"],
version=metadata["__version__"],
author=metadata["__author__"],
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
long_description=long_description,
long_description_content_type="text/markdown",
packages=["lightrag"],
packages=setuptools.find_packages(
exclude=("tests*", "docs*")
), # Automatically find packages
classifiers=[
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Intended Audience :: Developers",
"Topic :: Software Development :: Libraries :: Python Modules",
],
python_requires=">=3.9",
install_requires=deps,
install_requires=requirements,
include_package_data=True, # Includes non-code files from MANIFEST.in
project_urls={ # Additional project metadata
"Documentation": metadata.get("__url__", ""),
"Source": metadata.get("__url__", ""),
"Tracker": f"{metadata.get('__url__', '')}/issues"
if metadata.get("__url__")
else "",
},
)