diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml
new file mode 100644
index 00000000..7c12e0a2
--- /dev/null
+++ b/.github/workflows/linting.yaml
@@ -0,0 +1,30 @@
+name: Linting and Formatting
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ lint-and-format:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.x'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install pre-commit
+
+ - name: Run pre-commit
+ run: pre-commit run --all-files
diff --git a/README.md b/README.md
index 76535d19..15696b57 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@
-
+
@@ -28,6 +28,11 @@ This repository hosts the code of LightRAG. The structure of this code is based
- [x] [2024.10.16]π―π―π’π’LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]π―π―π’π’LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
+## Algorithm Flowchart
+
+
+
+
## Install
* Install from source (Recommend)
@@ -58,8 +63,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
-# import nest_asyncio
-# nest_asyncio.apply()
+# import nest_asyncio
+# nest_asyncio.apply()
#########
WORKING_DIR = "./dickens"
@@ -157,7 +162,7 @@ rag = LightRAG(
Using Ollama Models
-
+
* If you want to use Ollama models, you only need to set LightRAG as follows:
```python
@@ -204,7 +209,25 @@ ollama create -f Modelfile qwen2m
+### Query Param
+
+```python
+class QueryParam:
+ mode: Literal["local", "global", "hybrid", "naive"] = "global"
+ only_need_context: bool = False
+ response_type: str = "Multiple Paragraphs"
+ # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
+ top_k: int = 60
+ # Number of tokens for the original chunks.
+ max_token_for_text_unit: int = 4000
+ # Number of tokens for the relationship descriptions
+ max_token_for_global_context: int = 4000
+ # Number of tokens for the entity descriptions
+ max_token_for_local_context: int = 4000
+```
+
### Batch Insert
+
```python
# Batch Insert: Insert multiple texts at once
rag.insert(["TEXT1", "TEXT2",...])
@@ -214,7 +237,15 @@ rag.insert(["TEXT1", "TEXT2",...])
```python
# Incremental Insert: Insert new documents into an existing LightRAG instance
-rag = LightRAG(working_dir="./dickens")
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=embedding_dimension,
+ max_token_size=8192,
+ func=embedding_func,
+ ),
+)
with open("./newText.txt") as f:
rag.insert(f.read())
@@ -310,8 +341,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
- e.displayName = node.id
- REMOVE e:Entity
+ e.displayName = node.id
+ REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -364,7 +395,7 @@ def main():
except Exception as e:
print(f"Error occurred: {e}")
-
+
finally:
driver.close()
@@ -374,6 +405,125 @@ if __name__ == "__main__":
+## API Server Implementation
+
+LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
+
+### Setting up the API Server
+
+Click to expand setup instructions
+
+1. First, ensure you have the required dependencies:
+```bash
+pip install fastapi uvicorn pydantic
+```
+
+2. Set up your environment variables:
+```bash
+export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
+```
+
+3. Run the API server:
+```bash
+python examples/lightrag_api_openai_compatible_demo.py
+```
+
+The server will start on `http://0.0.0.0:8020`.
+
+
+### API Endpoints
+
+The API server provides the following endpoints:
+
+#### 1. Query Endpoint
+
+Click to view Query endpoint details
+
+- **URL:** `/query`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "query": "Your question here",
+ "mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/query" \
+ -H "Content-Type: application/json" \
+ -d '{"query": "What are the main themes?", "mode": "hybrid"}'
+```
+
+
+#### 2. Insert Text Endpoint
+
+Click to view Insert Text endpoint details
+
+- **URL:** `/insert`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "text": "Your text content here"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/insert" \
+ -H "Content-Type: application/json" \
+ -d '{"text": "Content to be inserted into RAG"}'
+```
+
+
+#### 3. Insert File Endpoint
+
+Click to view Insert File endpoint details
+
+- **URL:** `/insert_file`
+- **Method:** POST
+- **Body:**
+```json
+{
+ "file_path": "path/to/your/file.txt"
+}
+```
+- **Example:**
+```bash
+curl -X POST "http://127.0.0.1:8020/insert_file" \
+ -H "Content-Type: application/json" \
+ -d '{"file_path": "./book.txt"}'
+```
+
+
+#### 4. Health Check Endpoint
+
+Click to view Health Check endpoint details
+
+- **URL:** `/health`
+- **Method:** GET
+- **Example:**
+```bash
+curl -X GET "http://127.0.0.1:8020/health"
+```
+
+
+### Configuration
+
+The API server can be configured using environment variables:
+- `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
+- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
+
+### Error Handling
+
+Click to view error handling details
+
+The API includes comprehensive error handling:
+- File not found errors (404)
+- Processing errors (500)
+- Supports multiple file encodings (UTF-8 and GBK)
+
+
## Evaluation
### Dataset
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
@@ -629,6 +779,7 @@ def extract_queries(file_path):
β βββ lightrag_ollama_demo.py
β βββ lightrag_openai_compatible_demo.py
β βββ lightrag_openai_demo.py
+β βββ lightrag_siliconcloud_demo.py
β βββ vram_management_demo.py
βββ lightrag
β βββ __init__.py
diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py
index b455e6de..11279b3a 100644
--- a/examples/graph_visual_with_html.py
+++ b/examples/graph_visual_with_html.py
@@ -3,17 +3,17 @@ from pyvis.network import Network
import random
# Load the GraphML file
-G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
+G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
# Create a Pyvis network
-net = Network(notebook=True)
+net = Network(height="100vh", notebook=True)
# Convert NetworkX graph to Pyvis network
net.from_nx(G)
# Add colors to nodes
for node in net.nodes:
- node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
+ node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
# Save and display the network
-net.show('knowledge_graph.html')
\ No newline at end of file
+net.show("knowledge_graph.html")
diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py
index 22dde368..7377f21c 100644
--- a/examples/graph_visual_with_neo4j.py
+++ b/examples/graph_visual_with_neo4j.py
@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "your_password"
+
def convert_xml_to_json(xml_path, output_path):
"""Converts XML file to JSON and saves the output."""
if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
json_data = xml_to_json(xml_path)
if json_data:
- with open(output_path, 'w', encoding='utf-8') as f:
+ with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
print("Failed to create JSON data")
return None
+
def process_in_batches(tx, query, data, batch_size):
"""Process data in batches and execute the given query."""
for i in range(0, len(data), batch_size):
- batch = data[i:i + batch_size]
+ batch = data[i : i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
+
def main():
# Paths
- xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
- json_file = os.path.join(WORKING_DIR, 'graph_data.json')
+ xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
+ json_file = os.path.join(WORKING_DIR, "graph_data.json")
# Convert XML to JSON
json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
return
# Load nodes and edges
- nodes = json_data.get('nodes', [])
- edges = json_data.get('edges', [])
+ nodes = json_data.get("nodes", [])
+ edges = json_data.get("edges", [])
# Neo4j queries
create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
- e.displayName = node.id
- REMOVE e:Entity
+ e.displayName = node.id
+ REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
RETURN count(*)
@@ -100,19 +103,24 @@ def main():
# Execute queries in batches
with driver.session() as session:
# Insert nodes in batches
- session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
+ session.execute_write(
+ process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
+ )
# Insert edges in batches
- session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
+ session.execute_write(
+ process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
+ )
# Set displayName and labels
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
-
+
finally:
driver.close()
+
if __name__ == "__main__":
main()
diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py
new file mode 100644
index 00000000..2cd262bb
--- /dev/null
+++ b/examples/lightrag_api_openai_compatible_demo.py
@@ -0,0 +1,164 @@
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+import os
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import openai_complete_if_cache, openai_embedding
+from lightrag.utils import EmbeddingFunc
+import numpy as np
+from typing import Optional
+import asyncio
+import nest_asyncio
+
+# Apply nest_asyncio to solve event loop issues
+nest_asyncio.apply()
+
+DEFAULT_RAG_DIR = "index_default"
+app = FastAPI(title="LightRAG API", description="API for RAG operations")
+
+# Configure working directory
+WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
+print(f"WORKING_DIR: {WORKING_DIR}")
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+# LLM model function
+
+
+async def llm_model_func(
+ prompt, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ return await openai_complete_if_cache(
+ "gpt-4o-mini",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ api_key="YOUR_API_KEY",
+ base_url="YourURL/v1",
+ **kwargs,
+ )
+
+
+# Embedding function
+
+
+async def embedding_func(texts: list[str]) -> np.ndarray:
+ return await openai_embedding(
+ texts,
+ model="text-embedding-3-large",
+ api_key="YOUR_API_KEY",
+ base_url="YourURL/v1",
+ )
+
+
+# Initialize RAG instance
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=3072, max_token_size=8192, func=embedding_func
+ ),
+)
+
+# Data models
+
+
+class QueryRequest(BaseModel):
+ query: str
+ mode: str = "hybrid"
+
+
+class InsertRequest(BaseModel):
+ text: str
+
+
+class InsertFileRequest(BaseModel):
+ file_path: str
+
+
+class Response(BaseModel):
+ status: str
+ data: Optional[str] = None
+ message: Optional[str] = None
+
+
+# API routes
+
+
+@app.post("/query", response_model=Response)
+async def query_endpoint(request: QueryRequest):
+ try:
+ loop = asyncio.get_event_loop()
+ result = await loop.run_in_executor(
+ None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
+ )
+ return Response(status="success", data=result)
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/insert", response_model=Response)
+async def insert_endpoint(request: InsertRequest):
+ try:
+ loop = asyncio.get_event_loop()
+ await loop.run_in_executor(None, lambda: rag.insert(request.text))
+ return Response(status="success", message="Text inserted successfully")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/insert_file", response_model=Response)
+async def insert_file(request: InsertFileRequest):
+ try:
+ # Check if file exists
+ if not os.path.exists(request.file_path):
+ raise HTTPException(
+ status_code=404, detail=f"File not found: {request.file_path}"
+ )
+
+ # Read file content
+ try:
+ with open(request.file_path, "r", encoding="utf-8") as f:
+ content = f.read()
+ except UnicodeDecodeError:
+ # If UTF-8 decoding fails, try other encodings
+ with open(request.file_path, "r", encoding="gbk") as f:
+ content = f.read()
+
+ # Insert file content
+ loop = asyncio.get_event_loop()
+ await loop.run_in_executor(None, lambda: rag.insert(content))
+
+ return Response(
+ status="success",
+ message=f"File content from {request.file_path} inserted successfully",
+ )
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.get("/health")
+async def health_check():
+ return {"status": "healthy"}
+
+
+if __name__ == "__main__":
+ import uvicorn
+
+ uvicorn.run(app, host="0.0.0.0", port=8020)
+
+# Usage example
+# To run the server, use the following command in your terminal:
+# python lightrag_api_openai_compatible_demo.py
+
+# Example requests:
+# 1. Query:
+# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
+
+# 2. Insert text:
+# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
+
+# 3. Insert file:
+# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
+
+# 4. Health check:
+# curl -X GET "http://127.0.0.1:8020/health"
diff --git a/examples/lightrag_hf_demo.py b/examples/lightrag_hf_demo.py
index 87312307..91033e50 100644
--- a/examples/lightrag_hf_demo.py
+++ b/examples/lightrag_hf_demo.py
@@ -30,7 +30,7 @@ rag = LightRAG(
)
-with open("./book.txt") as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py
new file mode 100644
index 00000000..aeb96f71
--- /dev/null
+++ b/examples/lightrag_lmdeploy_demo.py
@@ -0,0 +1,75 @@
+import os
+
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
+from lightrag.utils import EmbeddingFunc
+from transformers import AutoModel, AutoTokenizer
+
+WORKING_DIR = "./dickens"
+
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+
+async def lmdeploy_model_complete(
+ prompt=None, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ return await lmdeploy_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ ## please specify chat_template if your local path does not follow original HF file name,
+ ## or model_name is a pytorch model on huggingface.co,
+ ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
+ ## for a list of chat_template available in lmdeploy.
+ chat_template="llama3",
+ # model_format ='awq', # if you are using awq quantization model.
+ # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
+ **kwargs,
+ )
+
+
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=lmdeploy_model_complete,
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
+ embedding_func=EmbeddingFunc(
+ embedding_dim=384,
+ max_token_size=5000,
+ func=lambda texts: hf_embedding(
+ texts,
+ tokenizer=AutoTokenizer.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ embed_model=AutoModel.from_pretrained(
+ "sentence-transformers/all-MiniLM-L6-v2"
+ ),
+ ),
+ ),
+)
+
+
+with open("./book.txt", "r", encoding="utf-8") as f:
+ rag.insert(f.read())
+
+# Perform naive search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
+
+# Perform local search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
+
+# Perform global search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
+
+# Perform hybrid search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py
index 6070131f..0a704024 100644
--- a/examples/lightrag_ollama_demo.py
+++ b/examples/lightrag_ollama_demo.py
@@ -28,7 +28,7 @@ rag = LightRAG(
),
)
-with open("./book.txt") as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py
index fbad1190..2470fc00 100644
--- a/examples/lightrag_openai_compatible_demo.py
+++ b/examples/lightrag_openai_compatible_demo.py
@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
)
+async def get_embedding_dim():
+ test_text = ["This is a test sentence."]
+ embedding = await embedding_func(test_text)
+ embedding_dim = embedding.shape[1]
+ return embedding_dim
+
+
# function test
async def test_funcs():
result = await llm_model_func("How are you?")
@@ -43,37 +50,59 @@ async def test_funcs():
print("embedding_func: ", result)
-asyncio.run(test_funcs())
+# asyncio.run(test_funcs())
-rag = LightRAG(
- working_dir=WORKING_DIR,
- llm_model_func=llm_model_func,
- embedding_func=EmbeddingFunc(
- embedding_dim=4096, max_token_size=8192, func=embedding_func
- ),
-)
+async def main():
+ try:
+ embedding_dimension = await get_embedding_dim()
+ print(f"Detected embedding dimension: {embedding_dimension}")
+
+ rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=embedding_dimension,
+ max_token_size=8192,
+ func=embedding_func,
+ ),
+ )
+
+ with open("./book.txt", "r", encoding="utf-8") as f:
+ rag.insert(f.read())
+
+ # Perform naive search
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
+ )
+ )
+
+ # Perform local search
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="local")
+ )
+ )
+
+ # Perform global search
+ print(
+ rag.query(
+ "What are the top themes in this story?",
+ param=QueryParam(mode="global"),
+ )
+ )
+
+ # Perform hybrid search
+ print(
+ rag.query(
+ "What are the top themes in this story?",
+ param=QueryParam(mode="hybrid"),
+ )
+ )
+ except Exception as e:
+ print(f"An error occurred: {e}")
-with open("./book.txt") as f:
- rag.insert(f.read())
-
-# Perform naive search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
-)
-
-# Perform local search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
-)
-
-# Perform global search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
-)
-
-# Perform hybrid search
-print(
- rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
-)
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py
index a6e7f3b2..29bc75ca 100644
--- a/examples/lightrag_openai_demo.py
+++ b/examples/lightrag_openai_demo.py
@@ -15,7 +15,7 @@ rag = LightRAG(
)
-with open("./book.txt") as f:
+with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py
new file mode 100644
index 00000000..a73f16c5
--- /dev/null
+++ b/examples/lightrag_siliconcloud_demo.py
@@ -0,0 +1,79 @@
+import os
+import asyncio
+from lightrag import LightRAG, QueryParam
+from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
+from lightrag.utils import EmbeddingFunc
+import numpy as np
+
+WORKING_DIR = "./dickens"
+
+if not os.path.exists(WORKING_DIR):
+ os.mkdir(WORKING_DIR)
+
+
+async def llm_model_func(
+ prompt, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ return await openai_complete_if_cache(
+ "Qwen/Qwen2.5-7B-Instruct",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
+ base_url="https://api.siliconflow.cn/v1/",
+ **kwargs,
+ )
+
+
+async def embedding_func(texts: list[str]) -> np.ndarray:
+ return await siliconcloud_embedding(
+ texts,
+ model="netease-youdao/bce-embedding-base_v1",
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
+ max_token_size=512,
+ )
+
+
+# function test
+async def test_funcs():
+ result = await llm_model_func("How are you?")
+ print("llm_model_func: ", result)
+
+ result = await embedding_func(["How are you?"])
+ print("embedding_func: ", result)
+
+
+asyncio.run(test_funcs())
+
+
+rag = LightRAG(
+ working_dir=WORKING_DIR,
+ llm_model_func=llm_model_func,
+ embedding_func=EmbeddingFunc(
+ embedding_dim=768, max_token_size=512, func=embedding_func
+ ),
+)
+
+
+with open("./book.txt") as f:
+ rag.insert(f.read())
+
+# Perform naive search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
+)
+
+# Perform local search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
+)
+
+# Perform global search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
+)
+
+# Perform hybrid search
+print(
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
+)
diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py
index ec750254..c173b913 100644
--- a/examples/vram_management_demo.py
+++ b/examples/vram_management_demo.py
@@ -27,11 +27,12 @@ rag = LightRAG(
# Read all .txt files from the TEXT_FILES_DIR directory
texts = []
for filename in os.listdir(TEXT_FILES_DIR):
- if filename.endswith('.txt'):
+ if filename.endswith(".txt"):
file_path = os.path.join(TEXT_FILES_DIR, filename)
- with open(file_path, 'r', encoding='utf-8') as file:
+ with open(file_path, "r", encoding="utf-8") as file:
texts.append(file.read())
+
# Batch insert texts into LightRAG with a retry mechanism
def insert_texts_with_retry(rag, texts, retries=3, delay=5):
for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
rag.insert(texts)
return
except Exception as e:
- print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
+ print(
+ f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
+ )
time.sleep(delay)
raise RuntimeError("Failed to insert texts after multiple retries.")
+
insert_texts_with_retry(rag, texts)
# Perform different types of queries and handle potential errors
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
+ )
+ )
except Exception as e:
print(f"Error performing naive search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="local")
+ )
+ )
except Exception as e:
print(f"Error performing local search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="global")
+ )
+ )
except Exception as e:
print(f"Error performing global search: {e}")
try:
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
+ print(
+ rag.query(
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
+ )
+ )
except Exception as e:
print(f"Error performing hybrid search: {e}")
+
# Function to clear VRAM resources
def clear_vram():
os.system("sudo nvidia-smi --gpu-reset")
+
# Regularly clear VRAM to prevent overflow
clear_vram_interval = 3600 # Clear once every hour
start_time = time.time()
diff --git a/lightrag/__init__.py b/lightrag/__init__.py
index db81e005..8e76a260 100644
--- a/lightrag/__init__.py
+++ b/lightrag/__init__.py
@@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
-__version__ = "0.0.7"
+__version__ = "0.0.8"
__author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG"
diff --git a/lightrag/base.py b/lightrag/base.py
index 50be4f62..cecd5edd 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -18,9 +18,13 @@ class QueryParam:
mode: Literal["local", "global", "hybrid", "naive"] = "global"
only_need_context: bool = False
response_type: str = "Multiple Paragraphs"
+ # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
+ # Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
+ # Number of tokens for the relationship descriptions
max_token_for_global_context: int = 4000
+ # Number of tokens for the entity descriptions
max_token_for_local_context: int = 4000
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index d4b1eaa1..955651fb 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -209,7 +209,7 @@ class LightRAG:
logger.info("[Entity Extraction]...")
maybe_new_kg = await extract_entities(
inserting_chunks,
- knwoledge_graph_inst=self.chunk_entity_relation_graph,
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
diff --git a/lightrag/llm.py b/lightrag/llm.py
index aa818995..f4045e80 100644
--- a/lightrag/llm.py
+++ b/lightrag/llm.py
@@ -1,10 +1,23 @@
import os
import copy
+from functools import lru_cache
import json
import aioboto3
+import aiohttp
import numpy as np
import ollama
-from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
+
+from openai import (
+ AsyncOpenAI,
+ APIConnectionError,
+ RateLimitError,
+ Timeout,
+ AsyncAzureOpenAI,
+)
+
+import base64
+import struct
+
from tenacity import (
retry,
stop_after_attempt,
@@ -13,6 +26,8 @@ from tenacity import (
)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
+from pydantic import BaseModel, Field
+from typing import List, Dict, Callable, Any
from .base import BaseKVStorage
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
@@ -62,6 +77,55 @@ async def openai_complete_if_cache(
return response.choices[0].message.content
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def azure_openai_complete_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ base_url=None,
+ api_key=None,
+ **kwargs,
+):
+ if api_key:
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
+ if base_url:
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
+
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
+
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ if prompt is not None:
+ messages.append({"role": "user", "content": prompt})
+ if hashing_kv is not None:
+ args_hash = compute_args_hash(model, messages)
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
+ if if_cache_return is not None:
+ return if_cache_return["return"]
+
+ response = await openai_async_client.chat.completions.create(
+ model=model, messages=messages, **kwargs
+ )
+
+ if hashing_kv is not None:
+ await hashing_kv.upsert(
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
+ )
+ return response.choices[0].message.content
+
+
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@@ -151,15 +215,25 @@ async def bedrock_complete_if_cache(
return response["output"]["message"]["content"][0]["text"]
+@lru_cache(maxsize=1)
+def initialize_hf_model(model_name):
+ hf_tokenizer = AutoTokenizer.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
+ hf_model = AutoModelForCausalLM.from_pretrained(
+ model_name, device_map="auto", trust_remote_code=True
+ )
+ if hf_tokenizer.pad_token is None:
+ hf_tokenizer.pad_token = hf_tokenizer.eos_token
+
+ return hf_model, hf_tokenizer
+
+
async def hf_model_if_cache(
model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
model_name = model
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
- if hf_tokenizer.pad_token is None:
- # print("use eos token")
- hf_tokenizer.pad_token = hf_tokenizer.eos_token
- hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
+ hf_model, hf_tokenizer = initialize_hf_model(model_name)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
@@ -208,10 +282,13 @@ async def hf_model_if_cache(
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
).to("cuda")
+ inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
- **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
+ **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
+ )
+ response_text = hf_tokenizer.decode(
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
- response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
if hashing_kv is not None:
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
return response_text
@@ -249,6 +326,135 @@ async def ollama_model_if_cache(
return result
+@lru_cache(maxsize=1)
+def initialize_lmdeploy_pipeline(
+ model,
+ tp=1,
+ chat_template=None,
+ log_level="WARNING",
+ model_format="hf",
+ quant_policy=0,
+):
+ from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
+
+ lmdeploy_pipe = pipeline(
+ model_path=model,
+ backend_config=TurbomindEngineConfig(
+ tp=tp, model_format=model_format, quant_policy=quant_policy
+ ),
+ chat_template_config=ChatTemplateConfig(model_name=chat_template)
+ if chat_template
+ else None,
+ log_level="WARNING",
+ )
+ return lmdeploy_pipe
+
+
+async def lmdeploy_model_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ chat_template=None,
+ model_format="hf",
+ quant_policy=0,
+ **kwargs,
+) -> str:
+ """
+ Args:
+ model (str): The path to the model.
+ It could be one of the following options:
+ - i) A local directory path of a turbomind model which is
+ converted by `lmdeploy convert` command or download
+ from ii) and iii).
+ - ii) The model_id of a lmdeploy-quantized model hosted
+ inside a model repo on huggingface.co, such as
+ "InternLM/internlm-chat-20b-4bit",
+ "lmdeploy/llama2-chat-70b-4bit", etc.
+ - iii) The model_id of a model hosted inside a model repo
+ on huggingface.co, such as "internlm/internlm-chat-7b",
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
+ and so on.
+ chat_template (str): needed when model is a pytorch model on
+ huggingface.co, such as "internlm-chat-7b",
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
+ and when the model name of local path did not match the original model name in HF.
+ tp (int): tensor parallel
+ prompt (Union[str, List[str]]): input texts to be completed.
+ do_preprocess (bool): whether pre-process the messages. Default to
+ True, which means chat_template will be applied.
+ skip_special_tokens (bool): Whether or not to remove special tokens
+ in the decoding. Default to be True.
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
+ Default to be False, which means greedy decoding will be applied.
+ """
+ try:
+ import lmdeploy
+ from lmdeploy import version_info, GenerationConfig
+ except Exception:
+ raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
+
+ kwargs.pop("response_format", None)
+ max_new_tokens = kwargs.pop("max_tokens", 512)
+ tp = kwargs.pop("tp", 1)
+ skip_special_tokens = kwargs.pop("skip_special_tokens", True)
+ do_preprocess = kwargs.pop("do_preprocess", True)
+ do_sample = kwargs.pop("do_sample", False)
+ gen_params = kwargs
+
+ version = version_info
+ if do_sample is not None and version < (0, 6, 0):
+ raise RuntimeError(
+ "`do_sample` parameter is not supported by lmdeploy until "
+ f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
+ )
+ else:
+ do_sample = True
+ gen_params.update(do_sample=do_sample)
+
+ lmdeploy_pipe = initialize_lmdeploy_pipeline(
+ model=model,
+ tp=tp,
+ chat_template=chat_template,
+ model_format=model_format,
+ quant_policy=quant_policy,
+ log_level="WARNING",
+ )
+
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+ if hashing_kv is not None:
+ args_hash = compute_args_hash(model, messages)
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
+ if if_cache_return is not None:
+ return if_cache_return["return"]
+
+ gen_config = GenerationConfig(
+ skip_special_tokens=skip_special_tokens,
+ max_new_tokens=max_new_tokens,
+ **gen_params,
+ )
+
+ response = ""
+ async for res in lmdeploy_pipe.generate(
+ messages,
+ gen_config=gen_config,
+ do_preprocess=do_preprocess,
+ stream_response=False,
+ session_id=1,
+ ):
+ response += res.response
+
+ if hashing_kv is not None:
+ await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
+ return response
+
+
async def gpt_4o_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -273,6 +479,18 @@ async def gpt_4o_mini_complete(
)
+async def azure_openai_complete(
+ prompt, system_prompt=None, history_messages=[], **kwargs
+) -> str:
+ return await azure_openai_complete_if_cache(
+ "conversation-4o-mini",
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+
+
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
@@ -314,7 +532,7 @@ async def ollama_model_complete(
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_embedding(
@@ -335,6 +553,73 @@ async def openai_embedding(
return np.array([dp.embedding for dp in response.data])
+@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def azure_openai_embedding(
+ texts: list[str],
+ model: str = "text-embedding-3-small",
+ base_url: str = None,
+ api_key: str = None,
+) -> np.ndarray:
+ if api_key:
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
+ if base_url:
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
+
+ openai_async_client = AsyncAzureOpenAI(
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
+ )
+
+ response = await openai_async_client.embeddings.create(
+ model=model, input=texts, encoding_format="float"
+ )
+ return np.array([dp.embedding for dp in response.data])
+
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=60),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def siliconcloud_embedding(
+ texts: list[str],
+ model: str = "netease-youdao/bce-embedding-base_v1",
+ base_url: str = "https://api.siliconflow.cn/v1/embeddings",
+ max_token_size: int = 512,
+ api_key: str = None,
+) -> np.ndarray:
+ if api_key and not api_key.startswith("Bearer "):
+ api_key = "Bearer " + api_key
+
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
+
+ truncate_texts = [text[0:max_token_size] for text in texts]
+
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
+
+ base64_strings = []
+ async with aiohttp.ClientSession() as session:
+ async with session.post(base_url, headers=headers, json=payload) as response:
+ content = await response.json()
+ if "code" in content:
+ raise ValueError(content)
+ base64_strings = [item["embedding"] for item in content["data"]]
+
+ embeddings = []
+ for string in base64_strings:
+ decode_bytes = base64.b64decode(string)
+ n = len(decode_bytes) // 4
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
+ embeddings.append(float_array)
+ return np.array(embeddings)
+
+
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
# stop=stop_after_attempt(3),
@@ -427,6 +712,85 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
return embed_text
+class Model(BaseModel):
+ """
+ This is a Pydantic model class named 'Model' that is used to define a custom language model.
+
+ Attributes:
+ gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
+ The function should take any argument and return a string.
+ kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
+ This could include parameters such as the model name, API key, etc.
+
+ Example usage:
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
+
+ In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
+ The 'kwargs' dictionary contains the model name and API key to be passed to the function.
+ """
+
+ gen_func: Callable[[Any], str] = Field(
+ ...,
+ description="A function that generates the response from the llm. The response must be a string",
+ )
+ kwargs: Dict[str, Any] = Field(
+ ...,
+ description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
+ )
+
+ class Config:
+ arbitrary_types_allowed = True
+
+
+class MultiModel:
+ """
+ Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
+ Could also be used for spliting across diffrent models or providers.
+
+ Attributes:
+ models (List[Model]): A list of language models to be used.
+
+ Usage example:
+ ```python
+ models = [
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
+ Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
+ ]
+ multi_model = MultiModel(models)
+ rag = LightRAG(
+ llm_model_func=multi_model.llm_model_func
+ / ..other args
+ )
+ ```
+ """
+
+ def __init__(self, models: List[Model]):
+ self._models = models
+ self._current_model = 0
+
+ def _next_model(self):
+ self._current_model = (self._current_model + 1) % len(self._models)
+ return self._models[self._current_model]
+
+ async def llm_model_func(
+ self, prompt, system_prompt=None, history_messages=[], **kwargs
+ ) -> str:
+ kwargs.pop("model", None) # stop from overwriting the custom model name
+ next_model = self._next_model()
+ args = dict(
+ prompt=prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ **next_model.kwargs,
+ )
+
+ return await next_model.gen_func(**args)
+
+
if __name__ == "__main__":
import asyncio
diff --git a/lightrag/operate.py b/lightrag/operate.py
index a0729cd8..8a6820f5 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
async def _merge_nodes_then_upsert(
entity_name: str,
nodes_data: list[dict],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_entitiy_types = []
already_source_ids = []
already_description = []
- already_node = await knwoledge_graph_inst.get_node(entity_name)
+ already_node = await knowledge_graph_inst.get_node(entity_name)
if already_node is not None:
already_entitiy_types.append(already_node["entity_type"])
already_source_ids.extend(
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
description=description,
source_id=source_id,
)
- await knwoledge_graph_inst.upsert_node(
+ await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
src_id: str,
tgt_id: str,
edges_data: list[dict],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
global_config: dict,
):
already_weights = []
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
already_description = []
already_keywords = []
- if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
- already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
+ if await knowledge_graph_inst.has_edge(src_id, tgt_id):
+ already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
already_weights.append(already_edge["weight"])
already_source_ids.extend(
split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
set([dp["source_id"] for dp in edges_data] + already_source_ids)
)
for need_insert_id in [src_id, tgt_id]:
- if not (await knwoledge_graph_inst.has_node(need_insert_id)):
- await knwoledge_graph_inst.upsert_node(
+ if not (await knowledge_graph_inst.has_node(need_insert_id)):
+ await knowledge_graph_inst.upsert_node(
need_insert_id,
node_data={
"source_id": source_id,
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
description = await _handle_entity_relation_summary(
(src_id, tgt_id), description, global_config
)
- await knwoledge_graph_inst.upsert_edge(
+ await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
async def extract_entities(
chunks: dict[str, TextChunkSchema],
- knwoledge_graph_inst: BaseGraphStorage,
+ knowledge_graph_inst: BaseGraphStorage,
entity_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
global_config: dict,
@@ -341,13 +341,13 @@ async def extract_entities(
maybe_edges[tuple(sorted(k))].extend(v)
all_entities_data = await asyncio.gather(
*[
- _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
+ _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
for k, v in maybe_nodes.items()
]
)
all_relationships_data = await asyncio.gather(
*[
- _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
+ _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
for k, v in maybe_edges.items()
]
)
@@ -384,7 +384,7 @@ async def extract_entities(
}
await relationships_vdb.upsert(data_for_vdb)
- return knwoledge_graph_inst
+ return knowledge_graph_inst
async def local_query(
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 9a68c16b..0da4a51a 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
+
def xml_to_json(xml_file):
try:
tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
- data = {
- "nodes": [],
- "edges": []
- }
+ data = {"nodes": [], "edges": []}
# Use namespace
- namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
+ namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
- for node in root.findall('.//node', namespace):
+ for node in root.findall(".//node", namespace):
node_data = {
- "id": node.get('id').strip('"'),
- "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
- "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
- "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
+ "id": node.get("id").strip('"'),
+ "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
+ if node.find("./data[@key='d0']", namespace) is not None
+ else "",
+ "description": node.find("./data[@key='d1']", namespace).text
+ if node.find("./data[@key='d1']", namespace) is not None
+ else "",
+ "source_id": node.find("./data[@key='d2']", namespace).text
+ if node.find("./data[@key='d2']", namespace) is not None
+ else "",
}
data["nodes"].append(node_data)
- for edge in root.findall('.//edge', namespace):
+ for edge in root.findall(".//edge", namespace):
edge_data = {
- "source": edge.get('source').strip('"'),
- "target": edge.get('target').strip('"'),
- "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
- "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
- "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
- "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
+ "source": edge.get("source").strip('"'),
+ "target": edge.get("target").strip('"'),
+ "weight": float(edge.find("./data[@key='d3']", namespace).text)
+ if edge.find("./data[@key='d3']", namespace) is not None
+ else 0.0,
+ "description": edge.find("./data[@key='d4']", namespace).text
+ if edge.find("./data[@key='d4']", namespace) is not None
+ else "",
+ "keywords": edge.find("./data[@key='d5']", namespace).text
+ if edge.find("./data[@key='d5']", namespace) is not None
+ else "",
+ "source_id": edge.find("./data[@key='d6']", namespace).text
+ if edge.find("./data[@key='d6']", namespace) is not None
+ else "",
}
data["edges"].append(edge_data)
diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py
index 2be5ea5c..5e2ef778 100644
--- a/reproduce/Step_3_openai_compatible.py
+++ b/reproduce/Step_3_openai_compatible.py
@@ -50,8 +50,8 @@ def extract_queries(file_path):
async def process_query(query_text, rag_instance, query_param):
try:
- result, context = await rag_instance.aquery(query_text, param=query_param)
- return {"query": query_text, "result": result, "context": context}, None
+ result = await rag_instance.aquery(query_text, param=query_param)
+ return {"query": query_text, "result": result}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}
diff --git a/requirements.txt b/requirements.txt
index 9cc5b7e9..6b0e025a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,16 @@
accelerate
aioboto3
+aiohttp
graspologic
hnswlib
nano-vectordb
networkx
ollama
openai
+pyvis
tenacity
tiktoken
torch
transformers
xxhash
-pyvis
\ No newline at end of file
+# lmdeploy[all]
diff --git a/setup.py b/setup.py
index 47222420..1b1f65f0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,39 +1,88 @@
import setuptools
-
-with open("README.md", "r", encoding="utf-8") as fh:
- long_description = fh.read()
+from pathlib import Path
-vars2find = ["__author__", "__version__", "__url__"]
-vars2readme = {}
-with open("./lightrag/__init__.py") as f:
- for line in f.readlines():
- for v in vars2find:
- if line.startswith(v):
- line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
- vars2readme[v] = line.split("=")[1]
+# Reading the long description from README.md
+def read_long_description():
+ try:
+ return Path("README.md").read_text(encoding="utf-8")
+ except FileNotFoundError:
+ return "A description of LightRAG is currently unavailable."
-deps = []
-with open("./requirements.txt") as f:
- for line in f.readlines():
- if not line.strip():
- continue
- deps.append(line.strip())
+
+# Retrieving metadata from __init__.py
+def retrieve_metadata():
+ vars2find = ["__author__", "__version__", "__url__"]
+ vars2readme = {}
+ try:
+ with open("./lightrag/__init__.py") as f:
+ for line in f.readlines():
+ for v in vars2find:
+ if line.startswith(v):
+ line = (
+ line.replace(" ", "")
+ .replace('"', "")
+ .replace("'", "")
+ .strip()
+ )
+ vars2readme[v] = line.split("=")[1]
+ except FileNotFoundError:
+ raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
+
+ # Checking if all required variables are found
+ missing_vars = [v for v in vars2find if v not in vars2readme]
+ if missing_vars:
+ raise ValueError(
+ f"Missing required metadata variables in __init__.py: {missing_vars}"
+ )
+
+ return vars2readme
+
+
+# Reading dependencies from requirements.txt
+def read_requirements():
+ deps = []
+ try:
+ with open("./requirements.txt") as f:
+ deps = [line.strip() for line in f if line.strip()]
+ except FileNotFoundError:
+ print(
+ "Warning: 'requirements.txt' not found. No dependencies will be installed."
+ )
+ return deps
+
+
+metadata = retrieve_metadata()
+long_description = read_long_description()
+requirements = read_requirements()
setuptools.setup(
name="lightrag-hku",
- url=vars2readme["__url__"],
- version=vars2readme["__version__"],
- author=vars2readme["__author__"],
+ url=metadata["__url__"],
+ version=metadata["__version__"],
+ author=metadata["__author__"],
description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
long_description=long_description,
long_description_content_type="text/markdown",
- packages=["lightrag"],
+ packages=setuptools.find_packages(
+ exclude=("tests*", "docs*")
+ ), # Automatically find packages
classifiers=[
+ "Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
+ "Intended Audience :: Developers",
+ "Topic :: Software Development :: Libraries :: Python Modules",
],
python_requires=">=3.9",
- install_requires=deps,
+ install_requires=requirements,
+ include_package_data=True, # Includes non-code files from MANIFEST.in
+ project_urls={ # Additional project metadata
+ "Documentation": metadata.get("__url__", ""),
+ "Source": metadata.get("__url__", ""),
+ "Tracker": f"{metadata.get('__url__', '')}/issues"
+ if metadata.get("__url__")
+ else "",
+ },
)