diff --git a/examples/lightrag_bedrock_demo.py b/examples/lightrag_bedrock_demo.py index 36ec3857..c515922e 100644 --- a/examples/lightrag_bedrock_demo.py +++ b/examples/lightrag_bedrock_demo.py @@ -3,46 +3,39 @@ LightRAG meets Amazon Bedrock ⛰️ """ import os +import logging from lightrag import LightRAG, QueryParam from lightrag.llm import bedrock_complete, bedrock_embedding from lightrag.utils import EmbeddingFunc -WORKING_DIR = "./dickens" +logging.getLogger("aiobotocore").setLevel(logging.WARNING) +WORKING_DIR = "./dickens" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=bedrock_complete, - llm_model_name="anthropic.claude-3-haiku-20240307-v1:0", - node2vec_params = { - 'dimensions': 1024, - 'num_walks': 10, - 'walk_length': 40, - 'window_size': 2, - 'iterations': 3, - 'random_seed': 3 - }, + llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock", embedding_func=EmbeddingFunc( embedding_dim=1024, max_token_size=8192, - func=lambda texts: bedrock_embedding(texts) + func=bedrock_embedding ) ) -with open("./book.txt") as f: +with open("./book.txt", 'r', encoding='utf-8') as f: rag.insert(f.read()) -# Naive search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) - -# Local search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) - -# Global search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) - -# Hybrid search -print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))) +for mode in ["naive", "local", "global", "hybrid"]: + print("\n+-" + "-" * len(mode) + "-+") + print(f"| {mode.capitalize()} |") + print("+-" + "-" * len(mode) + "-+\n") + print( + rag.query( + "What are the top themes in this story?", + param=QueryParam(mode=mode) + ) + ) diff --git a/lightrag/llm.py b/lightrag/llm.py index 8fc0da2e..48defb4d 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1,6 +1,9 @@ import os +import copy import json +import botocore import aioboto3 +import botocore.errorfactory import numpy as np import ollama from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout @@ -50,43 +53,70 @@ async def openai_complete_if_cache( ) return response.choices[0].message.content + +class BedrockError(Exception): + """Generic error for issues related to Amazon Bedrock""" + + @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, max=60), + retry=retry_if_exception_type((BedrockError)), ) async def bedrock_complete_if_cache( - model, prompt, system_prompt=None, history_messages=[], base_url=None, + model, prompt, system_prompt=None, history_messages=[], aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs ) -> str: os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id) os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key) os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - + # Fix message history format messages = [] - messages.extend(history_messages) + for history_message in history_messages: + message = copy.copy(history_message) + message['content'] = [{'text': message['content']}] + messages.append(message) + + # Add user prompt messages.append({'role': "user", 'content': [{'text': prompt}]}) + # Initialize Converse API arguments args = { 'modelId': model, 'messages': messages } + # Define system prompt if system_prompt: args['system'] = [{'text': system_prompt}] + # Map and set up inference parameters + inference_params_map = { + 'max_tokens': "maxTokens", + 'top_p': "topP", + 'stop_sequences': "stopSequences" + } + if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))): + args['inferenceConfig'] = {} + for param in inference_params: + args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param) + + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) if hashing_kv is not None: args_hash = compute_args_hash(model, messages) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: return if_cache_return["return"] + # Call model via Converse API session = aioboto3.Session() async with session.client("bedrock-runtime") as bedrock_async_client: - response = await bedrock_async_client.converse(**args, **kwargs) + try: + response = await bedrock_async_client.converse(**args, **kwargs) + except Exception as e: + raise BedrockError(e) if hashing_kv is not None: await hashing_kv.upsert({ @@ -200,7 +230,7 @@ async def bedrock_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: return await bedrock_complete_if_cache( - "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", prompt, system_prompt=system_prompt, history_messages=history_messages,