feat: trimming the model’s reasoning
This commit is contained in:
@@ -338,6 +338,12 @@ rag = LightRAG(
|
|||||||
|
|
||||||
There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
|
There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
|
||||||
|
|
||||||
|
#### Using "Thinking" Models (e.g., DeepSeek)
|
||||||
|
|
||||||
|
To return only the model's response, you can pass `reasoning_tag` in `llm_model_kwargs`.
|
||||||
|
|
||||||
|
For example, for DeepSeek models, `reasoning_tag` should be set to `think`.
|
||||||
|
|
||||||
#### Low RAM GPUs
|
#### Low RAM GPUs
|
||||||
|
|
||||||
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
|
In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
|
||||||
|
@@ -66,6 +66,7 @@ from lightrag.exceptions import (
|
|||||||
RateLimitError,
|
RateLimitError,
|
||||||
APITimeoutError,
|
APITimeoutError,
|
||||||
)
|
)
|
||||||
|
from lightrag.utils import extract_reasoning
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -85,6 +86,7 @@ async def ollama_model_if_cache(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
stream = True if kwargs.get("stream") else False
|
stream = True if kwargs.get("stream") else False
|
||||||
|
reasoning_tag = kwargs.pop("reasoning_tag", None)
|
||||||
kwargs.pop("max_tokens", None)
|
kwargs.pop("max_tokens", None)
|
||||||
# kwargs.pop("response_format", None) # allow json
|
# kwargs.pop("response_format", None) # allow json
|
||||||
host = kwargs.pop("host", None)
|
host = kwargs.pop("host", None)
|
||||||
@@ -105,7 +107,7 @@ async def ollama_model_if_cache(
|
|||||||
|
|
||||||
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
|
||||||
if stream:
|
if stream:
|
||||||
"""cannot cache stream response"""
|
"""cannot cache stream response and process reasoning"""
|
||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
@@ -113,7 +115,19 @@ async def ollama_model_if_cache(
|
|||||||
|
|
||||||
return inner()
|
return inner()
|
||||||
else:
|
else:
|
||||||
return response["message"]["content"]
|
model_response = response["message"]["content"]
|
||||||
|
|
||||||
|
"""
|
||||||
|
If the model also wraps its thoughts in a specific tag,
|
||||||
|
this information is not needed for the final
|
||||||
|
response and can simply be trimmed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (
|
||||||
|
model_response
|
||||||
|
if reasoning_tag is None
|
||||||
|
else extract_reasoning(model_response, reasoning_tag).response_content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def ollama_model_complete(
|
async def ollama_model_complete(
|
||||||
|
@@ -11,6 +11,7 @@ from functools import wraps
|
|||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Union, List, Optional
|
from typing import Any, Union, List, Optional
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
import bs4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@@ -64,6 +65,13 @@ class EmbeddingFunc:
|
|||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReasoningResponse:
|
||||||
|
reasoning_content: str
|
||||||
|
response_content: str
|
||||||
|
tag: str
|
||||||
|
|
||||||
|
|
||||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||||
"""Locate the JSON string body from a string"""
|
"""Locate the JSON string body from a string"""
|
||||||
try:
|
try:
|
||||||
@@ -666,3 +674,28 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(formatted_turns)
|
return "\n".join(formatted_turns)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_reasoning(response: str, tag: str) -> ReasoningResponse:
|
||||||
|
"""Extract the reasoning section and the following section from the LLM response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM response
|
||||||
|
tag: Tag to extract
|
||||||
|
Returns:
|
||||||
|
ReasoningResponse: Reasoning section and following section
|
||||||
|
|
||||||
|
"""
|
||||||
|
soup = bs4.BeautifulSoup(response, "html.parser")
|
||||||
|
|
||||||
|
reasoning_section = soup.find(tag)
|
||||||
|
if reasoning_section is None:
|
||||||
|
return ReasoningResponse(None, response, tag)
|
||||||
|
reasoning_content = reasoning_section.get_text().strip()
|
||||||
|
|
||||||
|
after_reasoning_section = reasoning_section.next_sibling
|
||||||
|
if after_reasoning_section is None:
|
||||||
|
return ReasoningResponse(reasoning_content, "", tag)
|
||||||
|
after_reasoning_content = after_reasoning_section.get_text().strip()
|
||||||
|
|
||||||
|
return ReasoningResponse(reasoning_content, after_reasoning_content, tag)
|
||||||
|
Reference in New Issue
Block a user