feat: trimming the model’s reasoning

This commit is contained in:
ultrageopro
2025-02-06 22:56:17 +03:00
parent 9db1db2b38
commit 19ee3d109c
3 changed files with 55 additions and 2 deletions

View File

@@ -66,6 +66,7 @@ from lightrag.exceptions import (
RateLimitError,
APITimeoutError,
)
from lightrag.utils import extract_reasoning
import numpy as np
from typing import Union
@@ -85,6 +86,7 @@ async def ollama_model_if_cache(
**kwargs,
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
reasoning_tag = kwargs.pop("reasoning_tag", None)
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
@@ -105,7 +107,7 @@ async def ollama_model_if_cache(
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
"""cannot cache stream response"""
"""cannot cache stream response and process reasoning"""
async def inner():
async for chunk in response:
@@ -113,7 +115,19 @@ async def ollama_model_if_cache(
return inner()
else:
return response["message"]["content"]
model_response = response["message"]["content"]
"""
If the model also wraps its thoughts in a specific tag,
this information is not needed for the final
response and can simply be trimmed.
"""
return (
model_response
if reasoning_tag is None
else extract_reasoning(model_response, reasoning_tag).response_content
)
async def ollama_model_complete(

View File

@@ -11,6 +11,7 @@ from functools import wraps
from hashlib import md5
from typing import Any, Union, List, Optional
import xml.etree.ElementTree as ET
import bs4
import numpy as np
import tiktoken
@@ -64,6 +65,13 @@ class EmbeddingFunc:
return await self.func(*args, **kwargs)
@dataclass
class ReasoningResponse:
reasoning_content: str
response_content: str
tag: str
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
try:
@@ -666,3 +674,28 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
)
return "\n".join(formatted_turns)
def extract_reasoning(response: str, tag: str) -> ReasoningResponse:
"""Extract the reasoning section and the following section from the LLM response.
Args:
response: LLM response
tag: Tag to extract
Returns:
ReasoningResponse: Reasoning section and following section
"""
soup = bs4.BeautifulSoup(response, "html.parser")
reasoning_section = soup.find(tag)
if reasoning_section is None:
return ReasoningResponse(None, response, tag)
reasoning_content = reasoning_section.get_text().strip()
after_reasoning_section = reasoning_section.next_sibling
if after_reasoning_section is None:
return ReasoningResponse(reasoning_content, "", tag)
after_reasoning_content = after_reasoning_section.get_text().strip()
return ReasoningResponse(reasoning_content, after_reasoning_content, tag)