Fixed retry strategy, message history and inference params; Cleaned up Bedrock example

This commit is contained in:
João Galego
2024-10-18 16:50:02 +01:00
parent 1fc55b18d5
commit 37d713a5c8
2 changed files with 55 additions and 32 deletions

View File

@@ -3,46 +3,39 @@ LightRAG meets Amazon Bedrock ⛰️
""" """
import os import os
import logging
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm import bedrock_complete, bedrock_embedding from lightrag.llm import bedrock_complete, bedrock_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
WORKING_DIR = "./dickens" logging.getLogger("aiobotocore").setLevel(logging.WARNING)
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=bedrock_complete, llm_model_func=bedrock_complete,
llm_model_name="anthropic.claude-3-haiku-20240307-v1:0", llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
node2vec_params = {
'dimensions': 1024,
'num_walks': 10,
'walk_length': 40,
'window_size': 2,
'iterations': 3,
'random_seed': 3
},
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=1024, embedding_dim=1024,
max_token_size=8192, max_token_size=8192,
func=lambda texts: bedrock_embedding(texts) func=bedrock_embedding
) )
) )
with open("./book.txt") as f: with open("./book.txt", 'r', encoding='utf-8') as f:
rag.insert(f.read()) rag.insert(f.read())
# Naive search for mode in ["naive", "local", "global", "hybrid"]:
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))) print("\n+-" + "-" * len(mode) + "-+")
print(f"| {mode.capitalize()} |")
# Local search print("+-" + "-" * len(mode) + "-+\n")
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))) print(
rag.query(
# Global search "What are the top themes in this story?",
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))) param=QueryParam(mode=mode)
)
# Hybrid search )
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))

View File

@@ -1,6 +1,9 @@
import os import os
import copy
import json import json
import botocore
import aioboto3 import aioboto3
import botocore.errorfactory
import numpy as np import numpy as np
import ollama import ollama
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -50,43 +53,70 @@ async def openai_complete_if_cache(
) )
return response.choices[0].message.content return response.choices[0].message.content
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), retry=retry_if_exception_type((BedrockError)),
) )
async def bedrock_complete_if_cache( async def bedrock_complete_if_cache(
model, prompt, system_prompt=None, history_messages=[], base_url=None, model, prompt, system_prompt=None, history_messages=[],
aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
) -> str: ) -> str:
os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id) os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key) os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token) os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) # Fix message history format
messages = [] messages = []
messages.extend(history_messages) for history_message in history_messages:
message = copy.copy(history_message)
message['content'] = [{'text': message['content']}]
messages.append(message)
# Add user prompt
messages.append({'role': "user", 'content': [{'text': prompt}]}) messages.append({'role': "user", 'content': [{'text': prompt}]})
# Initialize Converse API arguments
args = { args = {
'modelId': model, 'modelId': model,
'messages': messages 'messages': messages
} }
# Define system prompt
if system_prompt: if system_prompt:
args['system'] = [{'text': system_prompt}] args['system'] = [{'text': system_prompt}]
# Map and set up inference parameters
inference_params_map = {
'max_tokens': "maxTokens",
'top_p': "topP",
'stop_sequences': "stopSequences"
}
if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
args['inferenceConfig'] = {}
for param in inference_params:
args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
if hashing_kv is not None: if hashing_kv is not None:
args_hash = compute_args_hash(model, messages) args_hash = compute_args_hash(model, messages)
if_cache_return = await hashing_kv.get_by_id(args_hash) if_cache_return = await hashing_kv.get_by_id(args_hash)
if if_cache_return is not None: if if_cache_return is not None:
return if_cache_return["return"] return if_cache_return["return"]
# Call model via Converse API
session = aioboto3.Session() session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client: async with session.client("bedrock-runtime") as bedrock_async_client:
response = await bedrock_async_client.converse(**args, **kwargs) try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
if hashing_kv is not None: if hashing_kv is not None:
await hashing_kv.upsert({ await hashing_kv.upsert({
@@ -200,7 +230,7 @@ async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], **kwargs prompt, system_prompt=None, history_messages=[], **kwargs
) -> str: ) -> str:
return await bedrock_complete_if_cache( return await bedrock_complete_if_cache(
"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-haiku-20240307-v1:0",
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,