feat: trimming the model’s reasoning

This commit is contained in:
ultrageopro
2025-02-06 22:56:17 +03:00
parent 9db1db2b38
commit 19ee3d109c
3 changed files with 55 additions and 2 deletions

View File

@@ -338,6 +338,12 @@ rag = LightRAG(
There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
#### Using "Thinking" Models (e.g., DeepSeek)
To return only the model's response, you can pass `reasoning_tag` in `llm_model_kwargs`.
For example, for DeepSeek models, `reasoning_tag` should be set to `think`.
#### Low RAM GPUs
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.

View File

@@ -66,6 +66,7 @@ from lightrag.exceptions import (
RateLimitError,
APITimeoutError,
)
from lightrag.utils import extract_reasoning
import numpy as np
from typing import Union
@@ -85,6 +86,7 @@ async def ollama_model_if_cache(
**kwargs,
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
reasoning_tag = kwargs.pop("reasoning_tag", None)
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
@@ -105,7 +107,7 @@ async def ollama_model_if_cache(
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
"""cannot cache stream response"""
"""cannot cache stream response and process reasoning"""
async def inner():
async for chunk in response:
@@ -113,7 +115,19 @@ async def ollama_model_if_cache(
return inner()
else:
return response["message"]["content"]
model_response = response["message"]["content"]
"""
If the model also wraps its thoughts in a specific tag,
this information is not needed for the final
response and can simply be trimmed.
"""
return (
model_response
if reasoning_tag is None
else extract_reasoning(model_response, reasoning_tag).response_content
)
async def ollama_model_complete(

View File

@@ -11,6 +11,7 @@ from functools import wraps
from hashlib import md5
from typing import Any, Union, List, Optional
import xml.etree.ElementTree as ET
import bs4
import numpy as np
import tiktoken
@@ -64,6 +65,13 @@ class EmbeddingFunc:
return await self.func(*args, **kwargs)
@dataclass
class ReasoningResponse:
reasoning_content: str
response_content: str
tag: str
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
"""Locate the JSON string body from a string"""
try:
@@ -666,3 +674,28 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
)
return "\n".join(formatted_turns)
def extract_reasoning(response: str, tag: str) -> ReasoningResponse:
"""Extract the reasoning section and the following section from the LLM response.
Args:
response: LLM response
tag: Tag to extract
Returns:
ReasoningResponse: Reasoning section and following section
"""
soup = bs4.BeautifulSoup(response, "html.parser")
reasoning_section = soup.find(tag)
if reasoning_section is None:
return ReasoningResponse(None, response, tag)
reasoning_content = reasoning_section.get_text().strip()
after_reasoning_section = reasoning_section.next_sibling
if after_reasoning_section is None:
return ReasoningResponse(reasoning_content, "", tag)
after_reasoning_content = after_reasoning_section.get_text().strip()
return ReasoningResponse(reasoning_content, after_reasoning_content, tag)