Merge branch 'main' into before-sync-28-10-2024
This commit is contained in:
30
.github/workflows/linting.yaml
vendored
Normal file
30
.github/workflows/linting.yaml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Linting and Formatting
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
lint-and-format:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pre-commit
|
||||
|
||||
- name: Run pre-commit
|
||||
run: pre-commit run --all-files
|
167
README.md
167
README.md
@@ -8,7 +8,7 @@
|
||||
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
||||
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
|
||||
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
|
||||
<a href='https://discord.gg/mvsfu2Tg'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
|
||||
<a href='https://discord.gg/rdE8YVPm'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
|
||||
</p>
|
||||
<p>
|
||||
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
|
||||
@@ -28,6 +28,11 @@ This repository hosts the code of LightRAG. The structure of this code is based
|
||||
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||
|
||||
## Algorithm Flowchart
|
||||
|
||||

|
||||
|
||||
|
||||
## Install
|
||||
|
||||
* Install from source (Recommend)
|
||||
@@ -58,8 +63,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||
|
||||
#########
|
||||
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
#########
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
@@ -157,7 +162,7 @@ rag = LightRAG(
|
||||
|
||||
<details>
|
||||
<summary> Using Ollama Models </summary>
|
||||
|
||||
|
||||
* If you want to use Ollama models, you only need to set LightRAG as follows:
|
||||
|
||||
```python
|
||||
@@ -204,7 +209,25 @@ ollama create -f Modelfile qwen2m
|
||||
|
||||
</details>
|
||||
|
||||
### Query Param
|
||||
|
||||
```python
|
||||
class QueryParam:
|
||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||
only_need_context: bool = False
|
||||
response_type: str = "Multiple Paragraphs"
|
||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||
top_k: int = 60
|
||||
# Number of tokens for the original chunks.
|
||||
max_token_for_text_unit: int = 4000
|
||||
# Number of tokens for the relationship descriptions
|
||||
max_token_for_global_context: int = 4000
|
||||
# Number of tokens for the entity descriptions
|
||||
max_token_for_local_context: int = 4000
|
||||
```
|
||||
|
||||
### Batch Insert
|
||||
|
||||
```python
|
||||
# Batch Insert: Insert multiple texts at once
|
||||
rag.insert(["TEXT1", "TEXT2",...])
|
||||
@@ -214,7 +237,15 @@ rag.insert(["TEXT1", "TEXT2",...])
|
||||
|
||||
```python
|
||||
# Incremental Insert: Insert new documents into an existing LightRAG instance
|
||||
rag = LightRAG(working_dir="./dickens")
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
with open("./newText.txt") as f:
|
||||
rag.insert(f.read())
|
||||
@@ -310,8 +341,8 @@ def main():
|
||||
SET e.entity_type = node.entity_type,
|
||||
e.description = node.description,
|
||||
e.source_id = node.source_id,
|
||||
e.displayName = node.id
|
||||
REMOVE e:Entity
|
||||
e.displayName = node.id
|
||||
REMOVE e:Entity
|
||||
WITH e, node
|
||||
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
|
||||
RETURN count(*)
|
||||
@@ -364,7 +395,7 @@ def main():
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
|
||||
finally:
|
||||
driver.close()
|
||||
|
||||
@@ -374,6 +405,125 @@ if __name__ == "__main__":
|
||||
|
||||
</details>
|
||||
|
||||
## API Server Implementation
|
||||
|
||||
LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
|
||||
|
||||
### Setting up the API Server
|
||||
<details>
|
||||
<summary>Click to expand setup instructions</summary>
|
||||
|
||||
1. First, ensure you have the required dependencies:
|
||||
```bash
|
||||
pip install fastapi uvicorn pydantic
|
||||
```
|
||||
|
||||
2. Set up your environment variables:
|
||||
```bash
|
||||
export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
|
||||
```
|
||||
|
||||
3. Run the API server:
|
||||
```bash
|
||||
python examples/lightrag_api_openai_compatible_demo.py
|
||||
```
|
||||
|
||||
The server will start on `http://0.0.0.0:8020`.
|
||||
</details>
|
||||
|
||||
### API Endpoints
|
||||
|
||||
The API server provides the following endpoints:
|
||||
|
||||
#### 1. Query Endpoint
|
||||
<details>
|
||||
<summary>Click to view Query endpoint details</summary>
|
||||
|
||||
- **URL:** `/query`
|
||||
- **Method:** POST
|
||||
- **Body:**
|
||||
```json
|
||||
{
|
||||
"query": "Your question here",
|
||||
"mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
|
||||
}
|
||||
```
|
||||
- **Example:**
|
||||
```bash
|
||||
curl -X POST "http://127.0.0.1:8020/query" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"query": "What are the main themes?", "mode": "hybrid"}'
|
||||
```
|
||||
</details>
|
||||
|
||||
#### 2. Insert Text Endpoint
|
||||
<details>
|
||||
<summary>Click to view Insert Text endpoint details</summary>
|
||||
|
||||
- **URL:** `/insert`
|
||||
- **Method:** POST
|
||||
- **Body:**
|
||||
```json
|
||||
{
|
||||
"text": "Your text content here"
|
||||
}
|
||||
```
|
||||
- **Example:**
|
||||
```bash
|
||||
curl -X POST "http://127.0.0.1:8020/insert" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text": "Content to be inserted into RAG"}'
|
||||
```
|
||||
</details>
|
||||
|
||||
#### 3. Insert File Endpoint
|
||||
<details>
|
||||
<summary>Click to view Insert File endpoint details</summary>
|
||||
|
||||
- **URL:** `/insert_file`
|
||||
- **Method:** POST
|
||||
- **Body:**
|
||||
```json
|
||||
{
|
||||
"file_path": "path/to/your/file.txt"
|
||||
}
|
||||
```
|
||||
- **Example:**
|
||||
```bash
|
||||
curl -X POST "http://127.0.0.1:8020/insert_file" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"file_path": "./book.txt"}'
|
||||
```
|
||||
</details>
|
||||
|
||||
#### 4. Health Check Endpoint
|
||||
<details>
|
||||
<summary>Click to view Health Check endpoint details</summary>
|
||||
|
||||
- **URL:** `/health`
|
||||
- **Method:** GET
|
||||
- **Example:**
|
||||
```bash
|
||||
curl -X GET "http://127.0.0.1:8020/health"
|
||||
```
|
||||
</details>
|
||||
|
||||
### Configuration
|
||||
|
||||
The API server can be configured using environment variables:
|
||||
- `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
|
||||
- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
|
||||
|
||||
### Error Handling
|
||||
<details>
|
||||
<summary>Click to view error handling details</summary>
|
||||
|
||||
The API includes comprehensive error handling:
|
||||
- File not found errors (404)
|
||||
- Processing errors (500)
|
||||
- Supports multiple file encodings (UTF-8 and GBK)
|
||||
</details>
|
||||
|
||||
## Evaluation
|
||||
### Dataset
|
||||
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
||||
@@ -629,6 +779,7 @@ def extract_queries(file_path):
|
||||
│ ├── lightrag_ollama_demo.py
|
||||
│ ├── lightrag_openai_compatible_demo.py
|
||||
│ ├── lightrag_openai_demo.py
|
||||
│ ├── lightrag_siliconcloud_demo.py
|
||||
│ └── vram_management_demo.py
|
||||
├── lightrag
|
||||
│ ├── __init__.py
|
||||
|
@@ -3,17 +3,17 @@ from pyvis.network import Network
|
||||
import random
|
||||
|
||||
# Load the GraphML file
|
||||
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
|
||||
G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
|
||||
|
||||
# Create a Pyvis network
|
||||
net = Network(notebook=True)
|
||||
net = Network(height="100vh", notebook=True)
|
||||
|
||||
# Convert NetworkX graph to Pyvis network
|
||||
net.from_nx(G)
|
||||
|
||||
# Add colors to nodes
|
||||
for node in net.nodes:
|
||||
node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
|
||||
node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
|
||||
|
||||
# Save and display the network
|
||||
net.show('knowledge_graph.html')
|
||||
net.show("knowledge_graph.html")
|
||||
|
@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
|
||||
NEO4J_USERNAME = "neo4j"
|
||||
NEO4J_PASSWORD = "your_password"
|
||||
|
||||
|
||||
def convert_xml_to_json(xml_path, output_path):
|
||||
"""Converts XML file to JSON and saves the output."""
|
||||
if not os.path.exists(xml_path):
|
||||
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
|
||||
|
||||
json_data = xml_to_json(xml_path)
|
||||
if json_data:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||||
print(f"JSON file created: {output_path}")
|
||||
return json_data
|
||||
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
|
||||
print("Failed to create JSON data")
|
||||
return None
|
||||
|
||||
|
||||
def process_in_batches(tx, query, data, batch_size):
|
||||
"""Process data in batches and execute the given query."""
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i + batch_size]
|
||||
batch = data[i : i + batch_size]
|
||||
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
||||
|
||||
|
||||
def main():
|
||||
# Paths
|
||||
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
|
||||
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
|
||||
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
|
||||
json_file = os.path.join(WORKING_DIR, "graph_data.json")
|
||||
|
||||
# Convert XML to JSON
|
||||
json_data = convert_xml_to_json(xml_file, json_file)
|
||||
@@ -46,8 +49,8 @@ def main():
|
||||
return
|
||||
|
||||
# Load nodes and edges
|
||||
nodes = json_data.get('nodes', [])
|
||||
edges = json_data.get('edges', [])
|
||||
nodes = json_data.get("nodes", [])
|
||||
edges = json_data.get("edges", [])
|
||||
|
||||
# Neo4j queries
|
||||
create_nodes_query = """
|
||||
@@ -56,8 +59,8 @@ def main():
|
||||
SET e.entity_type = node.entity_type,
|
||||
e.description = node.description,
|
||||
e.source_id = node.source_id,
|
||||
e.displayName = node.id
|
||||
REMOVE e:Entity
|
||||
e.displayName = node.id
|
||||
REMOVE e:Entity
|
||||
WITH e, node
|
||||
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
|
||||
RETURN count(*)
|
||||
@@ -100,19 +103,24 @@ def main():
|
||||
# Execute queries in batches
|
||||
with driver.session() as session:
|
||||
# Insert nodes in batches
|
||||
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
|
||||
session.execute_write(
|
||||
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
|
||||
)
|
||||
|
||||
# Insert edges in batches
|
||||
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
|
||||
session.execute_write(
|
||||
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
|
||||
)
|
||||
|
||||
# Set displayName and labels
|
||||
session.run(set_displayname_and_labels_query)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
|
||||
finally:
|
||||
driver.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
164
examples/lightrag_api_openai_compatible_demo.py
Normal file
164
examples/lightrag_api_openai_compatible_demo.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
|
||||
# Apply nest_asyncio to solve event loop issues
|
||||
nest_asyncio.apply()
|
||||
|
||||
DEFAULT_RAG_DIR = "index_default"
|
||||
app = FastAPI(title="LightRAG API", description="API for RAG operations")
|
||||
|
||||
# Configure working directory
|
||||
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
||||
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
# LLM model function
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"gpt-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key="YOUR_API_KEY",
|
||||
base_url="YourURL/v1",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Embedding function
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embedding(
|
||||
texts,
|
||||
model="text-embedding-3-large",
|
||||
api_key="YOUR_API_KEY",
|
||||
base_url="YourURL/v1",
|
||||
)
|
||||
|
||||
|
||||
# Initialize RAG instance
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=3072, max_token_size=8192, func=embedding_func
|
||||
),
|
||||
)
|
||||
|
||||
# Data models
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
mode: str = "hybrid"
|
||||
|
||||
|
||||
class InsertRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class InsertFileRequest(BaseModel):
|
||||
file_path: str
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
status: str
|
||||
data: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
# API routes
|
||||
|
||||
|
||||
@app.post("/query", response_model=Response)
|
||||
async def query_endpoint(request: QueryRequest):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
|
||||
)
|
||||
return Response(status="success", data=result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/insert", response_model=Response)
|
||||
async def insert_endpoint(request: InsertRequest):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
||||
return Response(status="success", message="Text inserted successfully")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/insert_file", response_model=Response)
|
||||
async def insert_file(request: InsertFileRequest):
|
||||
try:
|
||||
# Check if file exists
|
||||
if not os.path.exists(request.file_path):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"File not found: {request.file_path}"
|
||||
)
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
with open(request.file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError:
|
||||
# If UTF-8 decoding fails, try other encodings
|
||||
with open(request.file_path, "r", encoding="gbk") as f:
|
||||
content = f.read()
|
||||
|
||||
# Insert file content
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, lambda: rag.insert(content))
|
||||
|
||||
return Response(
|
||||
status="success",
|
||||
message=f"File content from {request.file_path} inserted successfully",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8020)
|
||||
|
||||
# Usage example
|
||||
# To run the server, use the following command in your terminal:
|
||||
# python lightrag_api_openai_compatible_demo.py
|
||||
|
||||
# Example requests:
|
||||
# 1. Query:
|
||||
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
||||
|
||||
# 2. Insert text:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
||||
|
||||
# 3. Insert file:
|
||||
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
|
||||
|
||||
# 4. Health check:
|
||||
# curl -X GET "http://127.0.0.1:8020/health"
|
@@ -30,7 +30,7 @@ rag = LightRAG(
|
||||
)
|
||||
|
||||
|
||||
with open("./book.txt") as f:
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
|
75
examples/lightrag_lmdeploy_demo.py
Normal file
75
examples/lightrag_lmdeploy_demo.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def lmdeploy_model_complete(
|
||||
prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
return await lmdeploy_model_if_cache(
|
||||
model_name,
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
## please specify chat_template if your local path does not follow original HF file name,
|
||||
## or model_name is a pytorch model on huggingface.co,
|
||||
## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
|
||||
## for a list of chat_template available in lmdeploy.
|
||||
chat_template="llama3",
|
||||
# model_format ='awq', # if you are using awq quantization model.
|
||||
# quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=lmdeploy_model_complete,
|
||||
llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=384,
|
||||
max_token_size=5000,
|
||||
func=lambda texts: hf_embedding(
|
||||
texts,
|
||||
tokenizer=AutoTokenizer.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L6-v2"
|
||||
),
|
||||
embed_model=AutoModel.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L6-v2"
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
@@ -28,7 +28,7 @@ rag = LightRAG(
|
||||
),
|
||||
)
|
||||
|
||||
with open("./book.txt") as f:
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
|
@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
)
|
||||
|
||||
|
||||
async def get_embedding_dim():
|
||||
test_text = ["This is a test sentence."]
|
||||
embedding = await embedding_func(test_text)
|
||||
embedding_dim = embedding.shape[1]
|
||||
return embedding_dim
|
||||
|
||||
|
||||
# function test
|
||||
async def test_funcs():
|
||||
result = await llm_model_func("How are you?")
|
||||
@@ -43,37 +50,59 @@ async def test_funcs():
|
||||
print("embedding_func: ", result)
|
||||
|
||||
|
||||
asyncio.run(test_funcs())
|
||||
# asyncio.run(test_funcs())
|
||||
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
||||
),
|
||||
)
|
||||
async def main():
|
||||
try:
|
||||
embedding_dimension = await get_embedding_dim()
|
||||
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=embedding_dimension,
|
||||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
)
|
||||
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(mode="global"),
|
||||
)
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?",
|
||||
param=QueryParam(mode="hybrid"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
|
||||
with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
@@ -15,7 +15,7 @@ rag = LightRAG(
|
||||
)
|
||||
|
||||
|
||||
with open("./book.txt") as f:
|
||||
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
|
79
examples/lightrag_siliconcloud_demo.py
Normal file
79
examples/lightrag_siliconcloud_demo.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
import numpy as np
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"Qwen/Qwen2.5-7B-Instruct",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
base_url="https://api.siliconflow.cn/v1/",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await siliconcloud_embedding(
|
||||
texts,
|
||||
model="netease-youdao/bce-embedding-base_v1",
|
||||
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||
max_token_size=512,
|
||||
)
|
||||
|
||||
|
||||
# function test
|
||||
async def test_funcs():
|
||||
result = await llm_model_func("How are you?")
|
||||
print("llm_model_func: ", result)
|
||||
|
||||
result = await embedding_func(["How are you?"])
|
||||
print("embedding_func: ", result)
|
||||
|
||||
|
||||
asyncio.run(test_funcs())
|
||||
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=768, max_token_size=512, func=embedding_func
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||
)
|
||||
|
||||
# Perform local search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||
)
|
||||
|
||||
# Perform global search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
print(
|
||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||
)
|
@@ -27,11 +27,12 @@ rag = LightRAG(
|
||||
# Read all .txt files from the TEXT_FILES_DIR directory
|
||||
texts = []
|
||||
for filename in os.listdir(TEXT_FILES_DIR):
|
||||
if filename.endswith('.txt'):
|
||||
if filename.endswith(".txt"):
|
||||
file_path = os.path.join(TEXT_FILES_DIR, filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
texts.append(file.read())
|
||||
|
||||
|
||||
# Batch insert texts into LightRAG with a retry mechanism
|
||||
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
|
||||
for _ in range(retries):
|
||||
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
|
||||
rag.insert(texts)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
|
||||
print(
|
||||
f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
raise RuntimeError("Failed to insert texts after multiple retries.")
|
||||
|
||||
|
||||
insert_texts_with_retry(rag, texts)
|
||||
|
||||
# Perform different types of queries and handle potential errors
|
||||
try:
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error performing naive search: {e}")
|
||||
|
||||
try:
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error performing local search: {e}")
|
||||
|
||||
try:
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error performing global search: {e}")
|
||||
|
||||
try:
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
||||
print(
|
||||
rag.query(
|
||||
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error performing hybrid search: {e}")
|
||||
|
||||
|
||||
# Function to clear VRAM resources
|
||||
def clear_vram():
|
||||
os.system("sudo nvidia-smi --gpu-reset")
|
||||
|
||||
|
||||
# Regularly clear VRAM to prevent overflow
|
||||
clear_vram_interval = 3600 # Clear once every hour
|
||||
start_time = time.time()
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||
|
||||
__version__ = "0.0.7"
|
||||
__version__ = "0.0.8"
|
||||
__author__ = "Zirui Guo"
|
||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||
|
@@ -18,9 +18,13 @@ class QueryParam:
|
||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||
only_need_context: bool = False
|
||||
response_type: str = "Multiple Paragraphs"
|
||||
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||
top_k: int = 60
|
||||
# Number of tokens for the original chunks.
|
||||
max_token_for_text_unit: int = 4000
|
||||
# Number of tokens for the relationship descriptions
|
||||
max_token_for_global_context: int = 4000
|
||||
# Number of tokens for the entity descriptions
|
||||
max_token_for_local_context: int = 4000
|
||||
|
||||
|
||||
|
@@ -209,7 +209,7 @@ class LightRAG:
|
||||
logger.info("[Entity Extraction]...")
|
||||
maybe_new_kg = await extract_entities(
|
||||
inserting_chunks,
|
||||
knwoledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
entity_vdb=self.entities_vdb,
|
||||
relationships_vdb=self.relationships_vdb,
|
||||
global_config=asdict(self),
|
||||
|
382
lightrag/llm.py
382
lightrag/llm.py
@@ -1,10 +1,23 @@
|
||||
import os
|
||||
import copy
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import ollama
|
||||
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
||||
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
AsyncAzureOpenAI,
|
||||
)
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@@ -13,6 +26,8 @@ from tenacity import (
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Callable, Any
|
||||
from .base import BaseKVStorage
|
||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||
|
||||
@@ -62,6 +77,55 @@ async def openai_complete_if_cache(
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def azure_openai_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
**kwargs,
|
||||
):
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert(
|
||||
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
@@ -151,15 +215,25 @@ async def bedrock_complete_if_cache(
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_hf_model(model_name):
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
if hf_tokenizer.pad_token is None:
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
return hf_model, hf_tokenizer
|
||||
|
||||
|
||||
async def hf_model_if_cache(
|
||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
model_name = model
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
||||
if hf_tokenizer.pad_token is None:
|
||||
# print("use eos token")
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
@@ -208,10 +282,13 @@ async def hf_model_if_cache(
|
||||
input_ids = hf_tokenizer(
|
||||
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||
).to("cuda")
|
||||
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
|
||||
output = hf_model.generate(
|
||||
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
|
||||
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
|
||||
)
|
||||
response_text = hf_tokenizer.decode(
|
||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)
|
||||
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
||||
return response_text
|
||||
@@ -249,6 +326,135 @@ async def ollama_model_if_cache(
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def initialize_lmdeploy_pipeline(
|
||||
model,
|
||||
tp=1,
|
||||
chat_template=None,
|
||||
log_level="WARNING",
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
):
|
||||
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
|
||||
|
||||
lmdeploy_pipe = pipeline(
|
||||
model_path=model,
|
||||
backend_config=TurbomindEngineConfig(
|
||||
tp=tp, model_format=model_format, quant_policy=quant_policy
|
||||
),
|
||||
chat_template_config=ChatTemplateConfig(model_name=chat_template)
|
||||
if chat_template
|
||||
else None,
|
||||
log_level="WARNING",
|
||||
)
|
||||
return lmdeploy_pipe
|
||||
|
||||
|
||||
async def lmdeploy_model_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
chat_template=None,
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
model (str): The path to the model.
|
||||
It could be one of the following options:
|
||||
- i) A local directory path of a turbomind model which is
|
||||
converted by `lmdeploy convert` command or download
|
||||
from ii) and iii).
|
||||
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||
inside a model repo on huggingface.co, such as
|
||||
"InternLM/internlm-chat-20b-4bit",
|
||||
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||
- iii) The model_id of a model hosted inside a model repo
|
||||
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||
and so on.
|
||||
chat_template (str): needed when model is a pytorch model on
|
||||
huggingface.co, such as "internlm-chat-7b",
|
||||
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
|
||||
and when the model name of local path did not match the original model name in HF.
|
||||
tp (int): tensor parallel
|
||||
prompt (Union[str, List[str]]): input texts to be completed.
|
||||
do_preprocess (bool): whether pre-process the messages. Default to
|
||||
True, which means chat_template will be applied.
|
||||
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||
in the decoding. Default to be True.
|
||||
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
|
||||
Default to be False, which means greedy decoding will be applied.
|
||||
"""
|
||||
try:
|
||||
import lmdeploy
|
||||
from lmdeploy import version_info, GenerationConfig
|
||||
except Exception:
|
||||
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
|
||||
|
||||
kwargs.pop("response_format", None)
|
||||
max_new_tokens = kwargs.pop("max_tokens", 512)
|
||||
tp = kwargs.pop("tp", 1)
|
||||
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
|
||||
do_preprocess = kwargs.pop("do_preprocess", True)
|
||||
do_sample = kwargs.pop("do_sample", False)
|
||||
gen_params = kwargs
|
||||
|
||||
version = version_info
|
||||
if do_sample is not None and version < (0, 6, 0):
|
||||
raise RuntimeError(
|
||||
"`do_sample` parameter is not supported by lmdeploy until "
|
||||
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
|
||||
)
|
||||
else:
|
||||
do_sample = True
|
||||
gen_params.update(do_sample=do_sample)
|
||||
|
||||
lmdeploy_pipe = initialize_lmdeploy_pipeline(
|
||||
model=model,
|
||||
tp=tp,
|
||||
chat_template=chat_template,
|
||||
model_format=model_format,
|
||||
quant_policy=quant_policy,
|
||||
log_level="WARNING",
|
||||
)
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
if hashing_kv is not None:
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||
if if_cache_return is not None:
|
||||
return if_cache_return["return"]
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**gen_params,
|
||||
)
|
||||
|
||||
response = ""
|
||||
async for res in lmdeploy_pipe.generate(
|
||||
messages,
|
||||
gen_config=gen_config,
|
||||
do_preprocess=do_preprocess,
|
||||
stream_response=False,
|
||||
session_id=1,
|
||||
):
|
||||
response += res.response
|
||||
|
||||
if hashing_kv is not None:
|
||||
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
|
||||
return response
|
||||
|
||||
|
||||
async def gpt_4o_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -273,6 +479,18 @@ async def gpt_4o_mini_complete(
|
||||
)
|
||||
|
||||
|
||||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await azure_openai_complete_if_cache(
|
||||
"conversation-4o-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
@@ -314,7 +532,7 @@ async def ollama_model_complete(
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def openai_embedding(
|
||||
@@ -335,6 +553,73 @@ async def openai_embedding(
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def azure_openai_embedding(
|
||||
texts: list[str],
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="float"
|
||||
)
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||
)
|
||||
async def siliconcloud_embedding(
|
||||
texts: list[str],
|
||||
model: str = "netease-youdao/bce-embedding-base_v1",
|
||||
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
||||
max_token_size: int = 512,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key and not api_key.startswith("Bearer "):
|
||||
api_key = "Bearer " + api_key
|
||||
|
||||
headers = {"Authorization": api_key, "Content-Type": "application/json"}
|
||||
|
||||
truncate_texts = [text[0:max_token_size] for text in texts]
|
||||
|
||||
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
|
||||
|
||||
base64_strings = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||
content = await response.json()
|
||||
if "code" in content:
|
||||
raise ValueError(content)
|
||||
base64_strings = [item["embedding"] for item in content["data"]]
|
||||
|
||||
embeddings = []
|
||||
for string in base64_strings:
|
||||
decode_bytes = base64.b64decode(string)
|
||||
n = len(decode_bytes) // 4
|
||||
float_array = struct.unpack("<" + "f" * n, decode_bytes)
|
||||
embeddings.append(float_array)
|
||||
return np.array(embeddings)
|
||||
|
||||
|
||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(3),
|
||||
@@ -427,6 +712,85 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
|
||||
return embed_text
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
"""
|
||||
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
||||
|
||||
Attributes:
|
||||
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
||||
The function should take any argument and return a string.
|
||||
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
||||
This could include parameters such as the model name, API key, etc.
|
||||
|
||||
Example usage:
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
||||
|
||||
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
||||
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
||||
"""
|
||||
|
||||
gen_func: Callable[[Any], str] = Field(
|
||||
...,
|
||||
description="A function that generates the response from the llm. The response must be a string",
|
||||
)
|
||||
kwargs: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class MultiModel:
|
||||
"""
|
||||
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
||||
Could also be used for spliting across diffrent models or providers.
|
||||
|
||||
Attributes:
|
||||
models (List[Model]): A list of language models to be used.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
models = [
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
||||
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
||||
]
|
||||
multi_model = MultiModel(models)
|
||||
rag = LightRAG(
|
||||
llm_model_func=multi_model.llm_model_func
|
||||
/ ..other args
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, models: List[Model]):
|
||||
self._models = models
|
||||
self._current_model = 0
|
||||
|
||||
def _next_model(self):
|
||||
self._current_model = (self._current_model + 1) % len(self._models)
|
||||
return self._models[self._current_model]
|
||||
|
||||
async def llm_model_func(
|
||||
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||
next_model = self._next_model()
|
||||
args = dict(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
**next_model.kwargs,
|
||||
)
|
||||
|
||||
return await next_model.gen_func(**args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
|
||||
async def _merge_nodes_then_upsert(
|
||||
entity_name: str,
|
||||
nodes_data: list[dict],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
):
|
||||
already_entitiy_types = []
|
||||
already_source_ids = []
|
||||
already_description = []
|
||||
|
||||
already_node = await knwoledge_graph_inst.get_node(entity_name)
|
||||
already_node = await knowledge_graph_inst.get_node(entity_name)
|
||||
if already_node is not None:
|
||||
already_entitiy_types.append(already_node["entity_type"])
|
||||
already_source_ids.extend(
|
||||
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
|
||||
description=description,
|
||||
source_id=source_id,
|
||||
)
|
||||
await knwoledge_graph_inst.upsert_node(
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
entity_name,
|
||||
node_data=node_data,
|
||||
)
|
||||
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
|
||||
src_id: str,
|
||||
tgt_id: str,
|
||||
edges_data: list[dict],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
global_config: dict,
|
||||
):
|
||||
already_weights = []
|
||||
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
|
||||
already_description = []
|
||||
already_keywords = []
|
||||
|
||||
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
|
||||
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
|
||||
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||
already_weights.append(already_edge["weight"])
|
||||
already_source_ids.extend(
|
||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
||||
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
|
||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
||||
)
|
||||
for need_insert_id in [src_id, tgt_id]:
|
||||
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
|
||||
await knwoledge_graph_inst.upsert_node(
|
||||
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||
await knowledge_graph_inst.upsert_node(
|
||||
need_insert_id,
|
||||
node_data={
|
||||
"source_id": source_id,
|
||||
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
|
||||
description = await _handle_entity_relation_summary(
|
||||
(src_id, tgt_id), description, global_config
|
||||
)
|
||||
await knwoledge_graph_inst.upsert_edge(
|
||||
await knowledge_graph_inst.upsert_edge(
|
||||
src_id,
|
||||
tgt_id,
|
||||
edge_data=dict(
|
||||
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
|
||||
|
||||
async def extract_entities(
|
||||
chunks: dict[str, TextChunkSchema],
|
||||
knwoledge_graph_inst: BaseGraphStorage,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entity_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
global_config: dict,
|
||||
@@ -341,13 +341,13 @@ async def extract_entities(
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
all_entities_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
|
||||
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_nodes.items()
|
||||
]
|
||||
)
|
||||
all_relationships_data = await asyncio.gather(
|
||||
*[
|
||||
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
|
||||
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
|
||||
for k, v in maybe_edges.items()
|
||||
]
|
||||
)
|
||||
@@ -384,7 +384,7 @@ async def extract_entities(
|
||||
}
|
||||
await relationships_vdb.upsert(data_for_vdb)
|
||||
|
||||
return knwoledge_graph_inst
|
||||
return knowledge_graph_inst
|
||||
|
||||
|
||||
async def local_query(
|
||||
|
@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def xml_to_json(xml_file):
|
||||
try:
|
||||
tree = ET.parse(xml_file)
|
||||
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
|
||||
print(f"Root element: {root.tag}")
|
||||
print(f"Root attributes: {root.attrib}")
|
||||
|
||||
data = {
|
||||
"nodes": [],
|
||||
"edges": []
|
||||
}
|
||||
data = {"nodes": [], "edges": []}
|
||||
|
||||
# Use namespace
|
||||
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
|
||||
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
||||
|
||||
for node in root.findall('.//node', namespace):
|
||||
for node in root.findall(".//node", namespace):
|
||||
node_data = {
|
||||
"id": node.get('id').strip('"'),
|
||||
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
|
||||
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
|
||||
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
|
||||
"id": node.get("id").strip('"'),
|
||||
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
|
||||
if node.find("./data[@key='d0']", namespace) is not None
|
||||
else "",
|
||||
"description": node.find("./data[@key='d1']", namespace).text
|
||||
if node.find("./data[@key='d1']", namespace) is not None
|
||||
else "",
|
||||
"source_id": node.find("./data[@key='d2']", namespace).text
|
||||
if node.find("./data[@key='d2']", namespace) is not None
|
||||
else "",
|
||||
}
|
||||
data["nodes"].append(node_data)
|
||||
|
||||
for edge in root.findall('.//edge', namespace):
|
||||
for edge in root.findall(".//edge", namespace):
|
||||
edge_data = {
|
||||
"source": edge.get('source').strip('"'),
|
||||
"target": edge.get('target').strip('"'),
|
||||
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
|
||||
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
|
||||
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
|
||||
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
|
||||
"source": edge.get("source").strip('"'),
|
||||
"target": edge.get("target").strip('"'),
|
||||
"weight": float(edge.find("./data[@key='d3']", namespace).text)
|
||||
if edge.find("./data[@key='d3']", namespace) is not None
|
||||
else 0.0,
|
||||
"description": edge.find("./data[@key='d4']", namespace).text
|
||||
if edge.find("./data[@key='d4']", namespace) is not None
|
||||
else "",
|
||||
"keywords": edge.find("./data[@key='d5']", namespace).text
|
||||
if edge.find("./data[@key='d5']", namespace) is not None
|
||||
else "",
|
||||
"source_id": edge.find("./data[@key='d6']", namespace).text
|
||||
if edge.find("./data[@key='d6']", namespace) is not None
|
||||
else "",
|
||||
}
|
||||
data["edges"].append(edge_data)
|
||||
|
||||
|
@@ -50,8 +50,8 @@ def extract_queries(file_path):
|
||||
|
||||
async def process_query(query_text, rag_instance, query_param):
|
||||
try:
|
||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||
return {"query": query_text, "result": result, "context": context}, None
|
||||
result = await rag_instance.aquery(query_text, param=query_param)
|
||||
return {"query": query_text, "result": result}, None
|
||||
except Exception as e:
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
|
@@ -1,14 +1,16 @@
|
||||
accelerate
|
||||
aioboto3
|
||||
aiohttp
|
||||
graspologic
|
||||
hnswlib
|
||||
nano-vectordb
|
||||
networkx
|
||||
ollama
|
||||
openai
|
||||
pyvis
|
||||
tenacity
|
||||
tiktoken
|
||||
torch
|
||||
transformers
|
||||
xxhash
|
||||
pyvis
|
||||
# lmdeploy[all]
|
||||
|
93
setup.py
93
setup.py
@@ -1,39 +1,88 @@
|
||||
import setuptools
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
vars2find = ["__author__", "__version__", "__url__"]
|
||||
vars2readme = {}
|
||||
with open("./lightrag/__init__.py") as f:
|
||||
for line in f.readlines():
|
||||
for v in vars2find:
|
||||
if line.startswith(v):
|
||||
line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
|
||||
vars2readme[v] = line.split("=")[1]
|
||||
# Reading the long description from README.md
|
||||
def read_long_description():
|
||||
try:
|
||||
return Path("README.md").read_text(encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
return "A description of LightRAG is currently unavailable."
|
||||
|
||||
deps = []
|
||||
with open("./requirements.txt") as f:
|
||||
for line in f.readlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
deps.append(line.strip())
|
||||
|
||||
# Retrieving metadata from __init__.py
|
||||
def retrieve_metadata():
|
||||
vars2find = ["__author__", "__version__", "__url__"]
|
||||
vars2readme = {}
|
||||
try:
|
||||
with open("./lightrag/__init__.py") as f:
|
||||
for line in f.readlines():
|
||||
for v in vars2find:
|
||||
if line.startswith(v):
|
||||
line = (
|
||||
line.replace(" ", "")
|
||||
.replace('"', "")
|
||||
.replace("'", "")
|
||||
.strip()
|
||||
)
|
||||
vars2readme[v] = line.split("=")[1]
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
|
||||
|
||||
# Checking if all required variables are found
|
||||
missing_vars = [v for v in vars2find if v not in vars2readme]
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Missing required metadata variables in __init__.py: {missing_vars}"
|
||||
)
|
||||
|
||||
return vars2readme
|
||||
|
||||
|
||||
# Reading dependencies from requirements.txt
|
||||
def read_requirements():
|
||||
deps = []
|
||||
try:
|
||||
with open("./requirements.txt") as f:
|
||||
deps = [line.strip() for line in f if line.strip()]
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
"Warning: 'requirements.txt' not found. No dependencies will be installed."
|
||||
)
|
||||
return deps
|
||||
|
||||
|
||||
metadata = retrieve_metadata()
|
||||
long_description = read_long_description()
|
||||
requirements = read_requirements()
|
||||
|
||||
setuptools.setup(
|
||||
name="lightrag-hku",
|
||||
url=vars2readme["__url__"],
|
||||
version=vars2readme["__version__"],
|
||||
author=vars2readme["__author__"],
|
||||
url=metadata["__url__"],
|
||||
version=metadata["__version__"],
|
||||
author=metadata["__author__"],
|
||||
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
packages=["lightrag"],
|
||||
packages=setuptools.find_packages(
|
||||
exclude=("tests*", "docs*")
|
||||
), # Automatically find packages
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
],
|
||||
python_requires=">=3.9",
|
||||
install_requires=deps,
|
||||
install_requires=requirements,
|
||||
include_package_data=True, # Includes non-code files from MANIFEST.in
|
||||
project_urls={ # Additional project metadata
|
||||
"Documentation": metadata.get("__url__", ""),
|
||||
"Source": metadata.get("__url__", ""),
|
||||
"Tracker": f"{metadata.get('__url__', '')}/issues"
|
||||
if metadata.get("__url__")
|
||||
else "",
|
||||
},
|
||||
)
|
||||
|
Reference in New Issue
Block a user