diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml new file mode 100644 index 00000000..7c12e0a2 --- /dev/null +++ b/.github/workflows/linting.yaml @@ -0,0 +1,30 @@ +name: Linting and Formatting + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + lint-and-format: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit + + - name: Run pre-commit + run: pre-commit run --all-files diff --git a/README.md b/README.md index 76535d19..15696b57 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ - +

@@ -28,6 +28,11 @@ This repository hosts the code of LightRAG. The structure of this code is based - [x] [2024.10.16]πŸŽ―πŸŽ―πŸ“’πŸ“’LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! - [x] [2024.10.15]πŸŽ―πŸŽ―πŸ“’πŸ“’LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)! +## Algorithm Flowchart + +![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3) + + ## Install * Install from source (Recommend) @@ -58,8 +63,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() -# import nest_asyncio -# nest_asyncio.apply() +# import nest_asyncio +# nest_asyncio.apply() ######### WORKING_DIR = "./dickens" @@ -157,7 +162,7 @@ rag = LightRAG(

Using Ollama Models - + * If you want to use Ollama models, you only need to set LightRAG as follows: ```python @@ -204,7 +209,25 @@ ollama create -f Modelfile qwen2m
+### Query Param + +```python +class QueryParam: + mode: Literal["local", "global", "hybrid", "naive"] = "global" + only_need_context: bool = False + response_type: str = "Multiple Paragraphs" + # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. + top_k: int = 60 + # Number of tokens for the original chunks. + max_token_for_text_unit: int = 4000 + # Number of tokens for the relationship descriptions + max_token_for_global_context: int = 4000 + # Number of tokens for the entity descriptions + max_token_for_local_context: int = 4000 +``` + ### Batch Insert + ```python # Batch Insert: Insert multiple texts at once rag.insert(["TEXT1", "TEXT2",...]) @@ -214,7 +237,15 @@ rag.insert(["TEXT1", "TEXT2",...]) ```python # Incremental Insert: Insert new documents into an existing LightRAG instance -rag = LightRAG(working_dir="./dickens") +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), +) with open("./newText.txt") as f: rag.insert(f.read()) @@ -310,8 +341,8 @@ def main(): SET e.entity_type = node.entity_type, e.description = node.description, e.source_id = node.source_id, - e.displayName = node.id - REMOVE e:Entity + e.displayName = node.id + REMOVE e:Entity WITH e, node CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode RETURN count(*) @@ -364,7 +395,7 @@ def main(): except Exception as e: print(f"Error occurred: {e}") - + finally: driver.close() @@ -374,6 +405,125 @@ if __name__ == "__main__": +## API Server Implementation + +LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests. + +### Setting up the API Server +
+Click to expand setup instructions + +1. First, ensure you have the required dependencies: +```bash +pip install fastapi uvicorn pydantic +``` + +2. Set up your environment variables: +```bash +export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default" +``` + +3. Run the API server: +```bash +python examples/lightrag_api_openai_compatible_demo.py +``` + +The server will start on `http://0.0.0.0:8020`. +
+ +### API Endpoints + +The API server provides the following endpoints: + +#### 1. Query Endpoint +
+Click to view Query endpoint details + +- **URL:** `/query` +- **Method:** POST +- **Body:** +```json +{ + "query": "Your question here", + "mode": "hybrid" // Can be "naive", "local", "global", or "hybrid" +} +``` +- **Example:** +```bash +curl -X POST "http://127.0.0.1:8020/query" \ + -H "Content-Type: application/json" \ + -d '{"query": "What are the main themes?", "mode": "hybrid"}' +``` +
+ +#### 2. Insert Text Endpoint +
+Click to view Insert Text endpoint details + +- **URL:** `/insert` +- **Method:** POST +- **Body:** +```json +{ + "text": "Your text content here" +} +``` +- **Example:** +```bash +curl -X POST "http://127.0.0.1:8020/insert" \ + -H "Content-Type: application/json" \ + -d '{"text": "Content to be inserted into RAG"}' +``` +
+ +#### 3. Insert File Endpoint +
+Click to view Insert File endpoint details + +- **URL:** `/insert_file` +- **Method:** POST +- **Body:** +```json +{ + "file_path": "path/to/your/file.txt" +} +``` +- **Example:** +```bash +curl -X POST "http://127.0.0.1:8020/insert_file" \ + -H "Content-Type: application/json" \ + -d '{"file_path": "./book.txt"}' +``` +
+ +#### 4. Health Check Endpoint +
+Click to view Health Check endpoint details + +- **URL:** `/health` +- **Method:** GET +- **Example:** +```bash +curl -X GET "http://127.0.0.1:8020/health" +``` +
+ +### Configuration + +The API server can be configured using environment variables: +- `RAG_DIR`: Directory for storing the RAG index (default: "index_default") +- API keys and base URLs should be configured in the code for your specific LLM and embedding model providers + +### Error Handling +
+Click to view error handling details + +The API includes comprehensive error handling: +- File not found errors (404) +- Processing errors (500) +- Supports multiple file encodings (UTF-8 and GBK) +
+ ## Evaluation ### Dataset The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). @@ -629,6 +779,7 @@ def extract_queries(file_path): β”‚ β”œβ”€β”€ lightrag_ollama_demo.py β”‚ β”œβ”€β”€ lightrag_openai_compatible_demo.py β”‚ β”œβ”€β”€ lightrag_openai_demo.py +β”‚ β”œβ”€β”€ lightrag_siliconcloud_demo.py β”‚ └── vram_management_demo.py β”œβ”€β”€ lightrag β”‚ β”œβ”€β”€ __init__.py diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index b455e6de..11279b3a 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -3,17 +3,17 @@ from pyvis.network import Network import random # Load the GraphML file -G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml') +G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml") # Create a Pyvis network -net = Network(notebook=True) +net = Network(height="100vh", notebook=True) # Convert NetworkX graph to Pyvis network net.from_nx(G) # Add colors to nodes for node in net.nodes: - node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) + node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) # Save and display the network -net.show('knowledge_graph.html') \ No newline at end of file +net.show("knowledge_graph.html") diff --git a/examples/graph_visual_with_neo4j.py b/examples/graph_visual_with_neo4j.py index 22dde368..7377f21c 100644 --- a/examples/graph_visual_with_neo4j.py +++ b/examples/graph_visual_with_neo4j.py @@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "your_password" + def convert_xml_to_json(xml_path, output_path): """Converts XML file to JSON and saves the output.""" if not os.path.exists(xml_path): @@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path): json_data = xml_to_json(xml_path) if json_data: - with open(output_path, 'w', encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: json.dump(json_data, f, ensure_ascii=False, indent=2) print(f"JSON file created: {output_path}") return json_data @@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path): print("Failed to create JSON data") return None + def process_in_batches(tx, query, data, batch_size): """Process data in batches and execute the given query.""" for i in range(0, len(data), batch_size): - batch = data[i:i + batch_size] + batch = data[i : i + batch_size] tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) + def main(): # Paths - xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml') - json_file = os.path.join(WORKING_DIR, 'graph_data.json') + xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") + json_file = os.path.join(WORKING_DIR, "graph_data.json") # Convert XML to JSON json_data = convert_xml_to_json(xml_file, json_file) @@ -46,8 +49,8 @@ def main(): return # Load nodes and edges - nodes = json_data.get('nodes', []) - edges = json_data.get('edges', []) + nodes = json_data.get("nodes", []) + edges = json_data.get("edges", []) # Neo4j queries create_nodes_query = """ @@ -56,8 +59,8 @@ def main(): SET e.entity_type = node.entity_type, e.description = node.description, e.source_id = node.source_id, - e.displayName = node.id - REMOVE e:Entity + e.displayName = node.id + REMOVE e:Entity WITH e, node CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode RETURN count(*) @@ -100,19 +103,24 @@ def main(): # Execute queries in batches with driver.session() as session: # Insert nodes in batches - session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES) + session.execute_write( + process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES + ) # Insert edges in batches - session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES) + session.execute_write( + process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES + ) # Set displayName and labels session.run(set_displayname_and_labels_query) except Exception as e: print(f"Error occurred: {e}") - + finally: driver.close() + if __name__ == "__main__": main() diff --git a/examples/lightrag_api_openai_compatible_demo.py b/examples/lightrag_api_openai_compatible_demo.py new file mode 100644 index 00000000..2cd262bb --- /dev/null +++ b/examples/lightrag_api_openai_compatible_demo.py @@ -0,0 +1,164 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import os +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, openai_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np +from typing import Optional +import asyncio +import nest_asyncio + +# Apply nest_asyncio to solve event loop issues +nest_asyncio.apply() + +DEFAULT_RAG_DIR = "index_default" +app = FastAPI(title="LightRAG API", description="API for RAG operations") + +# Configure working directory +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# LLM model function + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key="YOUR_API_KEY", + base_url="YourURL/v1", + **kwargs, + ) + + +# Embedding function + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embedding( + texts, + model="text-embedding-3-large", + api_key="YOUR_API_KEY", + base_url="YourURL/v1", + ) + + +# Initialize RAG instance +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=3072, max_token_size=8192, func=embedding_func + ), +) + +# Data models + + +class QueryRequest(BaseModel): + query: str + mode: str = "hybrid" + + +class InsertRequest(BaseModel): + text: str + + +class InsertFileRequest(BaseModel): + file_path: str + + +class Response(BaseModel): + status: str + data: Optional[str] = None + message: Optional[str] = None + + +# API routes + + +@app.post("/query", response_model=Response) +async def query_endpoint(request: QueryRequest): + try: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode)) + ) + return Response(status="success", data=result) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/insert", response_model=Response) +async def insert_endpoint(request: InsertRequest): + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(request.text)) + return Response(status="success", message="Text inserted successfully") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/insert_file", response_model=Response) +async def insert_file(request: InsertFileRequest): + try: + # Check if file exists + if not os.path.exists(request.file_path): + raise HTTPException( + status_code=404, detail=f"File not found: {request.file_path}" + ) + + # Read file content + try: + with open(request.file_path, "r", encoding="utf-8") as f: + content = f.read() + except UnicodeDecodeError: + # If UTF-8 decoding fails, try other encodings + with open(request.file_path, "r", encoding="gbk") as f: + content = f.read() + + # Insert file content + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: rag.insert(content)) + + return Response( + status="success", + message=f"File content from {request.file_path} inserted successfully", + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health") +async def health_check(): + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8020) + +# Usage example +# To run the server, use the following command in your terminal: +# python lightrag_api_openai_compatible_demo.py + +# Example requests: +# 1. Query: +# curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}' + +# 2. Insert text: +# curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}' + +# 3. Insert file: +# curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}' + +# 4. Health check: +# curl -X GET "http://127.0.0.1:8020/health" diff --git a/examples/lightrag_hf_demo.py b/examples/lightrag_hf_demo.py index 87312307..91033e50 100644 --- a/examples/lightrag_hf_demo.py +++ b/examples/lightrag_hf_demo.py @@ -30,7 +30,7 @@ rag = LightRAG( ) -with open("./book.txt") as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search diff --git a/examples/lightrag_lmdeploy_demo.py b/examples/lightrag_lmdeploy_demo.py new file mode 100644 index 00000000..aeb96f71 --- /dev/null +++ b/examples/lightrag_lmdeploy_demo.py @@ -0,0 +1,75 @@ +import os + +from lightrag import LightRAG, QueryParam +from lightrag.llm import lmdeploy_model_if_cache, hf_embedding +from lightrag.utils import EmbeddingFunc +from transformers import AutoModel, AutoTokenizer + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def lmdeploy_model_complete( + prompt=None, system_prompt=None, history_messages=[], **kwargs +) -> str: + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + return await lmdeploy_model_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + ## please specify chat_template if your local path does not follow original HF file name, + ## or model_name is a pytorch model on huggingface.co, + ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py + ## for a list of chat_template available in lmdeploy. + chat_template="llama3", + # model_format ='awq', # if you are using awq quantization model. + # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8. + **kwargs, + ) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=lmdeploy_model_complete, + llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=5000, + func=lambda texts: hf_embedding( + texts, + tokenizer=AutoTokenizer.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + embed_model=AutoModel.from_pretrained( + "sentence-transformers/all-MiniLM-L6-v2" + ), + ), + ), +) + + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index 6070131f..0a704024 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -28,7 +28,7 @@ rag = LightRAG( ), ) -with open("./book.txt") as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index fbad1190..2470fc00 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray: ) +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + return embedding_dim + + # function test async def test_funcs(): result = await llm_model_func("How are you?") @@ -43,37 +50,59 @@ async def test_funcs(): print("embedding_func: ", result) -asyncio.run(test_funcs()) +# asyncio.run(test_funcs()) -rag = LightRAG( - working_dir=WORKING_DIR, - llm_model_func=llm_model_func, - embedding_func=EmbeddingFunc( - embedding_dim=4096, max_token_size=8192, func=embedding_func - ), -) +async def main(): + try: + embedding_dimension = await get_embedding_dim() + print(f"Detected embedding dimension: {embedding_dimension}") + + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), + ) + + with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + + # Perform naive search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) + + # Perform local search + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + # Perform global search + print( + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="global"), + ) + ) + + # Perform hybrid search + print( + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode="hybrid"), + ) + ) + except Exception as e: + print(f"An error occurred: {e}") -with open("./book.txt") as f: - rag.insert(f.read()) - -# Perform naive search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) -) - -# Perform local search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) -) - -# Perform global search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) -) - -# Perform hybrid search -print( - rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) -) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index a6e7f3b2..29bc75ca 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -15,7 +15,7 @@ rag = LightRAG( ) -with open("./book.txt") as f: +with open("./book.txt", "r", encoding="utf-8") as f: rag.insert(f.read()) # Perform naive search diff --git a/examples/lightrag_siliconcloud_demo.py b/examples/lightrag_siliconcloud_demo.py new file mode 100644 index 00000000..a73f16c5 --- /dev/null +++ b/examples/lightrag_siliconcloud_demo.py @@ -0,0 +1,79 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding +from lightrag.utils import EmbeddingFunc +import numpy as np + +WORKING_DIR = "./dickens" + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "Qwen/Qwen2.5-7B-Instruct", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=os.getenv("SILICONFLOW_API_KEY"), + base_url="https://api.siliconflow.cn/v1/", + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await siliconcloud_embedding( + texts, + model="netease-youdao/bce-embedding-base_v1", + api_key=os.getenv("SILICONFLOW_API_KEY"), + max_token_size=512, + ) + + +# function test +async def test_funcs(): + result = await llm_model_func("How are you?") + print("llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("embedding_func: ", result) + + +asyncio.run(test_funcs()) + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=768, max_token_size=512, func=embedding_func + ), +) + + +with open("./book.txt") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) diff --git a/examples/vram_management_demo.py b/examples/vram_management_demo.py index ec750254..c173b913 100644 --- a/examples/vram_management_demo.py +++ b/examples/vram_management_demo.py @@ -27,11 +27,12 @@ rag = LightRAG( # Read all .txt files from the TEXT_FILES_DIR directory texts = [] for filename in os.listdir(TEXT_FILES_DIR): - if filename.endswith('.txt'): + if filename.endswith(".txt"): file_path = os.path.join(TEXT_FILES_DIR, filename) - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: texts.append(file.read()) + # Batch insert texts into LightRAG with a retry mechanism def insert_texts_with_retry(rag, texts, retries=3, delay=5): for _ in range(retries): @@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5): rag.insert(texts) return except Exception as e: - print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...") + print( + f"Error occurred during insertion: {e}. Retrying in {delay} seconds..." + ) time.sleep(delay) raise RuntimeError("Failed to insert texts after multiple retries.") + insert_texts_with_retry(rag, texts) # Perform different types of queries and handle potential errors try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) + ) except Exception as e: print(f"Error performing naive search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) except Exception as e: print(f"Error performing local search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) except Exception as e: print(f"Error performing global search: {e}") try: - print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) + ) except Exception as e: print(f"Error performing hybrid search: {e}") + # Function to clear VRAM resources def clear_vram(): os.system("sudo nvidia-smi --gpu-reset") + # Regularly clear VRAM to prevent overflow clear_vram_interval = 3600 # Clear once every hour start_time = time.time() diff --git a/lightrag/__init__.py b/lightrag/__init__.py index db81e005..8e76a260 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "0.0.7" +__version__ = "0.0.8" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/base.py b/lightrag/base.py index 50be4f62..cecd5edd 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -18,9 +18,13 @@ class QueryParam: mode: Literal["local", "global", "hybrid", "naive"] = "global" only_need_context: bool = False response_type: str = "Multiple Paragraphs" + # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. top_k: int = 60 + # Number of tokens for the original chunks. max_token_for_text_unit: int = 4000 + # Number of tokens for the relationship descriptions max_token_for_global_context: int = 4000 + # Number of tokens for the entity descriptions max_token_for_local_context: int = 4000 diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d4b1eaa1..955651fb 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -209,7 +209,7 @@ class LightRAG: logger.info("[Entity Extraction]...") maybe_new_kg = await extract_entities( inserting_chunks, - knwoledge_graph_inst=self.chunk_entity_relation_graph, + knowledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, relationships_vdb=self.relationships_vdb, global_config=asdict(self), diff --git a/lightrag/llm.py b/lightrag/llm.py index aa818995..f4045e80 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,10 +1,23 @@ import os import copy +from functools import lru_cache import json import aioboto3 +import aiohttp import numpy as np import ollama -from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout + +from openai import ( + AsyncOpenAI, + APIConnectionError, + RateLimitError, + Timeout, + AsyncAzureOpenAI, +) + +import base64 +import struct + from tenacity import ( retry, stop_after_attempt, @@ -13,6 +26,8 @@ from tenacity import ( ) from transformers import AutoTokenizer, AutoModelForCausalLM import torch +from pydantic import BaseModel, Field +from typing import List, Dict, Callable, Any from .base import BaseKVStorage from .utils import compute_args_hash, wrap_embedding_func_with_attrs @@ -62,6 +77,55 @@ async def openai_complete_if_cache( return response.choices[0].message.content +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_complete_if_cache( + model, + prompt, + system_prompt=None, + history_messages=[], + base_url=None, + api_key=None, + **kwargs, +): + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + ) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + if prompt is not None: + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + return response.choices[0].message.content + + class BedrockError(Exception): """Generic error for issues related to Amazon Bedrock""" @@ -151,15 +215,25 @@ async def bedrock_complete_if_cache( return response["output"]["message"]["content"][0]["text"] +@lru_cache(maxsize=1) +def initialize_hf_model(model_name): + hf_tokenizer = AutoTokenizer.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) + if hf_tokenizer.pad_token is None: + hf_tokenizer.pad_token = hf_tokenizer.eos_token + + return hf_model, hf_tokenizer + + async def hf_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: model_name = model - hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto") - if hf_tokenizer.pad_token is None: - # print("use eos token") - hf_tokenizer.pad_token = hf_tokenizer.eos_token - hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") + hf_model, hf_tokenizer = initialize_hf_model(model_name) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: @@ -208,10 +282,13 @@ async def hf_model_if_cache( input_ids = hf_tokenizer( input_prompt, return_tensors="pt", padding=True, truncation=True ).to("cuda") + inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} output = hf_model.generate( - **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True + **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True + ) + response_text = hf_tokenizer.decode( + output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) - response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True) if hashing_kv is not None: await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}}) return response_text @@ -249,6 +326,135 @@ async def ollama_model_if_cache( return result +@lru_cache(maxsize=1) +def initialize_lmdeploy_pipeline( + model, + tp=1, + chat_template=None, + log_level="WARNING", + model_format="hf", + quant_policy=0, +): + from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig + + lmdeploy_pipe = pipeline( + model_path=model, + backend_config=TurbomindEngineConfig( + tp=tp, model_format=model_format, quant_policy=quant_policy + ), + chat_template_config=ChatTemplateConfig(model_name=chat_template) + if chat_template + else None, + log_level="WARNING", + ) + return lmdeploy_pipe + + +async def lmdeploy_model_if_cache( + model, + prompt, + system_prompt=None, + history_messages=[], + chat_template=None, + model_format="hf", + quant_policy=0, + **kwargs, +) -> str: + """ + Args: + model (str): The path to the model. + It could be one of the following options: + - i) A local directory path of a turbomind model which is + converted by `lmdeploy convert` command or download + from ii) and iii). + - ii) The model_id of a lmdeploy-quantized model hosted + inside a model repo on huggingface.co, such as + "InternLM/internlm-chat-20b-4bit", + "lmdeploy/llama2-chat-70b-4bit", etc. + - iii) The model_id of a model hosted inside a model repo + on huggingface.co, such as "internlm/internlm-chat-7b", + "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" + and so on. + chat_template (str): needed when model is a pytorch model on + huggingface.co, such as "internlm-chat-7b", + "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on, + and when the model name of local path did not match the original model name in HF. + tp (int): tensor parallel + prompt (Union[str, List[str]]): input texts to be completed. + do_preprocess (bool): whether pre-process the messages. Default to + True, which means chat_template will be applied. + skip_special_tokens (bool): Whether or not to remove special tokens + in the decoding. Default to be True. + do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. + Default to be False, which means greedy decoding will be applied. + """ + try: + import lmdeploy + from lmdeploy import version_info, GenerationConfig + except Exception: + raise ImportError("Please install lmdeploy before intialize lmdeploy backend.") + + kwargs.pop("response_format", None) + max_new_tokens = kwargs.pop("max_tokens", 512) + tp = kwargs.pop("tp", 1) + skip_special_tokens = kwargs.pop("skip_special_tokens", True) + do_preprocess = kwargs.pop("do_preprocess", True) + do_sample = kwargs.pop("do_sample", False) + gen_params = kwargs + + version = version_info + if do_sample is not None and version < (0, 6, 0): + raise RuntimeError( + "`do_sample` parameter is not supported by lmdeploy until " + f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}" + ) + else: + do_sample = True + gen_params.update(do_sample=do_sample) + + lmdeploy_pipe = initialize_lmdeploy_pipeline( + model=model, + tp=tp, + chat_template=chat_template, + model_format=model_format, + quant_policy=quant_policy, + log_level="WARNING", + ) + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + gen_config = GenerationConfig( + skip_special_tokens=skip_special_tokens, + max_new_tokens=max_new_tokens, + **gen_params, + ) + + response = "" + async for res in lmdeploy_pipe.generate( + messages, + gen_config=gen_config, + do_preprocess=do_preprocess, + stream_response=False, + session_id=1, + ): + response += res.response + + if hashing_kv is not None: + await hashing_kv.upsert({args_hash: {"return": response, "model": model}}) + return response + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -273,6 +479,18 @@ async def gpt_4o_mini_complete( ) +async def azure_openai_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await azure_openai_complete_if_cache( + "conversation-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -314,7 +532,7 @@ async def ollama_model_complete( @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), + wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) async def openai_embedding( @@ -335,6 +553,73 @@ async def openai_embedding( return np.array([dp.embedding for dp in response.data]) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def azure_openai_embedding( + texts: list[str], + model: str = "text-embedding-3-small", + base_url: str = None, + api_key: str = None, +) -> np.ndarray: + if api_key: + os.environ["AZURE_OPENAI_API_KEY"] = api_key + if base_url: + os.environ["AZURE_OPENAI_ENDPOINT"] = base_url + + openai_async_client = AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_key=os.getenv("AZURE_OPENAI_API_KEY"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + ) + + response = await openai_async_client.embeddings.create( + model=model, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data]) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def siliconcloud_embedding( + texts: list[str], + model: str = "netease-youdao/bce-embedding-base_v1", + base_url: str = "https://api.siliconflow.cn/v1/embeddings", + max_token_size: int = 512, + api_key: str = None, +) -> np.ndarray: + if api_key and not api_key.startswith("Bearer "): + api_key = "Bearer " + api_key + + headers = {"Authorization": api_key, "Content-Type": "application/json"} + + truncate_texts = [text[0:max_token_size] for text in texts] + + payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} + + base64_strings = [] + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=payload) as response: + content = await response.json() + if "code" in content: + raise ValueError(content) + base64_strings = [item["embedding"] for item in content["data"]] + + embeddings = [] + for string in base64_strings: + decode_bytes = base64.b64decode(string) + n = len(decode_bytes) // 4 + float_array = struct.unpack("<" + "f" * n, decode_bytes) + embeddings.append(float_array) + return np.array(embeddings) + + # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry( # stop=stop_after_attempt(3), @@ -427,6 +712,85 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra return embed_text +class Model(BaseModel): + """ + This is a Pydantic model class named 'Model' that is used to define a custom language model. + + Attributes: + gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. + The function should take any argument and return a string. + kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. + This could include parameters such as the model name, API key, etc. + + Example usage: + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) + + In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. + The 'kwargs' dictionary contains the model name and API key to be passed to the function. + """ + + gen_func: Callable[[Any], str] = Field( + ..., + description="A function that generates the response from the llm. The response must be a string", + ) + kwargs: Dict[str, Any] = Field( + ..., + description="The arguments to pass to the callable function. Eg. the api key, model name, etc", + ) + + class Config: + arbitrary_types_allowed = True + + +class MultiModel: + """ + Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. + Could also be used for spliting across diffrent models or providers. + + Attributes: + models (List[Model]): A list of language models to be used. + + Usage example: + ```python + models = [ + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), + Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), + ] + multi_model = MultiModel(models) + rag = LightRAG( + llm_model_func=multi_model.llm_model_func + / ..other args + ) + ``` + """ + + def __init__(self, models: List[Model]): + self._models = models + self._current_model = 0 + + def _next_model(self): + self._current_model = (self._current_model + 1) % len(self._models) + return self._models[self._current_model] + + async def llm_model_func( + self, prompt, system_prompt=None, history_messages=[], **kwargs + ) -> str: + kwargs.pop("model", None) # stop from overwriting the custom model name + next_model = self._next_model() + args = dict( + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + **next_model.kwargs, + ) + + return await next_model.gen_func(**args) + + if __name__ == "__main__": import asyncio diff --git a/lightrag/operate.py b/lightrag/operate.py index a0729cd8..8a6820f5 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction( async def _merge_nodes_then_upsert( entity_name: str, nodes_data: list[dict], - knwoledge_graph_inst: BaseGraphStorage, + knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): already_entitiy_types = [] already_source_ids = [] already_description = [] - already_node = await knwoledge_graph_inst.get_node(entity_name) + already_node = await knowledge_graph_inst.get_node(entity_name) if already_node is not None: already_entitiy_types.append(already_node["entity_type"]) already_source_ids.extend( @@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert( description=description, source_id=source_id, ) - await knwoledge_graph_inst.upsert_node( + await knowledge_graph_inst.upsert_node( entity_name, node_data=node_data, ) @@ -172,7 +172,7 @@ async def _merge_edges_then_upsert( src_id: str, tgt_id: str, edges_data: list[dict], - knwoledge_graph_inst: BaseGraphStorage, + knowledge_graph_inst: BaseGraphStorage, global_config: dict, ): already_weights = [] @@ -180,8 +180,8 @@ async def _merge_edges_then_upsert( already_description = [] already_keywords = [] - if await knwoledge_graph_inst.has_edge(src_id, tgt_id): - already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id) + if await knowledge_graph_inst.has_edge(src_id, tgt_id): + already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) already_weights.append(already_edge["weight"]) already_source_ids.extend( split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) @@ -202,8 +202,8 @@ async def _merge_edges_then_upsert( set([dp["source_id"] for dp in edges_data] + already_source_ids) ) for need_insert_id in [src_id, tgt_id]: - if not (await knwoledge_graph_inst.has_node(need_insert_id)): - await knwoledge_graph_inst.upsert_node( + if not (await knowledge_graph_inst.has_node(need_insert_id)): + await knowledge_graph_inst.upsert_node( need_insert_id, node_data={ "source_id": source_id, @@ -214,7 +214,7 @@ async def _merge_edges_then_upsert( description = await _handle_entity_relation_summary( (src_id, tgt_id), description, global_config ) - await knwoledge_graph_inst.upsert_edge( + await knowledge_graph_inst.upsert_edge( src_id, tgt_id, edge_data=dict( @@ -237,7 +237,7 @@ async def _merge_edges_then_upsert( async def extract_entities( chunks: dict[str, TextChunkSchema], - knwoledge_graph_inst: BaseGraphStorage, + knowledge_graph_inst: BaseGraphStorage, entity_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage, global_config: dict, @@ -341,13 +341,13 @@ async def extract_entities( maybe_edges[tuple(sorted(k))].extend(v) all_entities_data = await asyncio.gather( *[ - _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config) + _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config) for k, v in maybe_nodes.items() ] ) all_relationships_data = await asyncio.gather( *[ - _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config) + _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config) for k, v in maybe_edges.items() ] ) @@ -384,7 +384,7 @@ async def extract_entities( } await relationships_vdb.upsert(data_for_vdb) - return knwoledge_graph_inst + return knowledge_graph_inst async def local_query( diff --git a/lightrag/utils.py b/lightrag/utils.py index 9a68c16b..0da4a51a 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -185,6 +185,7 @@ def save_data_to_file(data, file_name): with open(file_name, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) + def xml_to_json(xml_file): try: tree = ET.parse(xml_file) @@ -194,31 +195,42 @@ def xml_to_json(xml_file): print(f"Root element: {root.tag}") print(f"Root attributes: {root.attrib}") - data = { - "nodes": [], - "edges": [] - } + data = {"nodes": [], "edges": []} # Use namespace - namespace = {'': 'http://graphml.graphdrawing.org/xmlns'} + namespace = {"": "http://graphml.graphdrawing.org/xmlns"} - for node in root.findall('.//node', namespace): + for node in root.findall(".//node", namespace): node_data = { - "id": node.get('id').strip('"'), - "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "", - "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "", - "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else "" + "id": node.get("id").strip('"'), + "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') + if node.find("./data[@key='d0']", namespace) is not None + else "", + "description": node.find("./data[@key='d1']", namespace).text + if node.find("./data[@key='d1']", namespace) is not None + else "", + "source_id": node.find("./data[@key='d2']", namespace).text + if node.find("./data[@key='d2']", namespace) is not None + else "", } data["nodes"].append(node_data) - for edge in root.findall('.//edge', namespace): + for edge in root.findall(".//edge", namespace): edge_data = { - "source": edge.get('source').strip('"'), - "target": edge.get('target').strip('"'), - "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0, - "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "", - "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "", - "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else "" + "source": edge.get("source").strip('"'), + "target": edge.get("target").strip('"'), + "weight": float(edge.find("./data[@key='d3']", namespace).text) + if edge.find("./data[@key='d3']", namespace) is not None + else 0.0, + "description": edge.find("./data[@key='d4']", namespace).text + if edge.find("./data[@key='d4']", namespace) is not None + else "", + "keywords": edge.find("./data[@key='d5']", namespace).text + if edge.find("./data[@key='d5']", namespace) is not None + else "", + "source_id": edge.find("./data[@key='d6']", namespace).text + if edge.find("./data[@key='d6']", namespace) is not None + else "", } data["edges"].append(edge_data) diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py index 2be5ea5c..5e2ef778 100644 --- a/reproduce/Step_3_openai_compatible.py +++ b/reproduce/Step_3_openai_compatible.py @@ -50,8 +50,8 @@ def extract_queries(file_path): async def process_query(query_text, rag_instance, query_param): try: - result, context = await rag_instance.aquery(query_text, param=query_param) - return {"query": query_text, "result": result, "context": context}, None + result = await rag_instance.aquery(query_text, param=query_param) + return {"query": query_text, "result": result}, None except Exception as e: return None, {"query": query_text, "error": str(e)} diff --git a/requirements.txt b/requirements.txt index 9cc5b7e9..6b0e025a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,16 @@ accelerate aioboto3 +aiohttp graspologic hnswlib nano-vectordb networkx ollama openai +pyvis tenacity tiktoken torch transformers xxhash -pyvis \ No newline at end of file +# lmdeploy[all] diff --git a/setup.py b/setup.py index 47222420..1b1f65f0 100644 --- a/setup.py +++ b/setup.py @@ -1,39 +1,88 @@ import setuptools - -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() +from pathlib import Path -vars2find = ["__author__", "__version__", "__url__"] -vars2readme = {} -with open("./lightrag/__init__.py") as f: - for line in f.readlines(): - for v in vars2find: - if line.startswith(v): - line = line.replace(" ", "").replace('"', "").replace("'", "").strip() - vars2readme[v] = line.split("=")[1] +# Reading the long description from README.md +def read_long_description(): + try: + return Path("README.md").read_text(encoding="utf-8") + except FileNotFoundError: + return "A description of LightRAG is currently unavailable." -deps = [] -with open("./requirements.txt") as f: - for line in f.readlines(): - if not line.strip(): - continue - deps.append(line.strip()) + +# Retrieving metadata from __init__.py +def retrieve_metadata(): + vars2find = ["__author__", "__version__", "__url__"] + vars2readme = {} + try: + with open("./lightrag/__init__.py") as f: + for line in f.readlines(): + for v in vars2find: + if line.startswith(v): + line = ( + line.replace(" ", "") + .replace('"', "") + .replace("'", "") + .strip() + ) + vars2readme[v] = line.split("=")[1] + except FileNotFoundError: + raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.") + + # Checking if all required variables are found + missing_vars = [v for v in vars2find if v not in vars2readme] + if missing_vars: + raise ValueError( + f"Missing required metadata variables in __init__.py: {missing_vars}" + ) + + return vars2readme + + +# Reading dependencies from requirements.txt +def read_requirements(): + deps = [] + try: + with open("./requirements.txt") as f: + deps = [line.strip() for line in f if line.strip()] + except FileNotFoundError: + print( + "Warning: 'requirements.txt' not found. No dependencies will be installed." + ) + return deps + + +metadata = retrieve_metadata() +long_description = read_long_description() +requirements = read_requirements() setuptools.setup( name="lightrag-hku", - url=vars2readme["__url__"], - version=vars2readme["__version__"], - author=vars2readme["__author__"], + url=metadata["__url__"], + version=metadata["__version__"], + author=metadata["__author__"], description="LightRAG: Simple and Fast Retrieval-Augmented Generation", long_description=long_description, long_description_content_type="text/markdown", - packages=["lightrag"], + packages=setuptools.find_packages( + exclude=("tests*", "docs*") + ), # Automatically find packages classifiers=[ + "Development Status :: 4 - Beta", "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries :: Python Modules", ], python_requires=">=3.9", - install_requires=deps, + install_requires=requirements, + include_package_data=True, # Includes non-code files from MANIFEST.in + project_urls={ # Additional project metadata + "Documentation": metadata.get("__url__", ""), + "Source": metadata.get("__url__", ""), + "Tracker": f"{metadata.get('__url__', '')}/issues" + if metadata.get("__url__") + else "", + }, )