Merge remote-tracking branch 'origin/main'

# Conflicts:
#	lightrag/llm.py
#	lightrag/operate.py
This commit is contained in:
magicyuan876
2024-12-06 15:06:00 +08:00
6 changed files with 198 additions and 6 deletions

View File

@@ -19,6 +19,7 @@ class QueryParam:
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.

View File

@@ -143,7 +143,7 @@ class OracleDB:
data = None
return data
async def execute(self, sql: str, data: list | dict = None):
async def execute(self, sql: str, data: Union[list, dict] = None):
# logger.info("go into OracleDB execute method")
try:
async with self.pool.acquire() as connection:

View File

@@ -4,8 +4,7 @@ import json
import os
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any, Optional
from dataclasses import dataclass
from typing import List, Dict, Callable, Any, Union
import aioboto3
import aiohttp
@@ -37,6 +36,13 @@ from .utils import (
get_best_cached_response,
)
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -397,7 +403,8 @@ async def ollama_model_if_cache(
system_prompt=None,
history_messages=[],
**kwargs,
) -> str:
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
@@ -422,7 +429,31 @@ async def ollama_model_if_cache(
return cached_response
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
""" cannot cache stream response """
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
return inner()
else:
result = response["message"]["content"]
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
model=model,
prompt=prompt,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=mode,
),
)
return result
result = response["message"]["content"]
# Save to cache
@@ -697,7 +728,7 @@ async def hf_model_complete(
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"

View File

@@ -536,9 +536,10 @@ async def kg_query(
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
mode=query_param.mode,
)
if len(response) > len(sys_prompt):
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")