Add LlamaIndex LLM implementation module
- Implemented LlamaIndex interface for language model interactions - Added async chat completion support - Included embedding generation functionality - Implemented retry mechanisms for API calls - Added configuration and message formatting utilities - Supports OpenAI-style message handling and external settings
This commit is contained in:
249
lightrag/llm/llama_index_impl.py
Normal file
249
lightrag/llm/llama_index_impl.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
LlamaIndex LLM Interface Module
|
||||
==========================
|
||||
|
||||
This module provides interfaces for interacting with LlamaIndex's language models,
|
||||
including text generation and embedding capabilities.
|
||||
|
||||
Author: Lightrag team
|
||||
Created: 2024-03-19
|
||||
License: MIT License
|
||||
|
||||
Version: 1.0.0
|
||||
|
||||
Change Log:
|
||||
- 1.0.0 (2024-03-19): Initial release
|
||||
* Added async chat completion support
|
||||
* Added embedding generation
|
||||
* Added stream response capability
|
||||
* Added support for external settings configuration
|
||||
* Added OpenAI-style message handling
|
||||
|
||||
Dependencies:
|
||||
- llama_index
|
||||
- numpy
|
||||
- pipmaster
|
||||
- Python >= 3.10
|
||||
|
||||
Usage:
|
||||
from lightrag.llm.llama_index_impl import llama_index_complete, llama_index_embed
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "lightrag Team"
|
||||
__status__ = "Production"
|
||||
|
||||
import pipmaster as pm
|
||||
from core.logging_config import setup_logger
|
||||
from llama_index.core.llms import (
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
ChatResponse,
|
||||
)
|
||||
from typing import List, Optional
|
||||
|
||||
# Install required dependencies
|
||||
if not pm.is_installed("llama-index"):
|
||||
pm.install("llama-index")
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.settings import Settings as LlamaIndexSettings
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
logger = setup_logger("lightrag.llm.llama_index_impl")
|
||||
|
||||
|
||||
def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs):
|
||||
"""
|
||||
Configure LlamaIndex settings.
|
||||
|
||||
Args:
|
||||
settings: LlamaIndex Settings instance. If None, uses default settings.
|
||||
**kwargs: Additional settings to override/configure
|
||||
"""
|
||||
if settings is None:
|
||||
settings = LlamaIndexSettings()
|
||||
|
||||
# Update settings with any provided kwargs
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
else:
|
||||
logger.warning(f"Unknown LlamaIndex setting: {key}")
|
||||
|
||||
# Set as global settings
|
||||
LlamaIndexSettings.set_global(settings)
|
||||
return settings
|
||||
|
||||
|
||||
def format_chat_messages(messages):
|
||||
"""Format chat messages into LlamaIndex format."""
|
||||
formatted_messages = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=content)
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.ASSISTANT, content=content)
|
||||
)
|
||||
elif role == "user":
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.USER, content=content)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unknown role {role}, treating as user message")
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.USER, content=content)
|
||||
)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def llama_index_complete_if_cache(
|
||||
model: str,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[dict] = [],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Complete the prompt using LlamaIndex."""
|
||||
try:
|
||||
# Format messages for chat
|
||||
formatted_messages = []
|
||||
|
||||
# Add system message if provided
|
||||
if system_prompt:
|
||||
formatted_messages.append(
|
||||
ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
|
||||
)
|
||||
|
||||
# Add history messages
|
||||
for msg in history_messages:
|
||||
formatted_messages.append(
|
||||
ChatMessage(
|
||||
role=MessageRole.USER
|
||||
if msg["role"] == "user"
|
||||
else MessageRole.ASSISTANT,
|
||||
content=msg["content"],
|
||||
)
|
||||
)
|
||||
|
||||
# Add current prompt
|
||||
formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt))
|
||||
|
||||
# Get LLM instance from kwargs
|
||||
if "llm_instance" not in kwargs:
|
||||
raise ValueError("llm_instance must be provided in kwargs")
|
||||
llm = kwargs["llm_instance"]
|
||||
|
||||
# Get response
|
||||
response: ChatResponse = await llm.achat(messages=formatted_messages)
|
||||
|
||||
# In newer versions, the response is in message.content
|
||||
content = response.message.content
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in llama_index_complete_if_cache: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def llama_index_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
keyword_extraction=False,
|
||||
settings: LlamaIndexSettings = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Main completion function for LlamaIndex
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
system_prompt: Optional system prompt
|
||||
history_messages: Optional chat history
|
||||
keyword_extraction: Whether to extract keywords from response
|
||||
settings: Optional LlamaIndex settings
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
result = await llama_index_complete_if_cache(
|
||||
kwargs.get("llm_instance"),
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction:
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(
|
||||
(RateLimitError, APIConnectionError, APITimeoutError)
|
||||
),
|
||||
)
|
||||
async def llama_index_embed(
|
||||
texts: list[str],
|
||||
embed_model: BaseEmbedding = None,
|
||||
settings: LlamaIndexSettings = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings using LlamaIndex
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
embed_model: LlamaIndex embedding model
|
||||
settings: Optional LlamaIndex settings
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
if settings:
|
||||
configure_llama_index(settings)
|
||||
|
||||
if embed_model is None:
|
||||
raise ValueError("embed_model must be provided")
|
||||
|
||||
# LlamaIndex's embed_query returns a list of floats
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await embed_model.aembed_query(text)
|
||||
embeddings.append(embedding)
|
||||
|
||||
return np.array(embeddings)
|
Reference in New Issue
Block a user