Merge remote-tracking branch 'origin/main' and fix syntax
This commit is contained in:
288
lightrag/llm.py
288
lightrag/llm.py
@@ -4,8 +4,7 @@ import json
|
||||
import os
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from typing import List, Dict, Callable, Any, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Callable, Any, Union
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -27,13 +26,9 @@ from tenacity import (
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from .base import BaseKVStorage
|
||||
from .utils import (
|
||||
compute_args_hash,
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
quantize_embedding,
|
||||
get_best_cached_response,
|
||||
)
|
||||
|
||||
import sys
|
||||
@@ -66,23 +61,13 @@ async def openai_complete_if_cache(
|
||||
openai_async_client = (
|
||||
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
||||
)
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
if "response_format" in kwargs:
|
||||
response = await openai_async_client.beta.chat.completions.parse(
|
||||
model=model, messages=messages, **kwargs
|
||||
@@ -108,22 +93,6 @@ async def openai_complete_if_cache(
|
||||
content = response.choices[0].message.content
|
||||
if r"\u" in content:
|
||||
content = content.encode("utf-8").decode("unicode_escape")
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=content,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@@ -154,10 +123,7 @@ async def azure_openai_complete_if_cache(
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
)
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
mode = kwargs.pop("mode", "default")
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
@@ -165,34 +131,11 @@ async def azure_openai_complete_if_cache(
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
response = await openai_async_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=content,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@@ -224,7 +167,7 @@ async def bedrock_complete_if_cache(
|
||||
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
||||
"AWS_SESSION_TOKEN", aws_session_token
|
||||
)
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
# Fix message history format
|
||||
messages = []
|
||||
for history_message in history_messages:
|
||||
@@ -234,15 +177,6 @@ async def bedrock_complete_if_cache(
|
||||
|
||||
# Add user prompt
|
||||
messages.append({"role": "user", "content": [{"text": prompt}]})
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Initialize Converse API arguments
|
||||
args = {"modelId": model, "messages": messages}
|
||||
@@ -265,15 +199,6 @@ async def bedrock_complete_if_cache(
|
||||
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
||||
kwargs.pop(param)
|
||||
)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Call model via Converse API
|
||||
session = aioboto3.Session()
|
||||
@@ -283,21 +208,6 @@ async def bedrock_complete_if_cache(
|
||||
except Exception as e:
|
||||
raise BedrockError(e)
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response["output"]["message"]["content"][0]["text"],
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response["output"]["message"]["content"][0]["text"]
|
||||
|
||||
|
||||
@@ -329,22 +239,12 @@ async def hf_model_if_cache(
|
||||
) -> str:
|
||||
model_name = model
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
input_prompt = ""
|
||||
try:
|
||||
input_prompt = hf_tokenizer.apply_chat_template(
|
||||
@@ -389,21 +289,6 @@ async def hf_model_if_cache(
|
||||
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response_text,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
@@ -424,25 +309,14 @@ async def ollama_model_if_cache(
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
"""cannot cache stream response"""
|
||||
@@ -453,22 +327,7 @@ async def ollama_model_if_cache(
|
||||
|
||||
return inner()
|
||||
else:
|
||||
result = response["message"]["content"]
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=result,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
return result
|
||||
return response["message"]["content"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -543,7 +402,7 @@ async def lmdeploy_model_if_cache(
|
||||
from lmdeploy import version_info, GenerationConfig
|
||||
except Exception:
|
||||
raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("response_format", None)
|
||||
max_new_tokens = kwargs.pop("max_tokens", 512)
|
||||
tp = kwargs.pop("tp", 1)
|
||||
@@ -575,19 +434,9 @@ async def lmdeploy_model_if_cache(
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Handle cache
|
||||
mode = kwargs.pop("mode", "default")
|
||||
args_hash = compute_args_hash(model, messages)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, prompt, mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
gen_config = GenerationConfig(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
max_new_tokens=max_new_tokens,
|
||||
@@ -603,22 +452,6 @@ async def lmdeploy_model_if_cache(
|
||||
session_id=1,
|
||||
):
|
||||
response += res.response
|
||||
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -779,6 +612,40 @@ async def openai_embedding(
|
||||
return np.array([dp.embedding for dp in response.data])
|
||||
|
||||
|
||||
async def fetch_data(url, headers, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
response_json = await response.json()
|
||||
data_list = response_json.get("data", [])
|
||||
return data_list
|
||||
|
||||
|
||||
async def jina_embedding(
|
||||
texts: list[str],
|
||||
dimensions: int = 1024,
|
||||
late_chunking: bool = False,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["JINA_API_KEY"] = api_key
|
||||
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"""Bearer {os.environ["JINA_API_KEY"]}""",
|
||||
}
|
||||
data = {
|
||||
"model": "jina-embeddings-v3",
|
||||
"normalized": True,
|
||||
"embedding_type": "float",
|
||||
"dimensions": f"{dimensions}",
|
||||
"late_chunking": late_chunking,
|
||||
"input": texts,
|
||||
}
|
||||
data_list = await fetch_data(url, headers, data)
|
||||
return np.array([dp["embedding"] for dp in data_list])
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
@@ -1064,77 +931,6 @@ class MultiModel:
|
||||
return await next_model.gen_func(**args)
|
||||
|
||||
|
||||
async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
|
||||
"""Generic cache handling function"""
|
||||
if hashing_kv is None:
|
||||
return None, None, None, None
|
||||
|
||||
# Get embedding cache configuration
|
||||
embedding_cache_config = hashing_kv.global_config.get(
|
||||
"embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95}
|
||||
)
|
||||
is_embedding_cache_enabled = embedding_cache_config["enabled"]
|
||||
|
||||
quantized = min_val = max_val = None
|
||||
if is_embedding_cache_enabled:
|
||||
# Use embedding cache
|
||||
embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
|
||||
current_embedding = await embedding_model_func([prompt])
|
||||
quantized, min_val, max_val = quantize_embedding(current_embedding[0])
|
||||
best_cached_response = await get_best_cached_response(
|
||||
hashing_kv,
|
||||
current_embedding[0],
|
||||
similarity_threshold=embedding_cache_config["similarity_threshold"],
|
||||
mode=mode,
|
||||
)
|
||||
if best_cached_response is not None:
|
||||
return best_cached_response, None, None, None
|
||||
else:
|
||||
# Use regular cache
|
||||
mode_cache = await hashing_kv.get_by_id(mode) or {}
|
||||
if args_hash in mode_cache:
|
||||
return mode_cache[args_hash]["return"], None, None, None
|
||||
|
||||
return None, quantized, min_val, max_val
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheData:
|
||||
args_hash: str
|
||||
content: str
|
||||
model: str
|
||||
prompt: str
|
||||
quantized: Optional[np.ndarray] = None
|
||||
min_val: Optional[float] = None
|
||||
max_val: Optional[float] = None
|
||||
mode: str = "default"
|
||||
|
||||
|
||||
async def save_to_cache(hashing_kv, cache_data: CacheData):
|
||||
if hashing_kv is None:
|
||||
return
|
||||
|
||||
mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
|
||||
|
||||
mode_cache[cache_data.args_hash] = {
|
||||
"return": cache_data.content,
|
||||
"model": cache_data.model,
|
||||
"embedding": (
|
||||
cache_data.quantized.tobytes().hex()
|
||||
if cache_data.quantized is not None
|
||||
else None
|
||||
),
|
||||
"embedding_shape": (
|
||||
cache_data.quantized.shape if cache_data.quantized is not None else None
|
||||
),
|
||||
"embedding_min": cache_data.min_val,
|
||||
"embedding_max": cache_data.max_val,
|
||||
"original_prompt": cache_data.prompt,
|
||||
}
|
||||
|
||||
await hashing_kv.upsert({cache_data.mode: mode_cache})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
Reference in New Issue
Block a user