Merge remote-tracking branch 'origin/main'
# Conflicts: # lightrag/llm.py # lightrag/operate.py
This commit is contained in:
@@ -4,8 +4,7 @@ import json
|
||||
import os
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from typing import List, Dict, Callable, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Callable, Any, Union
|
||||
|
||||
import aioboto3
|
||||
import aiohttp
|
||||
@@ -37,6 +36,13 @@ from .utils import (
|
||||
get_best_cached_response,
|
||||
)
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
else:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@@ -397,7 +403,8 @@ async def ollama_model_if_cache(
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
stream = True if kwargs.get("stream") else False
|
||||
kwargs.pop("max_tokens", None)
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
@@ -422,7 +429,31 @@ async def ollama_model_if_cache(
|
||||
return cached_response
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
""" cannot cache stream response """
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
yield chunk["message"]["content"]
|
||||
|
||||
return inner()
|
||||
else:
|
||||
result = response["message"]["content"]
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=result,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=mode,
|
||||
),
|
||||
)
|
||||
return result
|
||||
result = response["message"]["content"]
|
||||
|
||||
# Save to cache
|
||||
@@ -697,7 +728,7 @@ async def hf_model_complete(
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
kwargs["format"] = "json"
|
||||
|
Reference in New Issue
Block a user