feat: trimming the model’s reasoning
This commit is contained in:
@@ -66,6 +66,7 @@ from lightrag.exceptions import (
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from lightrag.utils import extract_reasoning
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
@@ -85,6 +86,7 @@ async def ollama_model_if_cache(
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
stream = True if kwargs.get("stream") else False
|
||||
reasoning_tag = kwargs.pop("reasoning_tag", None)
|
||||
kwargs.pop("max_tokens", None)
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
@@ -105,7 +107,7 @@ async def ollama_model_if_cache(
|
||||
|
||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||
if stream:
|
||||
"""cannot cache stream response"""
|
||||
"""cannot cache stream response and process reasoning"""
|
||||
|
||||
async def inner():
|
||||
async for chunk in response:
|
||||
@@ -113,7 +115,19 @@ async def ollama_model_if_cache(
|
||||
|
||||
return inner()
|
||||
else:
|
||||
return response["message"]["content"]
|
||||
model_response = response["message"]["content"]
|
||||
|
||||
"""
|
||||
If the model also wraps its thoughts in a specific tag,
|
||||
this information is not needed for the final
|
||||
response and can simply be trimmed.
|
||||
"""
|
||||
|
||||
return (
|
||||
model_response
|
||||
if reasoning_tag is None
|
||||
else extract_reasoning(model_response, reasoning_tag).response_content
|
||||
)
|
||||
|
||||
|
||||
async def ollama_model_complete(
|
||||
|
@@ -11,6 +11,7 @@ from functools import wraps
|
||||
from hashlib import md5
|
||||
from typing import Any, Union, List, Optional
|
||||
import xml.etree.ElementTree as ET
|
||||
import bs4
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
@@ -64,6 +65,13 @@ class EmbeddingFunc:
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReasoningResponse:
|
||||
reasoning_content: str
|
||||
response_content: str
|
||||
tag: str
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
"""Locate the JSON string body from a string"""
|
||||
try:
|
||||
@@ -666,3 +674,28 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
||||
)
|
||||
|
||||
return "\n".join(formatted_turns)
|
||||
|
||||
|
||||
def extract_reasoning(response: str, tag: str) -> ReasoningResponse:
|
||||
"""Extract the reasoning section and the following section from the LLM response.
|
||||
|
||||
Args:
|
||||
response: LLM response
|
||||
tag: Tag to extract
|
||||
Returns:
|
||||
ReasoningResponse: Reasoning section and following section
|
||||
|
||||
"""
|
||||
soup = bs4.BeautifulSoup(response, "html.parser")
|
||||
|
||||
reasoning_section = soup.find(tag)
|
||||
if reasoning_section is None:
|
||||
return ReasoningResponse(None, response, tag)
|
||||
reasoning_content = reasoning_section.get_text().strip()
|
||||
|
||||
after_reasoning_section = reasoning_section.next_sibling
|
||||
if after_reasoning_section is None:
|
||||
return ReasoningResponse(reasoning_content, "", tag)
|
||||
after_reasoning_content = after_reasoning_section.get_text().strip()
|
||||
|
||||
return ReasoningResponse(reasoning_content, after_reasoning_content, tag)
|
||||
|
Reference in New Issue
Block a user