From 19ee3d109c10d45e9501d7f3232db527e02a8c4b Mon Sep 17 00:00:00 2001 From: ultrageopro Date: Thu, 6 Feb 2025 22:56:17 +0300 Subject: [PATCH] =?UTF-8?q?feat:=20trimming=20the=20model=E2=80=99s=20reas?= =?UTF-8?q?oning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 ++++++ lightrag/llm/ollama.py | 18 ++++++++++++++++-- lightrag/utils.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d9315458..456d9a72 100644 --- a/README.md +++ b/README.md @@ -338,6 +338,12 @@ rag = LightRAG( There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k. +#### Using "Thinking" Models (e.g., DeepSeek) + +To return only the model's response, you can pass `reasoning_tag` in `llm_model_kwargs`. + +For example, for DeepSeek models, `reasoning_tag` should be set to `think`. + #### Low RAM GPUs In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`. diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 19f560e7..3541bd67 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -66,6 +66,7 @@ from lightrag.exceptions import ( RateLimitError, APITimeoutError, ) +from lightrag.utils import extract_reasoning import numpy as np from typing import Union @@ -85,6 +86,7 @@ async def ollama_model_if_cache( **kwargs, ) -> Union[str, AsyncIterator[str]]: stream = True if kwargs.get("stream") else False + reasoning_tag = kwargs.pop("reasoning_tag", None) kwargs.pop("max_tokens", None) # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) @@ -105,7 +107,7 @@ async def ollama_model_if_cache( response = await ollama_client.chat(model=model, messages=messages, **kwargs) if stream: - """cannot cache stream response""" + """cannot cache stream response and process reasoning""" async def inner(): async for chunk in response: @@ -113,7 +115,19 @@ async def ollama_model_if_cache( return inner() else: - return response["message"]["content"] + model_response = response["message"]["content"] + + """ + If the model also wraps its thoughts in a specific tag, + this information is not needed for the final + response and can simply be trimmed. + """ + + return ( + model_response + if reasoning_tag is None + else extract_reasoning(model_response, reasoning_tag).response_content + ) async def ollama_model_complete( diff --git a/lightrag/utils.py b/lightrag/utils.py index 3a69513b..ed0b6c06 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -11,6 +11,7 @@ from functools import wraps from hashlib import md5 from typing import Any, Union, List, Optional import xml.etree.ElementTree as ET +import bs4 import numpy as np import tiktoken @@ -64,6 +65,13 @@ class EmbeddingFunc: return await self.func(*args, **kwargs) +@dataclass +class ReasoningResponse: + reasoning_content: str + response_content: str + tag: str + + def locate_json_string_body_from_string(content: str) -> Union[str, None]: """Locate the JSON string body from a string""" try: @@ -666,3 +674,28 @@ def get_conversation_turns(conversation_history: list[dict], num_turns: int) -> ) return "\n".join(formatted_turns) + + +def extract_reasoning(response: str, tag: str) -> ReasoningResponse: + """Extract the reasoning section and the following section from the LLM response. + + Args: + response: LLM response + tag: Tag to extract + Returns: + ReasoningResponse: Reasoning section and following section + + """ + soup = bs4.BeautifulSoup(response, "html.parser") + + reasoning_section = soup.find(tag) + if reasoning_section is None: + return ReasoningResponse(None, response, tag) + reasoning_content = reasoning_section.get_text().strip() + + after_reasoning_section = reasoning_section.next_sibling + if after_reasoning_section is None: + return ReasoningResponse(reasoning_content, "", tag) + after_reasoning_content = after_reasoning_section.get_text().strip() + + return ReasoningResponse(reasoning_content, after_reasoning_content, tag)