Merge branch 'main' into before-sync-28-10-2024
This commit is contained in:
30
.github/workflows/linting.yaml
vendored
Normal file
30
.github/workflows/linting.yaml
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
name: Linting and Formatting
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint-and-format:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pre-commit
|
||||||
|
|
||||||
|
- name: Run pre-commit
|
||||||
|
run: pre-commit run --all-files
|
155
README.md
155
README.md
@@ -8,7 +8,7 @@
|
|||||||
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
||||||
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
|
<a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
|
||||||
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
|
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
|
||||||
<a href='https://discord.gg/mvsfu2Tg'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
|
<a href='https://discord.gg/rdE8YVPm'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
|
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
|
||||||
@@ -28,6 +28,11 @@ This repository hosts the code of LightRAG. The structure of this code is based
|
|||||||
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||||
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||||
|
|
||||||
|
## Algorithm Flowchart
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
* Install from source (Recommend)
|
* Install from source (Recommend)
|
||||||
@@ -204,7 +209,25 @@ ollama create -f Modelfile qwen2m
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### Query Param
|
||||||
|
|
||||||
|
```python
|
||||||
|
class QueryParam:
|
||||||
|
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||||
|
only_need_context: bool = False
|
||||||
|
response_type: str = "Multiple Paragraphs"
|
||||||
|
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||||
|
top_k: int = 60
|
||||||
|
# Number of tokens for the original chunks.
|
||||||
|
max_token_for_text_unit: int = 4000
|
||||||
|
# Number of tokens for the relationship descriptions
|
||||||
|
max_token_for_global_context: int = 4000
|
||||||
|
# Number of tokens for the entity descriptions
|
||||||
|
max_token_for_local_context: int = 4000
|
||||||
|
```
|
||||||
|
|
||||||
### Batch Insert
|
### Batch Insert
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Batch Insert: Insert multiple texts at once
|
# Batch Insert: Insert multiple texts at once
|
||||||
rag.insert(["TEXT1", "TEXT2",...])
|
rag.insert(["TEXT1", "TEXT2",...])
|
||||||
@@ -214,7 +237,15 @@ rag.insert(["TEXT1", "TEXT2",...])
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# Incremental Insert: Insert new documents into an existing LightRAG instance
|
# Incremental Insert: Insert new documents into an existing LightRAG instance
|
||||||
rag = LightRAG(working_dir="./dickens")
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=embedding_dimension,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=embedding_func,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
with open("./newText.txt") as f:
|
with open("./newText.txt") as f:
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
@@ -374,6 +405,125 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## API Server Implementation
|
||||||
|
|
||||||
|
LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
|
||||||
|
|
||||||
|
### Setting up the API Server
|
||||||
|
<details>
|
||||||
|
<summary>Click to expand setup instructions</summary>
|
||||||
|
|
||||||
|
1. First, ensure you have the required dependencies:
|
||||||
|
```bash
|
||||||
|
pip install fastapi uvicorn pydantic
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Set up your environment variables:
|
||||||
|
```bash
|
||||||
|
export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the API server:
|
||||||
|
```bash
|
||||||
|
python examples/lightrag_api_openai_compatible_demo.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The server will start on `http://0.0.0.0:8020`.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### API Endpoints
|
||||||
|
|
||||||
|
The API server provides the following endpoints:
|
||||||
|
|
||||||
|
#### 1. Query Endpoint
|
||||||
|
<details>
|
||||||
|
<summary>Click to view Query endpoint details</summary>
|
||||||
|
|
||||||
|
- **URL:** `/query`
|
||||||
|
- **Method:** POST
|
||||||
|
- **Body:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"query": "Your question here",
|
||||||
|
"mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- **Example:**
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://127.0.0.1:8020/query" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"query": "What are the main themes?", "mode": "hybrid"}'
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### 2. Insert Text Endpoint
|
||||||
|
<details>
|
||||||
|
<summary>Click to view Insert Text endpoint details</summary>
|
||||||
|
|
||||||
|
- **URL:** `/insert`
|
||||||
|
- **Method:** POST
|
||||||
|
- **Body:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"text": "Your text content here"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- **Example:**
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://127.0.0.1:8020/insert" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"text": "Content to be inserted into RAG"}'
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### 3. Insert File Endpoint
|
||||||
|
<details>
|
||||||
|
<summary>Click to view Insert File endpoint details</summary>
|
||||||
|
|
||||||
|
- **URL:** `/insert_file`
|
||||||
|
- **Method:** POST
|
||||||
|
- **Body:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"file_path": "path/to/your/file.txt"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
- **Example:**
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://127.0.0.1:8020/insert_file" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"file_path": "./book.txt"}'
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
#### 4. Health Check Endpoint
|
||||||
|
<details>
|
||||||
|
<summary>Click to view Health Check endpoint details</summary>
|
||||||
|
|
||||||
|
- **URL:** `/health`
|
||||||
|
- **Method:** GET
|
||||||
|
- **Example:**
|
||||||
|
```bash
|
||||||
|
curl -X GET "http://127.0.0.1:8020/health"
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
The API server can be configured using environment variables:
|
||||||
|
- `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
|
||||||
|
- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
<details>
|
||||||
|
<summary>Click to view error handling details</summary>
|
||||||
|
|
||||||
|
The API includes comprehensive error handling:
|
||||||
|
- File not found errors (404)
|
||||||
|
- Processing errors (500)
|
||||||
|
- Supports multiple file encodings (UTF-8 and GBK)
|
||||||
|
</details>
|
||||||
|
|
||||||
## Evaluation
|
## Evaluation
|
||||||
### Dataset
|
### Dataset
|
||||||
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
||||||
@@ -629,6 +779,7 @@ def extract_queries(file_path):
|
|||||||
│ ├── lightrag_ollama_demo.py
|
│ ├── lightrag_ollama_demo.py
|
||||||
│ ├── lightrag_openai_compatible_demo.py
|
│ ├── lightrag_openai_compatible_demo.py
|
||||||
│ ├── lightrag_openai_demo.py
|
│ ├── lightrag_openai_demo.py
|
||||||
|
│ ├── lightrag_siliconcloud_demo.py
|
||||||
│ └── vram_management_demo.py
|
│ └── vram_management_demo.py
|
||||||
├── lightrag
|
├── lightrag
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
|
@@ -3,17 +3,17 @@ from pyvis.network import Network
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
# Load the GraphML file
|
# Load the GraphML file
|
||||||
G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
|
G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
|
||||||
|
|
||||||
# Create a Pyvis network
|
# Create a Pyvis network
|
||||||
net = Network(notebook=True)
|
net = Network(height="100vh", notebook=True)
|
||||||
|
|
||||||
# Convert NetworkX graph to Pyvis network
|
# Convert NetworkX graph to Pyvis network
|
||||||
net.from_nx(G)
|
net.from_nx(G)
|
||||||
|
|
||||||
# Add colors to nodes
|
# Add colors to nodes
|
||||||
for node in net.nodes:
|
for node in net.nodes:
|
||||||
node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
|
node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
|
||||||
|
|
||||||
# Save and display the network
|
# Save and display the network
|
||||||
net.show('knowledge_graph.html')
|
net.show("knowledge_graph.html")
|
||||||
|
@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
|
|||||||
NEO4J_USERNAME = "neo4j"
|
NEO4J_USERNAME = "neo4j"
|
||||||
NEO4J_PASSWORD = "your_password"
|
NEO4J_PASSWORD = "your_password"
|
||||||
|
|
||||||
|
|
||||||
def convert_xml_to_json(xml_path, output_path):
|
def convert_xml_to_json(xml_path, output_path):
|
||||||
"""Converts XML file to JSON and saves the output."""
|
"""Converts XML file to JSON and saves the output."""
|
||||||
if not os.path.exists(xml_path):
|
if not os.path.exists(xml_path):
|
||||||
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
|
|||||||
|
|
||||||
json_data = xml_to_json(xml_path)
|
json_data = xml_to_json(xml_path)
|
||||||
if json_data:
|
if json_data:
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||||||
print(f"JSON file created: {output_path}")
|
print(f"JSON file created: {output_path}")
|
||||||
return json_data
|
return json_data
|
||||||
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
|
|||||||
print("Failed to create JSON data")
|
print("Failed to create JSON data")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_in_batches(tx, query, data, batch_size):
|
def process_in_batches(tx, query, data, batch_size):
|
||||||
"""Process data in batches and execute the given query."""
|
"""Process data in batches and execute the given query."""
|
||||||
for i in range(0, len(data), batch_size):
|
for i in range(0, len(data), batch_size):
|
||||||
batch = data[i:i + batch_size]
|
batch = data[i : i + batch_size]
|
||||||
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Paths
|
# Paths
|
||||||
xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
|
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
|
||||||
json_file = os.path.join(WORKING_DIR, 'graph_data.json')
|
json_file = os.path.join(WORKING_DIR, "graph_data.json")
|
||||||
|
|
||||||
# Convert XML to JSON
|
# Convert XML to JSON
|
||||||
json_data = convert_xml_to_json(xml_file, json_file)
|
json_data = convert_xml_to_json(xml_file, json_file)
|
||||||
@@ -46,8 +49,8 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Load nodes and edges
|
# Load nodes and edges
|
||||||
nodes = json_data.get('nodes', [])
|
nodes = json_data.get("nodes", [])
|
||||||
edges = json_data.get('edges', [])
|
edges = json_data.get("edges", [])
|
||||||
|
|
||||||
# Neo4j queries
|
# Neo4j queries
|
||||||
create_nodes_query = """
|
create_nodes_query = """
|
||||||
@@ -100,10 +103,14 @@ def main():
|
|||||||
# Execute queries in batches
|
# Execute queries in batches
|
||||||
with driver.session() as session:
|
with driver.session() as session:
|
||||||
# Insert nodes in batches
|
# Insert nodes in batches
|
||||||
session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
|
session.execute_write(
|
||||||
|
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
|
||||||
|
)
|
||||||
|
|
||||||
# Insert edges in batches
|
# Insert edges in batches
|
||||||
session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
|
session.execute_write(
|
||||||
|
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
|
||||||
|
)
|
||||||
|
|
||||||
# Set displayName and labels
|
# Set displayName and labels
|
||||||
session.run(set_displayname_and_labels_query)
|
session.run(set_displayname_and_labels_query)
|
||||||
@@ -114,5 +121,6 @@ def main():
|
|||||||
finally:
|
finally:
|
||||||
driver.close()
|
driver.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
164
examples/lightrag_api_openai_compatible_demo.py
Normal file
164
examples/lightrag_api_openai_compatible_demo.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import os
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional
|
||||||
|
import asyncio
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
# Apply nest_asyncio to solve event loop issues
|
||||||
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
DEFAULT_RAG_DIR = "index_default"
|
||||||
|
app = FastAPI(title="LightRAG API", description="API for RAG operations")
|
||||||
|
|
||||||
|
# Configure working directory
|
||||||
|
WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
|
||||||
|
print(f"WORKING_DIR: {WORKING_DIR}")
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
# LLM model function
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"gpt-4o-mini",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
base_url="YourURL/v1",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Embedding function
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await openai_embedding(
|
||||||
|
texts,
|
||||||
|
model="text-embedding-3-large",
|
||||||
|
api_key="YOUR_API_KEY",
|
||||||
|
base_url="YourURL/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize RAG instance
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=3072, max_token_size=8192, func=embedding_func
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data models
|
||||||
|
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
mode: str = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
|
class InsertRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class InsertFileRequest(BaseModel):
|
||||||
|
file_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class Response(BaseModel):
|
||||||
|
status: str
|
||||||
|
data: Optional[str] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# API routes
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/query", response_model=Response)
|
||||||
|
async def query_endpoint(request: QueryRequest):
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(
|
||||||
|
None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
|
||||||
|
)
|
||||||
|
return Response(status="success", data=result)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/insert", response_model=Response)
|
||||||
|
async def insert_endpoint(request: InsertRequest):
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, lambda: rag.insert(request.text))
|
||||||
|
return Response(status="success", message="Text inserted successfully")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/insert_file", response_model=Response)
|
||||||
|
async def insert_file(request: InsertFileRequest):
|
||||||
|
try:
|
||||||
|
# Check if file exists
|
||||||
|
if not os.path.exists(request.file_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"File not found: {request.file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read file content
|
||||||
|
try:
|
||||||
|
with open(request.file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# If UTF-8 decoding fails, try other encodings
|
||||||
|
with open(request.file_path, "r", encoding="gbk") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Insert file content
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(None, lambda: rag.insert(content))
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
status="success",
|
||||||
|
message=f"File content from {request.file_path} inserted successfully",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8020)
|
||||||
|
|
||||||
|
# Usage example
|
||||||
|
# To run the server, use the following command in your terminal:
|
||||||
|
# python lightrag_api_openai_compatible_demo.py
|
||||||
|
|
||||||
|
# Example requests:
|
||||||
|
# 1. Query:
|
||||||
|
# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
|
||||||
|
|
||||||
|
# 2. Insert text:
|
||||||
|
# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
|
||||||
|
|
||||||
|
# 3. Insert file:
|
||||||
|
# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
|
||||||
|
|
||||||
|
# 4. Health check:
|
||||||
|
# curl -X GET "http://127.0.0.1:8020/health"
|
@@ -30,7 +30,7 @@ rag = LightRAG(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
with open("./book.txt") as f:
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
|
75
examples/lightrag_lmdeploy_demo.py
Normal file
75
examples/lightrag_lmdeploy_demo.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def lmdeploy_model_complete(
|
||||||
|
prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
|
return await lmdeploy_model_if_cache(
|
||||||
|
model_name,
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
## please specify chat_template if your local path does not follow original HF file name,
|
||||||
|
## or model_name is a pytorch model on huggingface.co,
|
||||||
|
## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
|
||||||
|
## for a list of chat_template available in lmdeploy.
|
||||||
|
chat_template="llama3",
|
||||||
|
# model_format ='awq', # if you are using awq quantization model.
|
||||||
|
# quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=lmdeploy_model_complete,
|
||||||
|
llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=384,
|
||||||
|
max_token_size=5000,
|
||||||
|
func=lambda texts: hf_embedding(
|
||||||
|
texts,
|
||||||
|
tokenizer=AutoTokenizer.from_pretrained(
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
),
|
||||||
|
embed_model=AutoModel.from_pretrained(
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform local search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform global search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform hybrid search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
@@ -28,7 +28,7 @@ rag = LightRAG(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
with open("./book.txt") as f:
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
|
@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embedding_dim():
|
||||||
|
test_text = ["This is a test sentence."]
|
||||||
|
embedding = await embedding_func(test_text)
|
||||||
|
embedding_dim = embedding.shape[1]
|
||||||
|
return embedding_dim
|
||||||
|
|
||||||
|
|
||||||
# function test
|
# function test
|
||||||
async def test_funcs():
|
async def test_funcs():
|
||||||
result = await llm_model_func("How are you?")
|
result = await llm_model_func("How are you?")
|
||||||
@@ -43,37 +50,59 @@ async def test_funcs():
|
|||||||
print("embedding_func: ", result)
|
print("embedding_func: ", result)
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(test_funcs())
|
# asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
|
||||||
rag = LightRAG(
|
async def main():
|
||||||
|
try:
|
||||||
|
embedding_dimension = await get_embedding_dim()
|
||||||
|
print(f"Detected embedding dimension: {embedding_dimension}")
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
working_dir=WORKING_DIR,
|
working_dir=WORKING_DIR,
|
||||||
llm_model_func=llm_model_func,
|
llm_model_func=llm_model_func,
|
||||||
embedding_func=EmbeddingFunc(
|
embedding_func=EmbeddingFunc(
|
||||||
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
embedding_dim=embedding_dimension,
|
||||||
|
max_token_size=8192,
|
||||||
|
func=embedding_func,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
with open("./book.txt") as f:
|
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
print(
|
print(
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
rag.query(
|
||||||
)
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Perform local search
|
# Perform local search
|
||||||
print(
|
print(
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
rag.query(
|
||||||
)
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Perform global search
|
# Perform global search
|
||||||
print(
|
print(
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
rag.query(
|
||||||
)
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="global"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Perform hybrid search
|
# Perform hybrid search
|
||||||
print(
|
print(
|
||||||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
rag.query(
|
||||||
)
|
"What are the top themes in this story?",
|
||||||
|
param=QueryParam(mode="hybrid"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
@@ -15,7 +15,7 @@ rag = LightRAG(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
with open("./book.txt") as f:
|
with open("./book.txt", "r", encoding="utf-8") as f:
|
||||||
rag.insert(f.read())
|
rag.insert(f.read())
|
||||||
|
|
||||||
# Perform naive search
|
# Perform naive search
|
||||||
|
79
examples/lightrag_siliconcloud_demo.py
Normal file
79
examples/lightrag_siliconcloud_demo.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from lightrag import LightRAG, QueryParam
|
||||||
|
from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
|
||||||
|
from lightrag.utils import EmbeddingFunc
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
WORKING_DIR = "./dickens"
|
||||||
|
|
||||||
|
if not os.path.exists(WORKING_DIR):
|
||||||
|
os.mkdir(WORKING_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await openai_complete_if_cache(
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||||
|
base_url="https://api.siliconflow.cn/v1/",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
return await siliconcloud_embedding(
|
||||||
|
texts,
|
||||||
|
model="netease-youdao/bce-embedding-base_v1",
|
||||||
|
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
||||||
|
max_token_size=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# function test
|
||||||
|
async def test_funcs():
|
||||||
|
result = await llm_model_func("How are you?")
|
||||||
|
print("llm_model_func: ", result)
|
||||||
|
|
||||||
|
result = await embedding_func(["How are you?"])
|
||||||
|
print("embedding_func: ", result)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(test_funcs())
|
||||||
|
|
||||||
|
|
||||||
|
rag = LightRAG(
|
||||||
|
working_dir=WORKING_DIR,
|
||||||
|
llm_model_func=llm_model_func,
|
||||||
|
embedding_func=EmbeddingFunc(
|
||||||
|
embedding_dim=768, max_token_size=512, func=embedding_func
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
with open("./book.txt") as f:
|
||||||
|
rag.insert(f.read())
|
||||||
|
|
||||||
|
# Perform naive search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform local search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform global search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform hybrid search
|
||||||
|
print(
|
||||||
|
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
||||||
|
)
|
@@ -27,11 +27,12 @@ rag = LightRAG(
|
|||||||
# Read all .txt files from the TEXT_FILES_DIR directory
|
# Read all .txt files from the TEXT_FILES_DIR directory
|
||||||
texts = []
|
texts = []
|
||||||
for filename in os.listdir(TEXT_FILES_DIR):
|
for filename in os.listdir(TEXT_FILES_DIR):
|
||||||
if filename.endswith('.txt'):
|
if filename.endswith(".txt"):
|
||||||
file_path = os.path.join(TEXT_FILES_DIR, filename)
|
file_path = os.path.join(TEXT_FILES_DIR, filename)
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
texts.append(file.read())
|
texts.append(file.read())
|
||||||
|
|
||||||
|
|
||||||
# Batch insert texts into LightRAG with a retry mechanism
|
# Batch insert texts into LightRAG with a retry mechanism
|
||||||
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
|
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
|
||||||
for _ in range(retries):
|
for _ in range(retries):
|
||||||
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
|
|||||||
rag.insert(texts)
|
rag.insert(texts)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
|
print(
|
||||||
|
f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
|
||||||
|
)
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
raise RuntimeError("Failed to insert texts after multiple retries.")
|
raise RuntimeError("Failed to insert texts after multiple retries.")
|
||||||
|
|
||||||
|
|
||||||
insert_texts_with_retry(rag, texts)
|
insert_texts_with_retry(rag, texts)
|
||||||
|
|
||||||
# Perform different types of queries and handle potential errors
|
# Perform different types of queries and handle potential errors
|
||||||
try:
|
try:
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error performing naive search: {e}")
|
print(f"Error performing naive search: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="local")
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error performing local search: {e}")
|
print(f"Error performing local search: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="global")
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error performing global search: {e}")
|
print(f"Error performing global search: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
print(
|
||||||
|
rag.query(
|
||||||
|
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error performing hybrid search: {e}")
|
print(f"Error performing hybrid search: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Function to clear VRAM resources
|
# Function to clear VRAM resources
|
||||||
def clear_vram():
|
def clear_vram():
|
||||||
os.system("sudo nvidia-smi --gpu-reset")
|
os.system("sudo nvidia-smi --gpu-reset")
|
||||||
|
|
||||||
|
|
||||||
# Regularly clear VRAM to prevent overflow
|
# Regularly clear VRAM to prevent overflow
|
||||||
clear_vram_interval = 3600 # Clear once every hour
|
clear_vram_interval = 3600 # Clear once every hour
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
||||||
|
|
||||||
__version__ = "0.0.7"
|
__version__ = "0.0.8"
|
||||||
__author__ = "Zirui Guo"
|
__author__ = "Zirui Guo"
|
||||||
__url__ = "https://github.com/HKUDS/LightRAG"
|
__url__ = "https://github.com/HKUDS/LightRAG"
|
||||||
|
@@ -18,9 +18,13 @@ class QueryParam:
|
|||||||
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
||||||
only_need_context: bool = False
|
only_need_context: bool = False
|
||||||
response_type: str = "Multiple Paragraphs"
|
response_type: str = "Multiple Paragraphs"
|
||||||
|
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
||||||
top_k: int = 60
|
top_k: int = 60
|
||||||
|
# Number of tokens for the original chunks.
|
||||||
max_token_for_text_unit: int = 4000
|
max_token_for_text_unit: int = 4000
|
||||||
|
# Number of tokens for the relationship descriptions
|
||||||
max_token_for_global_context: int = 4000
|
max_token_for_global_context: int = 4000
|
||||||
|
# Number of tokens for the entity descriptions
|
||||||
max_token_for_local_context: int = 4000
|
max_token_for_local_context: int = 4000
|
||||||
|
|
||||||
|
|
||||||
|
@@ -209,7 +209,7 @@ class LightRAG:
|
|||||||
logger.info("[Entity Extraction]...")
|
logger.info("[Entity Extraction]...")
|
||||||
maybe_new_kg = await extract_entities(
|
maybe_new_kg = await extract_entities(
|
||||||
inserting_chunks,
|
inserting_chunks,
|
||||||
knwoledge_graph_inst=self.chunk_entity_relation_graph,
|
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||||
entity_vdb=self.entities_vdb,
|
entity_vdb=self.entities_vdb,
|
||||||
relationships_vdb=self.relationships_vdb,
|
relationships_vdb=self.relationships_vdb,
|
||||||
global_config=asdict(self),
|
global_config=asdict(self),
|
||||||
|
382
lightrag/llm.py
382
lightrag/llm.py
@@ -1,10 +1,23 @@
|
|||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
|
from functools import lru_cache
|
||||||
import json
|
import json
|
||||||
import aioboto3
|
import aioboto3
|
||||||
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ollama
|
import ollama
|
||||||
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
|
||||||
|
from openai import (
|
||||||
|
AsyncOpenAI,
|
||||||
|
APIConnectionError,
|
||||||
|
RateLimitError,
|
||||||
|
Timeout,
|
||||||
|
AsyncAzureOpenAI,
|
||||||
|
)
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
@@ -13,6 +26,8 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Dict, Callable, Any
|
||||||
from .base import BaseKVStorage
|
from .base import BaseKVStorage
|
||||||
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
||||||
|
|
||||||
@@ -62,6 +77,55 @@ async def openai_complete_if_cache(
|
|||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
|
)
|
||||||
|
async def azure_openai_complete_if_cache(
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
base_url=None,
|
||||||
|
api_key=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if api_key:
|
||||||
|
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||||
|
if base_url:
|
||||||
|
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||||
|
|
||||||
|
openai_async_client = AsyncAzureOpenAI(
|
||||||
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||||
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||||
|
)
|
||||||
|
|
||||||
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.extend(history_messages)
|
||||||
|
if prompt is not None:
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
if hashing_kv is not None:
|
||||||
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
|
if if_cache_return is not None:
|
||||||
|
return if_cache_return["return"]
|
||||||
|
|
||||||
|
response = await openai_async_client.chat.completions.create(
|
||||||
|
model=model, messages=messages, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if hashing_kv is not None:
|
||||||
|
await hashing_kv.upsert(
|
||||||
|
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
class BedrockError(Exception):
|
class BedrockError(Exception):
|
||||||
"""Generic error for issues related to Amazon Bedrock"""
|
"""Generic error for issues related to Amazon Bedrock"""
|
||||||
|
|
||||||
@@ -151,15 +215,25 @@ async def bedrock_complete_if_cache(
|
|||||||
return response["output"]["message"]["content"][0]["text"]
|
return response["output"]["message"]["content"][0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def initialize_hf_model(model_name):
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
if hf_tokenizer.pad_token is None:
|
||||||
|
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||||
|
|
||||||
|
return hf_model, hf_tokenizer
|
||||||
|
|
||||||
|
|
||||||
async def hf_model_if_cache(
|
async def hf_model_if_cache(
|
||||||
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
model_name = model
|
model_name = model
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||||
if hf_tokenizer.pad_token is None:
|
|
||||||
# print("use eos token")
|
|
||||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
|
||||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
|
||||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
messages = []
|
messages = []
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
@@ -208,10 +282,13 @@ async def hf_model_if_cache(
|
|||||||
input_ids = hf_tokenizer(
|
input_ids = hf_tokenizer(
|
||||||
input_prompt, return_tensors="pt", padding=True, truncation=True
|
input_prompt, return_tensors="pt", padding=True, truncation=True
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
|
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
|
||||||
output = hf_model.generate(
|
output = hf_model.generate(
|
||||||
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
|
**input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
|
||||||
|
)
|
||||||
|
response_text = hf_tokenizer.decode(
|
||||||
|
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||||
)
|
)
|
||||||
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
|
||||||
if hashing_kv is not None:
|
if hashing_kv is not None:
|
||||||
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
||||||
return response_text
|
return response_text
|
||||||
@@ -249,6 +326,135 @@ async def ollama_model_if_cache(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def initialize_lmdeploy_pipeline(
|
||||||
|
model,
|
||||||
|
tp=1,
|
||||||
|
chat_template=None,
|
||||||
|
log_level="WARNING",
|
||||||
|
model_format="hf",
|
||||||
|
quant_policy=0,
|
||||||
|
):
|
||||||
|
from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
|
||||||
|
|
||||||
|
lmdeploy_pipe = pipeline(
|
||||||
|
model_path=model,
|
||||||
|
backend_config=TurbomindEngineConfig(
|
||||||
|
tp=tp, model_format=model_format, quant_policy=quant_policy
|
||||||
|
),
|
||||||
|
chat_template_config=ChatTemplateConfig(model_name=chat_template)
|
||||||
|
if chat_template
|
||||||
|
else None,
|
||||||
|
log_level="WARNING",
|
||||||
|
)
|
||||||
|
return lmdeploy_pipe
|
||||||
|
|
||||||
|
|
||||||
|
async def lmdeploy_model_if_cache(
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
chat_template=None,
|
||||||
|
model_format="hf",
|
||||||
|
quant_policy=0,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model (str): The path to the model.
|
||||||
|
It could be one of the following options:
|
||||||
|
- i) A local directory path of a turbomind model which is
|
||||||
|
converted by `lmdeploy convert` command or download
|
||||||
|
from ii) and iii).
|
||||||
|
- ii) The model_id of a lmdeploy-quantized model hosted
|
||||||
|
inside a model repo on huggingface.co, such as
|
||||||
|
"InternLM/internlm-chat-20b-4bit",
|
||||||
|
"lmdeploy/llama2-chat-70b-4bit", etc.
|
||||||
|
- iii) The model_id of a model hosted inside a model repo
|
||||||
|
on huggingface.co, such as "internlm/internlm-chat-7b",
|
||||||
|
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
||||||
|
and so on.
|
||||||
|
chat_template (str): needed when model is a pytorch model on
|
||||||
|
huggingface.co, such as "internlm-chat-7b",
|
||||||
|
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
|
||||||
|
and when the model name of local path did not match the original model name in HF.
|
||||||
|
tp (int): tensor parallel
|
||||||
|
prompt (Union[str, List[str]]): input texts to be completed.
|
||||||
|
do_preprocess (bool): whether pre-process the messages. Default to
|
||||||
|
True, which means chat_template will be applied.
|
||||||
|
skip_special_tokens (bool): Whether or not to remove special tokens
|
||||||
|
in the decoding. Default to be True.
|
||||||
|
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
|
||||||
|
Default to be False, which means greedy decoding will be applied.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import lmdeploy
|
||||||
|
from lmdeploy import version_info, GenerationConfig
|
||||||
|
except Exception:
|
||||||
|
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
|
||||||
|
|
||||||
|
kwargs.pop("response_format", None)
|
||||||
|
max_new_tokens = kwargs.pop("max_tokens", 512)
|
||||||
|
tp = kwargs.pop("tp", 1)
|
||||||
|
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
|
||||||
|
do_preprocess = kwargs.pop("do_preprocess", True)
|
||||||
|
do_sample = kwargs.pop("do_sample", False)
|
||||||
|
gen_params = kwargs
|
||||||
|
|
||||||
|
version = version_info
|
||||||
|
if do_sample is not None and version < (0, 6, 0):
|
||||||
|
raise RuntimeError(
|
||||||
|
"`do_sample` parameter is not supported by lmdeploy until "
|
||||||
|
f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
do_sample = True
|
||||||
|
gen_params.update(do_sample=do_sample)
|
||||||
|
|
||||||
|
lmdeploy_pipe = initialize_lmdeploy_pipeline(
|
||||||
|
model=model,
|
||||||
|
tp=tp,
|
||||||
|
chat_template=chat_template,
|
||||||
|
model_format=model_format,
|
||||||
|
quant_policy=quant_policy,
|
||||||
|
log_level="WARNING",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||||
|
messages.extend(history_messages)
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
if hashing_kv is not None:
|
||||||
|
args_hash = compute_args_hash(model, messages)
|
||||||
|
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
||||||
|
if if_cache_return is not None:
|
||||||
|
return if_cache_return["return"]
|
||||||
|
|
||||||
|
gen_config = GenerationConfig(
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
**gen_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
async for res in lmdeploy_pipe.generate(
|
||||||
|
messages,
|
||||||
|
gen_config=gen_config,
|
||||||
|
do_preprocess=do_preprocess,
|
||||||
|
stream_response=False,
|
||||||
|
session_id=1,
|
||||||
|
):
|
||||||
|
response += res.response
|
||||||
|
|
||||||
|
if hashing_kv is not None:
|
||||||
|
await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def gpt_4o_complete(
|
async def gpt_4o_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -273,6 +479,18 @@ async def gpt_4o_mini_complete(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def azure_openai_complete(
|
||||||
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
return await azure_openai_complete_if_cache(
|
||||||
|
"conversation-4o-mini",
|
||||||
|
prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def bedrock_complete(
|
async def bedrock_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -314,7 +532,7 @@ async def ollama_model_complete(
|
|||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
)
|
)
|
||||||
async def openai_embedding(
|
async def openai_embedding(
|
||||||
@@ -335,6 +553,73 @@ async def openai_embedding(
|
|||||||
return np.array([dp.embedding for dp in response.data])
|
return np.array([dp.embedding for dp in response.data])
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
|
)
|
||||||
|
async def azure_openai_embedding(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
base_url: str = None,
|
||||||
|
api_key: str = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if api_key:
|
||||||
|
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||||
|
if base_url:
|
||||||
|
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||||
|
|
||||||
|
openai_async_client = AsyncAzureOpenAI(
|
||||||
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||||
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||||
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_async_client.embeddings.create(
|
||||||
|
model=model, input=texts, encoding_format="float"
|
||||||
|
)
|
||||||
|
return np.array([dp.embedding for dp in response.data])
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
||||||
|
)
|
||||||
|
async def siliconcloud_embedding(
|
||||||
|
texts: list[str],
|
||||||
|
model: str = "netease-youdao/bce-embedding-base_v1",
|
||||||
|
base_url: str = "https://api.siliconflow.cn/v1/embeddings",
|
||||||
|
max_token_size: int = 512,
|
||||||
|
api_key: str = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if api_key and not api_key.startswith("Bearer "):
|
||||||
|
api_key = "Bearer " + api_key
|
||||||
|
|
||||||
|
headers = {"Authorization": api_key, "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
truncate_texts = [text[0:max_token_size] for text in texts]
|
||||||
|
|
||||||
|
payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
|
||||||
|
|
||||||
|
base64_strings = []
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||||
|
content = await response.json()
|
||||||
|
if "code" in content:
|
||||||
|
raise ValueError(content)
|
||||||
|
base64_strings = [item["embedding"] for item in content["data"]]
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for string in base64_strings:
|
||||||
|
decode_bytes = base64.b64decode(string)
|
||||||
|
n = len(decode_bytes) // 4
|
||||||
|
float_array = struct.unpack("<" + "f" * n, decode_bytes)
|
||||||
|
embeddings.append(float_array)
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
|
||||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||||
# @retry(
|
# @retry(
|
||||||
# stop=stop_after_attempt(3),
|
# stop=stop_after_attempt(3),
|
||||||
@@ -427,6 +712,85 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
|
|||||||
return embed_text
|
return embed_text
|
||||||
|
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
"""
|
||||||
|
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
||||||
|
The function should take any argument and return a string.
|
||||||
|
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
||||||
|
This could include parameters such as the model name, API key, etc.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
||||||
|
|
||||||
|
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
||||||
|
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
gen_func: Callable[[Any], str] = Field(
|
||||||
|
...,
|
||||||
|
description="A function that generates the response from the llm. The response must be a string",
|
||||||
|
)
|
||||||
|
kwargs: Dict[str, Any] = Field(
|
||||||
|
...,
|
||||||
|
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModel:
|
||||||
|
"""
|
||||||
|
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
||||||
|
Could also be used for spliting across diffrent models or providers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
models (List[Model]): A list of language models to be used.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
```python
|
||||||
|
models = [
|
||||||
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
||||||
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
||||||
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
||||||
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
||||||
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
||||||
|
]
|
||||||
|
multi_model = MultiModel(models)
|
||||||
|
rag = LightRAG(
|
||||||
|
llm_model_func=multi_model.llm_model_func
|
||||||
|
/ ..other args
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, models: List[Model]):
|
||||||
|
self._models = models
|
||||||
|
self._current_model = 0
|
||||||
|
|
||||||
|
def _next_model(self):
|
||||||
|
self._current_model = (self._current_model + 1) % len(self._models)
|
||||||
|
return self._models[self._current_model]
|
||||||
|
|
||||||
|
async def llm_model_func(
|
||||||
|
self, prompt, system_prompt=None, history_messages=[], **kwargs
|
||||||
|
) -> str:
|
||||||
|
kwargs.pop("model", None) # stop from overwriting the custom model name
|
||||||
|
next_model = self._next_model()
|
||||||
|
args = dict(
|
||||||
|
prompt=prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history_messages=history_messages,
|
||||||
|
**kwargs,
|
||||||
|
**next_model.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await next_model.gen_func(**args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
|
|||||||
async def _merge_nodes_then_upsert(
|
async def _merge_nodes_then_upsert(
|
||||||
entity_name: str,
|
entity_name: str,
|
||||||
nodes_data: list[dict],
|
nodes_data: list[dict],
|
||||||
knwoledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
):
|
):
|
||||||
already_entitiy_types = []
|
already_entitiy_types = []
|
||||||
already_source_ids = []
|
already_source_ids = []
|
||||||
already_description = []
|
already_description = []
|
||||||
|
|
||||||
already_node = await knwoledge_graph_inst.get_node(entity_name)
|
already_node = await knowledge_graph_inst.get_node(entity_name)
|
||||||
if already_node is not None:
|
if already_node is not None:
|
||||||
already_entitiy_types.append(already_node["entity_type"])
|
already_entitiy_types.append(already_node["entity_type"])
|
||||||
already_source_ids.extend(
|
already_source_ids.extend(
|
||||||
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
|
|||||||
description=description,
|
description=description,
|
||||||
source_id=source_id,
|
source_id=source_id,
|
||||||
)
|
)
|
||||||
await knwoledge_graph_inst.upsert_node(
|
await knowledge_graph_inst.upsert_node(
|
||||||
entity_name,
|
entity_name,
|
||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
)
|
)
|
||||||
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
|
|||||||
src_id: str,
|
src_id: str,
|
||||||
tgt_id: str,
|
tgt_id: str,
|
||||||
edges_data: list[dict],
|
edges_data: list[dict],
|
||||||
knwoledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
):
|
):
|
||||||
already_weights = []
|
already_weights = []
|
||||||
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
|
|||||||
already_description = []
|
already_description = []
|
||||||
already_keywords = []
|
already_keywords = []
|
||||||
|
|
||||||
if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
|
if await knowledge_graph_inst.has_edge(src_id, tgt_id):
|
||||||
already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
|
already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
|
||||||
already_weights.append(already_edge["weight"])
|
already_weights.append(already_edge["weight"])
|
||||||
already_source_ids.extend(
|
already_source_ids.extend(
|
||||||
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
|
||||||
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
|
|||||||
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
set([dp["source_id"] for dp in edges_data] + already_source_ids)
|
||||||
)
|
)
|
||||||
for need_insert_id in [src_id, tgt_id]:
|
for need_insert_id in [src_id, tgt_id]:
|
||||||
if not (await knwoledge_graph_inst.has_node(need_insert_id)):
|
if not (await knowledge_graph_inst.has_node(need_insert_id)):
|
||||||
await knwoledge_graph_inst.upsert_node(
|
await knowledge_graph_inst.upsert_node(
|
||||||
need_insert_id,
|
need_insert_id,
|
||||||
node_data={
|
node_data={
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
|
|||||||
description = await _handle_entity_relation_summary(
|
description = await _handle_entity_relation_summary(
|
||||||
(src_id, tgt_id), description, global_config
|
(src_id, tgt_id), description, global_config
|
||||||
)
|
)
|
||||||
await knwoledge_graph_inst.upsert_edge(
|
await knowledge_graph_inst.upsert_edge(
|
||||||
src_id,
|
src_id,
|
||||||
tgt_id,
|
tgt_id,
|
||||||
edge_data=dict(
|
edge_data=dict(
|
||||||
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
|
|||||||
|
|
||||||
async def extract_entities(
|
async def extract_entities(
|
||||||
chunks: dict[str, TextChunkSchema],
|
chunks: dict[str, TextChunkSchema],
|
||||||
knwoledge_graph_inst: BaseGraphStorage,
|
knowledge_graph_inst: BaseGraphStorage,
|
||||||
entity_vdb: BaseVectorStorage,
|
entity_vdb: BaseVectorStorage,
|
||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
@@ -341,13 +341,13 @@ async def extract_entities(
|
|||||||
maybe_edges[tuple(sorted(k))].extend(v)
|
maybe_edges[tuple(sorted(k))].extend(v)
|
||||||
all_entities_data = await asyncio.gather(
|
all_entities_data = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
|
_merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
|
||||||
for k, v in maybe_nodes.items()
|
for k, v in maybe_nodes.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
all_relationships_data = await asyncio.gather(
|
all_relationships_data = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
|
_merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
|
||||||
for k, v in maybe_edges.items()
|
for k, v in maybe_edges.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -384,7 +384,7 @@ async def extract_entities(
|
|||||||
}
|
}
|
||||||
await relationships_vdb.upsert(data_for_vdb)
|
await relationships_vdb.upsert(data_for_vdb)
|
||||||
|
|
||||||
return knwoledge_graph_inst
|
return knowledge_graph_inst
|
||||||
|
|
||||||
|
|
||||||
async def local_query(
|
async def local_query(
|
||||||
|
@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
|
|||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def xml_to_json(xml_file):
|
def xml_to_json(xml_file):
|
||||||
try:
|
try:
|
||||||
tree = ET.parse(xml_file)
|
tree = ET.parse(xml_file)
|
||||||
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
|
|||||||
print(f"Root element: {root.tag}")
|
print(f"Root element: {root.tag}")
|
||||||
print(f"Root attributes: {root.attrib}")
|
print(f"Root attributes: {root.attrib}")
|
||||||
|
|
||||||
data = {
|
data = {"nodes": [], "edges": []}
|
||||||
"nodes": [],
|
|
||||||
"edges": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Use namespace
|
# Use namespace
|
||||||
namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
|
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
||||||
|
|
||||||
for node in root.findall('.//node', namespace):
|
for node in root.findall(".//node", namespace):
|
||||||
node_data = {
|
node_data = {
|
||||||
"id": node.get('id').strip('"'),
|
"id": node.get("id").strip('"'),
|
||||||
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
|
"entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
|
||||||
"description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
|
if node.find("./data[@key='d0']", namespace) is not None
|
||||||
"source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
|
else "",
|
||||||
|
"description": node.find("./data[@key='d1']", namespace).text
|
||||||
|
if node.find("./data[@key='d1']", namespace) is not None
|
||||||
|
else "",
|
||||||
|
"source_id": node.find("./data[@key='d2']", namespace).text
|
||||||
|
if node.find("./data[@key='d2']", namespace) is not None
|
||||||
|
else "",
|
||||||
}
|
}
|
||||||
data["nodes"].append(node_data)
|
data["nodes"].append(node_data)
|
||||||
|
|
||||||
for edge in root.findall('.//edge', namespace):
|
for edge in root.findall(".//edge", namespace):
|
||||||
edge_data = {
|
edge_data = {
|
||||||
"source": edge.get('source').strip('"'),
|
"source": edge.get("source").strip('"'),
|
||||||
"target": edge.get('target').strip('"'),
|
"target": edge.get("target").strip('"'),
|
||||||
"weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
|
"weight": float(edge.find("./data[@key='d3']", namespace).text)
|
||||||
"description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
|
if edge.find("./data[@key='d3']", namespace) is not None
|
||||||
"keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
|
else 0.0,
|
||||||
"source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
|
"description": edge.find("./data[@key='d4']", namespace).text
|
||||||
|
if edge.find("./data[@key='d4']", namespace) is not None
|
||||||
|
else "",
|
||||||
|
"keywords": edge.find("./data[@key='d5']", namespace).text
|
||||||
|
if edge.find("./data[@key='d5']", namespace) is not None
|
||||||
|
else "",
|
||||||
|
"source_id": edge.find("./data[@key='d6']", namespace).text
|
||||||
|
if edge.find("./data[@key='d6']", namespace) is not None
|
||||||
|
else "",
|
||||||
}
|
}
|
||||||
data["edges"].append(edge_data)
|
data["edges"].append(edge_data)
|
||||||
|
|
||||||
|
@@ -50,8 +50,8 @@ def extract_queries(file_path):
|
|||||||
|
|
||||||
async def process_query(query_text, rag_instance, query_param):
|
async def process_query(query_text, rag_instance, query_param):
|
||||||
try:
|
try:
|
||||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
result = await rag_instance.aquery(query_text, param=query_param)
|
||||||
return {"query": query_text, "result": result, "context": context}, None
|
return {"query": query_text, "result": result}, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None, {"query": query_text, "error": str(e)}
|
return None, {"query": query_text, "error": str(e)}
|
||||||
|
|
||||||
|
@@ -1,14 +1,16 @@
|
|||||||
accelerate
|
accelerate
|
||||||
aioboto3
|
aioboto3
|
||||||
|
aiohttp
|
||||||
graspologic
|
graspologic
|
||||||
hnswlib
|
hnswlib
|
||||||
nano-vectordb
|
nano-vectordb
|
||||||
networkx
|
networkx
|
||||||
ollama
|
ollama
|
||||||
openai
|
openai
|
||||||
|
pyvis
|
||||||
tenacity
|
tenacity
|
||||||
tiktoken
|
tiktoken
|
||||||
torch
|
torch
|
||||||
transformers
|
transformers
|
||||||
xxhash
|
xxhash
|
||||||
pyvis
|
# lmdeploy[all]
|
||||||
|
85
setup.py
85
setup.py
@@ -1,39 +1,88 @@
|
|||||||
import setuptools
|
import setuptools
|
||||||
|
from pathlib import Path
|
||||||
with open("README.md", "r", encoding="utf-8") as fh:
|
|
||||||
long_description = fh.read()
|
|
||||||
|
|
||||||
|
|
||||||
vars2find = ["__author__", "__version__", "__url__"]
|
# Reading the long description from README.md
|
||||||
vars2readme = {}
|
def read_long_description():
|
||||||
with open("./lightrag/__init__.py") as f:
|
try:
|
||||||
|
return Path("README.md").read_text(encoding="utf-8")
|
||||||
|
except FileNotFoundError:
|
||||||
|
return "A description of LightRAG is currently unavailable."
|
||||||
|
|
||||||
|
|
||||||
|
# Retrieving metadata from __init__.py
|
||||||
|
def retrieve_metadata():
|
||||||
|
vars2find = ["__author__", "__version__", "__url__"]
|
||||||
|
vars2readme = {}
|
||||||
|
try:
|
||||||
|
with open("./lightrag/__init__.py") as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
for v in vars2find:
|
for v in vars2find:
|
||||||
if line.startswith(v):
|
if line.startswith(v):
|
||||||
line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
|
line = (
|
||||||
|
line.replace(" ", "")
|
||||||
|
.replace('"', "")
|
||||||
|
.replace("'", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
vars2readme[v] = line.split("=")[1]
|
vars2readme[v] = line.split("=")[1]
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
|
||||||
|
|
||||||
deps = []
|
# Checking if all required variables are found
|
||||||
with open("./requirements.txt") as f:
|
missing_vars = [v for v in vars2find if v not in vars2readme]
|
||||||
for line in f.readlines():
|
if missing_vars:
|
||||||
if not line.strip():
|
raise ValueError(
|
||||||
continue
|
f"Missing required metadata variables in __init__.py: {missing_vars}"
|
||||||
deps.append(line.strip())
|
)
|
||||||
|
|
||||||
|
return vars2readme
|
||||||
|
|
||||||
|
|
||||||
|
# Reading dependencies from requirements.txt
|
||||||
|
def read_requirements():
|
||||||
|
deps = []
|
||||||
|
try:
|
||||||
|
with open("./requirements.txt") as f:
|
||||||
|
deps = [line.strip() for line in f if line.strip()]
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(
|
||||||
|
"Warning: 'requirements.txt' not found. No dependencies will be installed."
|
||||||
|
)
|
||||||
|
return deps
|
||||||
|
|
||||||
|
|
||||||
|
metadata = retrieve_metadata()
|
||||||
|
long_description = read_long_description()
|
||||||
|
requirements = read_requirements()
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="lightrag-hku",
|
name="lightrag-hku",
|
||||||
url=vars2readme["__url__"],
|
url=metadata["__url__"],
|
||||||
version=vars2readme["__version__"],
|
version=metadata["__version__"],
|
||||||
author=vars2readme["__author__"],
|
author=metadata["__author__"],
|
||||||
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
|
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
packages=["lightrag"],
|
packages=setuptools.find_packages(
|
||||||
|
exclude=("tests*", "docs*")
|
||||||
|
), # Automatically find packages
|
||||||
classifiers=[
|
classifiers=[
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
],
|
],
|
||||||
python_requires=">=3.9",
|
python_requires=">=3.9",
|
||||||
install_requires=deps,
|
install_requires=requirements,
|
||||||
|
include_package_data=True, # Includes non-code files from MANIFEST.in
|
||||||
|
project_urls={ # Additional project metadata
|
||||||
|
"Documentation": metadata.get("__url__", ""),
|
||||||
|
"Source": metadata.get("__url__", ""),
|
||||||
|
"Tracker": f"{metadata.get('__url__', '')}/issues"
|
||||||
|
if metadata.get("__url__")
|
||||||
|
else "",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user