diff --git a/README.md b/README.md index 97d6279c..432261f7 100644 --- a/README.md +++ b/README.md @@ -312,7 +312,45 @@ rag = LightRAG( In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`. +
+ Wrappers +LightRAG supports integration with various frameworks and model providers through wrappers. These wrappers provide a consistent interface while abstracting away the specifics of each framework. + +### Current Wrappers + +1. **LlamaIndex** (`wrapper/llama_index_impl.py`): + - Integrates with OpenAI and other providers through LlamaIndex + - Supports both direct API access and proxy services like LiteLLM + - Provides consistent interfaces for embeddings and completions + - See [LlamaIndex Wrapper Documentation](lightrag/wrapper/Readme.md) for detailed setup and examples + +### Example Usage + +```python +# Using LlamaIndex with direct OpenAI access +from lightrag import LightRAG +from lightrag.wrapper.llama_index_impl import llama_index_complete_if_cache, llama_index_embed +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.llms.openai import OpenAI + +rag = LightRAG( + working_dir="your/path", + llm_model_func=llm_model_func, # LlamaIndex-compatible completion function + embedding_func=EmbeddingFunc( # LlamaIndex-compatible embedding function + embedding_dim=1536, + max_token_size=8192, + func=lambda texts: llama_index_embed(texts, embed_model=embed_model) + ), +) +``` + +#### For detailed documentation and examples, see: +- [LlamaIndex Wrapper Documentation](lightrag/wrapper/Readme.md) +- [Direct OpenAI Example](examples/lightrag_api_llamaindex_direct_demo_simplified.py) +- [LiteLLM Proxy Example](examples/lightrag_api_llamaindex_litellm_demo_simplified.py) + +
Conversation History Support diff --git a/examples/lightrag_api_llamaindex_direct_demo_simplified.py b/examples/lightrag_api_llamaindex_direct_demo_simplified.py new file mode 100644 index 00000000..50dfec96 --- /dev/null +++ b/examples/lightrag_api_llamaindex_direct_demo_simplified.py @@ -0,0 +1,98 @@ +import os +from lightrag import LightRAG, QueryParam +from lightrag.wrapper.llama_index_impl import llama_index_complete_if_cache, llama_index_embed +from lightrag.utils import EmbeddingFunc +from llama_index.llms.openai import OpenAI +from llama_index.embeddings.openai import OpenAIEmbedding +import asyncio + +# Configure working directory +DEFAULT_RAG_DIR = "index_default" +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") + +# Model configuration +LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4") +print(f"LLM_MODEL: {LLM_MODEL}") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-small") +print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) +print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") + +# OpenAI configuration +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "your-api-key-here") + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# Initialize LLM function +async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + try: + # Initialize OpenAI if not in kwargs + if 'llm_instance' not in kwargs: + llm_instance = OpenAI( + model=LLM_MODEL, + api_key=OPENAI_API_KEY, + temperature=0.7, + ) + kwargs['llm_instance'] = llm_instance + + response = await llama_index_complete_if_cache( + kwargs['llm_instance'], + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + return response + except Exception as e: + print(f"LLM request failed: {str(e)}") + raise + +# Initialize embedding function +async def embedding_func(texts): + try: + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, + api_key=OPENAI_API_KEY, + ) + return await llama_index_embed(texts, embed_model=embed_model) + except Exception as e: + print(f"Embedding failed: {str(e)}") + raise + +# Get embedding dimension +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + print(f"embedding_dim={embedding_dim}") + return embedding_dim + +# Initialize RAG instance +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=asyncio.run(get_embedding_dim()), + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, + func=embedding_func, + ), +) + +# Insert example text +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Test different query modes +print("\nNaive Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) + +print("\nLocal Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) + +print("\nGlobal Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) + +print("\nHybrid Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) \ No newline at end of file diff --git a/examples/lightrag_api_llamaindex_litellm_demo_simplified.py b/examples/lightrag_api_llamaindex_litellm_demo_simplified.py new file mode 100644 index 00000000..11bdeba8 --- /dev/null +++ b/examples/lightrag_api_llamaindex_litellm_demo_simplified.py @@ -0,0 +1,102 @@ +import os +from lightrag import LightRAG, QueryParam +from lightrag.wrapper.llama_index_impl import llama_index_complete_if_cache, llama_index_embed +from lightrag.utils import EmbeddingFunc +from llama_index.llms.litellm import LiteLLM +from llama_index.embeddings.litellm import LiteLLMEmbedding +import asyncio + +# Configure working directory +DEFAULT_RAG_DIR = "index_default" +WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}") +print(f"WORKING_DIR: {WORKING_DIR}") + +# Model configuration +LLM_MODEL = os.environ.get("LLM_MODEL", "gpt-4o") +print(f"LLM_MODEL: {LLM_MODEL}") +EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "embedding-model") +print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}") +EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) +print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") + +# LiteLLM configuration +LITELLM_URL = os.environ.get("LITELLM_URL", "http://localhost:4000") +print(f"LITELLM_URL: {LITELLM_URL}") +LITELLM_KEY = os.environ.get("LITELLM_KEY", "sk-1234") + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +# Initialize LLM function +async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + try: + # Initialize LiteLLM if not in kwargs + if 'llm_instance' not in kwargs: + llm_instance = LiteLLM( + model=f"openai/{LLM_MODEL}", # Format: "provider/model_name" + api_base=LITELLM_URL, + api_key=LITELLM_KEY, + temperature=0.7, + ) + kwargs['llm_instance'] = llm_instance + + response = await llama_index_complete_if_cache( + kwargs['llm_instance'], + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + return response + except Exception as e: + print(f"LLM request failed: {str(e)}") + raise + +# Initialize embedding function +async def embedding_func(texts): + try: + embed_model = LiteLLMEmbedding( + model_name=f"openai/{EMBEDDING_MODEL}", + api_base=LITELLM_URL, + api_key=LITELLM_KEY, + ) + return await llama_index_embed(texts, embed_model=embed_model) + except Exception as e: + print(f"Embedding failed: {str(e)}") + raise + +# Get embedding dimension +async def get_embedding_dim(): + test_text = ["This is a test sentence."] + embedding = await embedding_func(test_text) + embedding_dim = embedding.shape[1] + print(f"embedding_dim={embedding_dim}") + return embedding_dim + +# Initialize RAG instance +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=asyncio.run(get_embedding_dim()), + max_token_size=EMBEDDING_MAX_TOKEN_SIZE, + func=embedding_func, + ), +) + +# Insert example text +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Test different query modes +print("\nNaive Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) + +print("\nLocal Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) + +print("\nGlobal Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) + +print("\nHybrid Search:") +print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) \ No newline at end of file diff --git a/lightrag/wrapper/Readme.md b/lightrag/wrapper/Readme.md new file mode 100644 index 00000000..ece56458 --- /dev/null +++ b/lightrag/wrapper/Readme.md @@ -0,0 +1,177 @@ +## Wrapper Directory + +The `wrapper` directory contains integrations with different frameworks. These wrappers provide a consistent interface to LightRAG while abstracting away the specifics of each framework. + +## Wrapper Directory Structure + +``` +lightrag/ +├── wrapper/ # Wrappers for different model providers and frameworks +│ ├── llama_index_impl.py # LlamaIndex integration for embeddings and completions +│ └── ... # Other framework wrappers +├── kg/ # Knowledge graph implementations +├── utils/ # Utility functions and helpers +└── ... +``` +Current wrappers: + +1. **LlamaIndex** (`wrapper/llama_index.py`): + - Provides integration with OpenAI and other providers through LlamaIndex + - Supports both direct API access and proxy services like LiteLLM + - Handles embeddings and completions with consistent interfaces + - See example implementations: + - [Direct OpenAI Usage](../examples/lightrag_api_llamaindex_direct_demo_simplified.py) + - [LiteLLM Proxy Usage](../examples/lightrag_api_llamaindex_litellm_demo_simplified.py) + +
+ Using LlamaIndex + +LightRAG supports LlamaIndex for embeddings and completions in two ways: direct OpenAI usage or through LiteLLM proxy. + +### Setup + +First, install the required dependencies: +```bash +pip install llama-index-llms-litellm llama-index-embeddings-litellm +``` + +### Standard OpenAI Usage + +```python +from lightrag import LightRAG +from lightrag.wrapper.llama_index_impl import llama_index_complete_if_cache, llama_index_embed +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.llms.openai import OpenAI +from lightrag.utils import EmbeddingFunc + +# Initialize with direct OpenAI access +async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + try: + # Initialize OpenAI if not in kwargs + if 'llm_instance' not in kwargs: + llm_instance = OpenAI( + model="gpt-4", + api_key="your-openai-key", + temperature=0.7, + ) + kwargs['llm_instance'] = llm_instance + + response = await llama_index_complete_if_cache( + kwargs['llm_instance'], + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + return response + except Exception as e: + logger.error(f"LLM request failed: {str(e)}") + raise + +# Initialize LightRAG with OpenAI +rag = LightRAG( + working_dir="your/path", + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=lambda texts: llama_index_embed( + texts, + embed_model=OpenAIEmbedding( + model="text-embedding-3-large", + api_key="your-openai-key" + ) + ), + ), +) +``` + +### Using LiteLLM Proxy + +1. Use any LLM provider through LiteLLM +2. Leverage LlamaIndex's embedding and completion capabilities +3. Maintain consistent configuration across services + +```python +from lightrag import LightRAG +from lightrag.wrapper.llama_index_impl import llama_index_complete_if_cache, llama_index_embed +from llama_index.llms.litellm import LiteLLM +from llama_index.embeddings.litellm import LiteLLMEmbedding +from lightrag.utils import EmbeddingFunc + +# Initialize with LiteLLM proxy +async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + try: + # Initialize LiteLLM if not in kwargs + if 'llm_instance' not in kwargs: + llm_instance = LiteLLM( + model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name" + api_base=settings.LITELLM_URL, + api_key=settings.LITELLM_KEY, + temperature=0.7, + ) + kwargs['llm_instance'] = llm_instance + + response = await llama_index_complete_if_cache( + kwargs['llm_instance'], + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + return response + except Exception as e: + logger.error(f"LLM request failed: {str(e)}") + raise + +# Initialize LightRAG with LiteLLM +rag = LightRAG( + working_dir="your/path", + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=1536, + max_token_size=8192, + func=lambda texts: llama_index_embed( + texts, + embed_model=LiteLLMEmbedding( + model_name=f"openai/{settings.EMBEDDING_MODEL}", + api_base=settings.LITELLM_URL, + api_key=settings.LITELLM_KEY, + ) + ), + ), +) +``` + +### Environment Variables + +For OpenAI direct usage: +```bash +OPENAI_API_KEY=your-openai-key +``` + +For LiteLLM proxy: +```bash +# LiteLLM Configuration +LITELLM_URL=http://litellm:4000 +LITELLM_KEY=your-litellm-key + +# Model Configuration +LLM_MODEL=gpt-4 +EMBEDDING_MODEL=text-embedding-3-large +EMBEDDING_MAX_TOKEN_SIZE=8192 +``` + +### Key Differences +1. **Direct OpenAI**: + - Simpler setup + - Direct API access + - Requires OpenAI API key + +2. **LiteLLM Proxy**: + - Model provider agnostic + - Centralized API key management + - Support for multiple providers + - Better cost control and monitoring + +
diff --git a/lightrag/wrapper/__init__.py b/lightrag/wrapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightrag/wrapper/llama_index_impl.py b/lightrag/wrapper/llama_index_impl.py new file mode 100644 index 00000000..f79dade5 --- /dev/null +++ b/lightrag/wrapper/llama_index_impl.py @@ -0,0 +1,207 @@ +import pipmaster as pm +from llama_index.core.llms import ( + ChatMessage, + MessageRole, + ChatResponse, +) +from typing import List, Optional + +# Install required dependencies +if not pm.is_installed("llama-index"): + pm.install("llama-index") + +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.settings import Settings as LlamaIndexSettings +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) +from lightrag.utils import ( + wrap_embedding_func_with_attrs, + locate_json_string_body_from_string, +) +from lightrag.exceptions import ( + APIConnectionError, + RateLimitError, + APITimeoutError, +) +import numpy as np + + +def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs): + """ + Configure LlamaIndex settings. + + Args: + settings: LlamaIndex Settings instance. If None, uses default settings. + **kwargs: Additional settings to override/configure + """ + if settings is None: + settings = LlamaIndexSettings() + + # Update settings with any provided kwargs + for key, value in kwargs.items(): + if hasattr(settings, key): + setattr(settings, key, value) + else: + logger.warning(f"Unknown LlamaIndex setting: {key}") + + # Set as global settings + LlamaIndexSettings.set_global(settings) + return settings + + +def format_chat_messages(messages): + """Format chat messages into LlamaIndex format.""" + formatted_messages = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "system": + formatted_messages.append( + ChatMessage(role=MessageRole.SYSTEM, content=content) + ) + elif role == "assistant": + formatted_messages.append( + ChatMessage(role=MessageRole.ASSISTANT, content=content) + ) + elif role == "user": + formatted_messages.append( + ChatMessage(role=MessageRole.USER, content=content) + ) + else: + logger.warning(f"Unknown role {role}, treating as user message") + formatted_messages.append( + ChatMessage(role=MessageRole.USER, content=content) + ) + + return formatted_messages + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), +) +async def llama_index_complete_if_cache( + model: str, + prompt: str, + system_prompt: Optional[str] = None, + history_messages: List[dict] = [], + **kwargs, +) -> str: + """Complete the prompt using LlamaIndex.""" + try: + # Format messages for chat + formatted_messages = [] + + # Add system message if provided + if system_prompt: + formatted_messages.append( + ChatMessage(role=MessageRole.SYSTEM, content=system_prompt) + ) + + # Add history messages + for msg in history_messages: + formatted_messages.append( + ChatMessage( + role=MessageRole.USER + if msg["role"] == "user" + else MessageRole.ASSISTANT, + content=msg["content"], + ) + ) + + # Add current prompt + formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt)) + + # Get LLM instance from kwargs + if "llm_instance" not in kwargs: + raise ValueError("llm_instance must be provided in kwargs") + llm = kwargs["llm_instance"] + + # Get response + response: ChatResponse = await llm.achat(messages=formatted_messages) + + # In newer versions, the response is in message.content + content = response.message.content + return content + + except Exception as e: + logger.error(f"Error in llama_index_complete_if_cache: {str(e)}") + raise + + +async def llama_index_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + settings: LlamaIndexSettings = None, + **kwargs, +) -> str: + """ + Main completion function for LlamaIndex + + Args: + prompt: Input prompt + system_prompt: Optional system prompt + history_messages: Optional chat history + keyword_extraction: Whether to extract keywords from response + settings: Optional LlamaIndex settings + **kwargs: Additional arguments + """ + if history_messages is None: + history_messages = [] + + keyword_extraction = kwargs.pop("keyword_extraction", None) + result = await llama_index_complete_if_cache( + kwargs.get("llm_instance"), + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + if keyword_extraction: + return locate_json_string_body_from_string(result) + return result + + +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), +) +async def llama_index_embed( + texts: list[str], + embed_model: BaseEmbedding = None, + settings: LlamaIndexSettings = None, + **kwargs, +) -> np.ndarray: + """ + Generate embeddings using LlamaIndex + + Args: + texts: List of texts to embed + embed_model: LlamaIndex embedding model + settings: Optional LlamaIndex settings + **kwargs: Additional arguments + """ + if settings: + configure_llama_index(settings) + + if embed_model is None: + raise ValueError("embed_model must be provided") + + # Use _get_text_embeddings for batch processing + embeddings = embed_model._get_text_embeddings(texts) + return np.array(embeddings)