Merge pull request #797 from danielaskdd/add-env-settings

Add the token size truncation for local query and token size setting by env
This commit is contained in:
zrguo
2025-02-17 15:00:07 +08:00
committed by GitHub
11 changed files with 142 additions and 41 deletions

View File

@@ -18,6 +18,7 @@
### Logging level
LOG_LEVEL=INFO
VERBOSE=False
### Optional Timeout
TIMEOUT=300
@@ -27,14 +28,21 @@ TIMEOUT=300
### RAG Configuration
MAX_ASYNC=4
MAX_TOKENS=32768
EMBEDDING_DIM=1024
MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3
#CHUNK_SIZE=1200
#CHUNK_OVERLAP_SIZE=100
#COSINE_THRESHOLD=0.2
#TOP_K=60
### Settings relative to query
HISTORY_TURNS=3
COSINE_THRESHOLD=0.2
TOP_K=60
MAX_TOKEN_TEXT_CHUNK=4000
MAX_TOKEN_RELATION_DESC=4000
MAX_TOKEN_ENTITY_DESC=4000
### Settings relative to indexing
CHUNK_SIZE=1200
CHUNK_OVERLAP_SIZE=100
MAX_TOKENS=32768
MAX_TOKEN_SUMMARY=500
SUMMARY_LANGUAGE=English
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
### Ollama example

2
.gitignore vendored
View File

@@ -5,7 +5,7 @@ __pycache__/
.eggs/
*.tgz
*.tar.gz
*.ini # Remove config.ini from repo
*.ini
# Virtual Environment
.venv/

View File

@@ -222,6 +222,7 @@ You can select storage implementation by enviroment variables or command line a
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout |
| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| --verbose | False | Verbose debug output (True, Flase) |
| --key | None | API key for authentication. Protects lightrag server against unauthorized access |
| --ssl | False | Enable HTTPS |
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |

View File

@@ -133,8 +133,8 @@ def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any:
if value is None:
return default
if isinstance(value_type, bool):
return value.lower() in ("true", "1", "yes")
if value_type is bool:
return value.lower() in ("true", "1", "yes", "t", "on")
try:
return value_type(value)
except ValueError:
@@ -236,6 +236,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
ASCIIColors.white(" ├─ Log Level: ", end="")
ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" └─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
@@ -565,6 +567,13 @@ def parse_args() -> argparse.Namespace:
help="Prefix of the namespace",
)
parser.add_argument(
"--verbose",
type=bool,
default=get_env_value("VERBOSE", False, bool),
help="Verbose debug output(default: from env or false)",
)
args = parser.parse_args()
# conver relative path to absolute path
@@ -768,6 +777,11 @@ temp_prefix = "__tmp_" # prefix for temporary files
def create_app(args):
# Initialize verbose debug setting
from lightrag.utils import set_verbose_debug
set_verbose_debug(args.verbose)
global global_top_k
global_top_k = args.top_k # save top_k from args

View File

@@ -11,6 +11,7 @@ from fastapi.responses import StreamingResponse
import asyncio
from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken
from dotenv import load_dotenv
@@ -111,18 +112,9 @@ class OllamaTagResponse(BaseModel):
def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text
Chinese characters: approximately 1.5 tokens per character
English characters: approximately 0.25 tokens per character
"""
# Use regex to match Chinese and non-Chinese characters separately
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
# Calculate estimated token count
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens)
"""Estimate the number of tokens in text using tiktoken"""
tokens = encode_string_by_tiktoken(text)
return len(tokens)
def parse_query_mode(query: str) -> tuple[str, SearchMode]:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import os
from dotenv import load_dotenv
from dataclasses import dataclass, field
from enum import Enum
from typing import (
@@ -9,12 +10,12 @@ from typing import (
TypedDict,
TypeVar,
)
import numpy as np
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
load_dotenv()
class TextChunkSchema(TypedDict):
tokens: int
@@ -54,13 +55,15 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = 4000
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000
max_token_for_global_context: int = int(
os.getenv("MAX_TOKEN_RELATION_DESC", "4000")
)
"""Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000
max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000"))
"""Maximum number of tokens allocated for entity descriptions in local retrieval."""
hl_keywords: list[str] = field(default_factory=list)

View File

@@ -268,10 +268,10 @@ class LightRAG:
"""Directory where logs are stored. Defaults to the current working directory."""
# Text chunking
chunk_token_size: int = 1200
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = 100
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini"
@@ -281,7 +281,7 @@ class LightRAG:
entity_extract_max_gleaning: int = 1
"""Maximum number of entity extraction attempts for ambiguous content."""
entity_summary_to_max_tokens: int = 500
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
"""Maximum number of tokens used for summarizing extracted entities."""
# Node embedding

View File

@@ -40,9 +40,10 @@ __version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
from ..utils import verbose_debug, VERBOSE_DEBUG
import sys
import os
import logging
if sys.version_info < (3, 9):
from typing import AsyncIterator
@@ -110,6 +111,11 @@ async def openai_complete_if_cache(
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
# Set openai logger level to INFO when VERBOSE_DEBUG is off
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("openai").setLevel(logging.INFO)
openai_async_client = (
AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None
@@ -125,13 +131,11 @@ async def openai_complete_if_cache(
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 添加日志输出
logger.debug("===== Query Input to LLM =====")
logger.debug("===== Sending Query to LLM =====")
logger.debug(f"Model: {model} Base URL: {base_url}")
logger.debug(f"Additional kwargs: {kwargs}")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
# logger.debug(f"Messages: {messages}")
verbose_debug(f"Query: {prompt}")
verbose_debug(f"System prompt: {system_prompt}")
try:
if "response_format" in kwargs:

View File

@@ -43,6 +43,7 @@ __status__ = "Production"
import sys
import re
import json
from ..utils import verbose_debug
if sys.version_info < (3, 9):
pass
@@ -119,7 +120,7 @@ async def zhipu_complete_if_cache(
# Add debug logging
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
verbose_debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs
kwargs = {

View File

@@ -687,6 +687,9 @@ async def kg_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
query,
system_prompt=sys_prompt,
@@ -772,6 +775,9 @@ async def extract_keywords_only(
query=text, examples=examples, language=language, history=history_context
)
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"]
result = await use_model_func(kw_prompt, keyword_extraction=True)
@@ -935,7 +941,9 @@ async def mix_kg_vector_query(
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
formatted_chunks.append(chunk_text)
logger.info(f"Truncate {len(chunks)} to {len(formatted_chunks)} chunks")
logger.debug(
f"Truncate chunks from {len(chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
return "\n--New Chunk--\n".join(formatted_chunks)
except Exception as e:
logger.error(f"Error in get_vector_context: {e}")
@@ -968,6 +976,9 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response
response = await use_model_func(
query,
@@ -1073,7 +1084,7 @@ async def _build_query_context(
if not entities_context.strip() and not relations_context.strip():
return None
return f"""
result = f"""
-----Entities-----
```csv
{entities_context}
@@ -1087,6 +1098,15 @@ async def _build_query_context(
{text_units_context}
```
"""
contex_tokens = len(encode_string_by_tiktoken(result))
entities_tokens = len(encode_string_by_tiktoken(entities_context))
relations_tokens = len(encode_string_by_tiktoken(relations_context))
text_units_tokens = len(encode_string_by_tiktoken(text_units_context))
logger.debug(
f"Context Tokens - Total: {contex_tokens}, Entities: {entities_tokens}, Relations: {relations_tokens}, Chunks: {text_units_tokens}"
)
return result
async def _get_node_data(
@@ -1130,8 +1150,19 @@ async def _get_node_data(
node_datas, query_param, knowledge_graph_inst
),
)
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
)
logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
)
# build prompt
@@ -1264,6 +1295,10 @@ async def _find_most_related_text_unit_from_entities(
max_token_size=query_param.max_token_for_text_unit,
)
logger.debug(
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
all_text_units = [t["data"] for t in all_text_units]
return all_text_units
@@ -1305,6 +1340,11 @@ async def _find_most_related_edges_from_entities(
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context,
)
logger.debug(
f"Truncate relations from {len(all_edges)} to {len(all_edges_data)} (max tokens:{query_param.max_token_for_global_context})"
)
return all_edges_data
@@ -1352,11 +1392,15 @@ async def _get_edge_data(
edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
)
len_edge_datas = len(edge_datas)
edge_datas = truncate_list_by_token_size(
edge_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context,
)
logger.debug(
f"Truncate relations from {len_edge_datas} to {len(edge_datas)} (max tokens:{query_param.max_token_for_global_context})"
)
use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships(
@@ -1367,7 +1411,7 @@ async def _get_edge_data(
),
)
logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
)
relations_section_list = [
@@ -1456,11 +1500,15 @@ async def _find_most_related_entities_from_relationships(
for k, n, d in zip(entity_names, node_datas, node_degrees)
]
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
)
return node_datas
@@ -1516,6 +1564,10 @@ async def _find_related_text_unit_from_relationships(
max_token_size=query_param.max_token_for_text_unit,
)
logger.debug(
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
return all_text_units
@@ -1583,7 +1635,10 @@ async def naive_query(
logger.warning("No chunks left after truncation")
return PROMPTS["fail_response"]
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
logger.debug(
f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context:
@@ -1606,6 +1661,9 @@ async def naive_query(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.info(f"[naive_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
query,
system_prompt=sys_prompt,
@@ -1748,6 +1806,9 @@ async def kg_query_with_keywords(
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
response = await use_model_func(
query,
system_prompt=sys_prompt,

View File

@@ -20,6 +20,23 @@ import tiktoken
from lightrag.prompt import PROMPTS
VERBOSE_DEBUG = os.getenv("VERBOSE", "false").lower() == "true"
def verbose_debug(msg: str, *args, **kwargs):
"""Function for outputting detailed debug information.
When VERBOSE_DEBUG=True, outputs the complete message.
When VERBOSE_DEBUG=False, outputs only the first 30 characters.
"""
if VERBOSE_DEBUG:
logger.debug(msg, *args, **kwargs)
def set_verbose_debug(enabled: bool):
"""Enable or disable verbose debug output"""
global VERBOSE_DEBUG
VERBOSE_DEBUG = enabled
class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""