Merge pull request #797 from danielaskdd/add-env-settings

Add the token size truncation for local query and token size setting by env
This commit is contained in:
zrguo
2025-02-17 15:00:07 +08:00
committed by GitHub
11 changed files with 142 additions and 41 deletions

View File

@@ -18,6 +18,7 @@
### Logging level ### Logging level
LOG_LEVEL=INFO LOG_LEVEL=INFO
VERBOSE=False
### Optional Timeout ### Optional Timeout
TIMEOUT=300 TIMEOUT=300
@@ -27,14 +28,21 @@ TIMEOUT=300
### RAG Configuration ### RAG Configuration
MAX_ASYNC=4 MAX_ASYNC=4
MAX_TOKENS=32768
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
MAX_EMBED_TOKENS=8192 MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3 ### Settings relative to query
#CHUNK_SIZE=1200 HISTORY_TURNS=3
#CHUNK_OVERLAP_SIZE=100 COSINE_THRESHOLD=0.2
#COSINE_THRESHOLD=0.2 TOP_K=60
#TOP_K=60 MAX_TOKEN_TEXT_CHUNK=4000
MAX_TOKEN_RELATION_DESC=4000
MAX_TOKEN_ENTITY_DESC=4000
### Settings relative to indexing
CHUNK_SIZE=1200
CHUNK_OVERLAP_SIZE=100
MAX_TOKENS=32768
MAX_TOKEN_SUMMARY=500
SUMMARY_LANGUAGE=English
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal) ### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
### Ollama example ### Ollama example

2
.gitignore vendored
View File

@@ -5,7 +5,7 @@ __pycache__/
.eggs/ .eggs/
*.tgz *.tgz
*.tar.gz *.tar.gz
*.ini # Remove config.ini from repo *.ini
# Virtual Environment # Virtual Environment
.venv/ .venv/

View File

@@ -222,6 +222,7 @@ You can select storage implementation by enviroment variables or command line a
| --max-embed-tokens | 8192 | Maximum embedding token size | | --max-embed-tokens | 8192 | Maximum embedding token size |
| --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout | | --timeout | None | Timeout in seconds (useful when using slow AI). Use None for infinite timeout |
| --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) | | --log-level | INFO | Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
| --verbose | False | Verbose debug output (True, Flase) |
| --key | None | API key for authentication. Protects lightrag server against unauthorized access | | --key | None | API key for authentication. Protects lightrag server against unauthorized access |
| --ssl | False | Enable HTTPS | | --ssl | False | Enable HTTPS |
| --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) | | --ssl-certfile | None | Path to SSL certificate file (required if --ssl is enabled) |

View File

@@ -133,8 +133,8 @@ def get_env_value(env_key: str, default: Any, value_type: type = str) -> Any:
if value is None: if value is None:
return default return default
if isinstance(value_type, bool): if value_type is bool:
return value.lower() in ("true", "1", "yes") return value.lower() in ("true", "1", "yes", "t", "on")
try: try:
return value_type(value) return value_type(value)
except ValueError: except ValueError:
@@ -236,6 +236,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}")
ASCIIColors.white(" ├─ Log Level: ", end="") ASCIIColors.white(" ├─ Log Level: ", end="")
ASCIIColors.yellow(f"{args.log_level}") ASCIIColors.yellow(f"{args.log_level}")
ASCIIColors.white(" ├─ Verbose Debug: ", end="")
ASCIIColors.yellow(f"{args.verbose}")
ASCIIColors.white(" └─ Timeout: ", end="") ASCIIColors.white(" └─ Timeout: ", end="")
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}") ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
@@ -565,6 +567,13 @@ def parse_args() -> argparse.Namespace:
help="Prefix of the namespace", help="Prefix of the namespace",
) )
parser.add_argument(
"--verbose",
type=bool,
default=get_env_value("VERBOSE", False, bool),
help="Verbose debug output(default: from env or false)",
)
args = parser.parse_args() args = parser.parse_args()
# conver relative path to absolute path # conver relative path to absolute path
@@ -768,6 +777,11 @@ temp_prefix = "__tmp_" # prefix for temporary files
def create_app(args): def create_app(args):
# Initialize verbose debug setting
from lightrag.utils import set_verbose_debug
set_verbose_debug(args.verbose)
global global_top_k global global_top_k
global_top_k = args.top_k # save top_k from args global_top_k = args.top_k # save top_k from args

View File

@@ -11,6 +11,7 @@ from fastapi.responses import StreamingResponse
import asyncio import asyncio
from ascii_colors import trace_exception from ascii_colors import trace_exception
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import encode_string_by_tiktoken
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -111,18 +112,9 @@ class OllamaTagResponse(BaseModel):
def estimate_tokens(text: str) -> int: def estimate_tokens(text: str) -> int:
"""Estimate the number of tokens in text """Estimate the number of tokens in text using tiktoken"""
Chinese characters: approximately 1.5 tokens per character tokens = encode_string_by_tiktoken(text)
English characters: approximately 0.25 tokens per character return len(tokens)
"""
# Use regex to match Chinese and non-Chinese characters separately
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
non_chinese_chars = len(re.findall(r"[^\u4e00-\u9fff]", text))
# Calculate estimated token count
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
return int(tokens)
def parse_query_mode(query: str) -> tuple[str, SearchMode]: def parse_query_mode(query: str) -> tuple[str, SearchMode]:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from dotenv import load_dotenv
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import ( from typing import (
@@ -9,12 +10,12 @@ from typing import (
TypedDict, TypedDict,
TypeVar, TypeVar,
) )
import numpy as np import numpy as np
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph
load_dotenv()
class TextChunkSchema(TypedDict): class TextChunkSchema(TypedDict):
tokens: int tokens: int
@@ -54,13 +55,15 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = 4000 max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000 max_token_for_global_context: int = int(
os.getenv("MAX_TOKEN_RELATION_DESC", "4000")
)
"""Maximum number of tokens allocated for relationship descriptions in global retrieval.""" """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000 max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000"))
"""Maximum number of tokens allocated for entity descriptions in local retrieval.""" """Maximum number of tokens allocated for entity descriptions in local retrieval."""
hl_keywords: list[str] = field(default_factory=list) hl_keywords: list[str] = field(default_factory=list)

View File

@@ -268,10 +268,10 @@ class LightRAG:
"""Directory where logs are stored. Defaults to the current working directory.""" """Directory where logs are stored. Defaults to the current working directory."""
# Text chunking # Text chunking
chunk_token_size: int = 1200 chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
"""Maximum number of tokens per text chunk when splitting documents.""" """Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = 100 chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
"""Number of overlapping tokens between consecutive text chunks to preserve context.""" """Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini" tiktoken_model_name: str = "gpt-4o-mini"
@@ -281,7 +281,7 @@ class LightRAG:
entity_extract_max_gleaning: int = 1 entity_extract_max_gleaning: int = 1
"""Maximum number of entity extraction attempts for ambiguous content.""" """Maximum number of entity extraction attempts for ambiguous content."""
entity_summary_to_max_tokens: int = 500 entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
"""Maximum number of tokens used for summarizing extracted entities.""" """Maximum number of tokens used for summarizing extracted entities."""
# Node embedding # Node embedding

View File

@@ -40,9 +40,10 @@ __version__ = "1.0.0"
__author__ = "lightrag Team" __author__ = "lightrag Team"
__status__ = "Production" __status__ = "Production"
from ..utils import verbose_debug, VERBOSE_DEBUG
import sys import sys
import os import os
import logging
if sys.version_info < (3, 9): if sys.version_info < (3, 9):
from typing import AsyncIterator from typing import AsyncIterator
@@ -110,6 +111,11 @@ async def openai_complete_if_cache(
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
# Set openai logger level to INFO when VERBOSE_DEBUG is off
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
logging.getLogger("openai").setLevel(logging.INFO)
openai_async_client = ( openai_async_client = (
AsyncOpenAI(default_headers=default_headers, api_key=api_key) AsyncOpenAI(default_headers=default_headers, api_key=api_key)
if base_url is None if base_url is None
@@ -125,13 +131,11 @@ async def openai_complete_if_cache(
messages.extend(history_messages) messages.extend(history_messages)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# 添加日志输出 logger.debug("===== Sending Query to LLM =====")
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Model: {model} Base URL: {base_url}") logger.debug(f"Model: {model} Base URL: {base_url}")
logger.debug(f"Additional kwargs: {kwargs}") logger.debug(f"Additional kwargs: {kwargs}")
logger.debug(f"Query: {prompt}") verbose_debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}") verbose_debug(f"System prompt: {system_prompt}")
# logger.debug(f"Messages: {messages}")
try: try:
if "response_format" in kwargs: if "response_format" in kwargs:

View File

@@ -43,6 +43,7 @@ __status__ = "Production"
import sys import sys
import re import re
import json import json
from ..utils import verbose_debug
if sys.version_info < (3, 9): if sys.version_info < (3, 9):
pass pass
@@ -119,7 +120,7 @@ async def zhipu_complete_if_cache(
# Add debug logging # Add debug logging
logger.debug("===== Query Input to LLM =====") logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}") logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}") verbose_debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs # Remove unsupported kwargs
kwargs = { kwargs = {

View File

@@ -687,6 +687,9 @@ async def kg_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -772,6 +775,9 @@ async def extract_keywords_only(
query=text, examples=examples, language=language, history=history_context query=text, examples=examples, language=language, history=history_context
) )
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction # 5. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
result = await use_model_func(kw_prompt, keyword_extraction=True) result = await use_model_func(kw_prompt, keyword_extraction=True)
@@ -935,7 +941,9 @@ async def mix_kg_vector_query(
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}" chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
formatted_chunks.append(chunk_text) formatted_chunks.append(chunk_text)
logger.info(f"Truncate {len(chunks)} to {len(formatted_chunks)} chunks") logger.debug(
f"Truncate chunks from {len(chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
return "\n--New Chunk--\n".join(formatted_chunks) return "\n--New Chunk--\n".join(formatted_chunks)
except Exception as e: except Exception as e:
logger.error(f"Error in get_vector_context: {e}") logger.error(f"Error in get_vector_context: {e}")
@@ -968,6 +976,9 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response # 6. Generate response
response = await use_model_func( response = await use_model_func(
query, query,
@@ -1073,7 +1084,7 @@ async def _build_query_context(
if not entities_context.strip() and not relations_context.strip(): if not entities_context.strip() and not relations_context.strip():
return None return None
return f""" result = f"""
-----Entities----- -----Entities-----
```csv ```csv
{entities_context} {entities_context}
@@ -1087,6 +1098,15 @@ async def _build_query_context(
{text_units_context} {text_units_context}
``` ```
""" """
contex_tokens = len(encode_string_by_tiktoken(result))
entities_tokens = len(encode_string_by_tiktoken(entities_context))
relations_tokens = len(encode_string_by_tiktoken(relations_context))
text_units_tokens = len(encode_string_by_tiktoken(text_units_context))
logger.debug(
f"Context Tokens - Total: {contex_tokens}, Entities: {entities_tokens}, Relations: {relations_tokens}, Chunks: {text_units_tokens}"
)
return result
async def _get_node_data( async def _get_node_data(
@@ -1130,8 +1150,19 @@ async def _get_node_data(
node_datas, query_param, knowledge_graph_inst node_datas, query_param, knowledge_graph_inst
), ),
) )
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size(
node_datas,
key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context,
)
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
)
logger.info( logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
) )
# build prompt # build prompt
@@ -1264,6 +1295,10 @@ async def _find_most_related_text_unit_from_entities(
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
logger.debug(
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
all_text_units = [t["data"] for t in all_text_units] all_text_units = [t["data"] for t in all_text_units]
return all_text_units return all_text_units
@@ -1305,6 +1340,11 @@ async def _find_most_related_edges_from_entities(
key=lambda x: x["description"], key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
) )
logger.debug(
f"Truncate relations from {len(all_edges)} to {len(all_edges_data)} (max tokens:{query_param.max_token_for_global_context})"
)
return all_edges_data return all_edges_data
@@ -1352,11 +1392,15 @@ async def _get_edge_data(
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
len_edge_datas = len(edge_datas)
edge_datas = truncate_list_by_token_size( edge_datas = truncate_list_by_token_size(
edge_datas, edge_datas,
key=lambda x: x["description"], key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
) )
logger.debug(
f"Truncate relations from {len_edge_datas} to {len(edge_datas)} (max tokens:{query_param.max_token_for_global_context})"
)
use_entities, use_text_units = await asyncio.gather( use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships( _find_most_related_entities_from_relationships(
@@ -1367,7 +1411,7 @@ async def _get_edge_data(
), ),
) )
logger.info( logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
) )
relations_section_list = [ relations_section_list = [
@@ -1456,11 +1500,15 @@ async def _find_most_related_entities_from_relationships(
for k, n, d in zip(entity_names, node_datas, node_degrees) for k, n, d in zip(entity_names, node_datas, node_degrees)
] ]
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(
node_datas, node_datas,
key=lambda x: x["description"], key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context, max_token_size=query_param.max_token_for_local_context,
) )
logger.debug(
f"Truncate entities from {len_node_datas} to {len(node_datas)} (max tokens:{query_param.max_token_for_local_context})"
)
return node_datas return node_datas
@@ -1516,6 +1564,10 @@ async def _find_related_text_unit_from_relationships(
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
logger.debug(
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
return all_text_units return all_text_units
@@ -1583,7 +1635,10 @@ async def naive_query(
logger.warning("No chunks left after truncation") logger.warning("No chunks left after truncation")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") logger.debug(
f"Truncate chunks from {len(chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context: if query_param.only_need_context:
@@ -1606,6 +1661,9 @@ async def naive_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.info(f"[naive_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -1748,6 +1806,9 @@ async def kg_query_with_keywords(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.debug(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,

View File

@@ -20,6 +20,23 @@ import tiktoken
from lightrag.prompt import PROMPTS from lightrag.prompt import PROMPTS
VERBOSE_DEBUG = os.getenv("VERBOSE", "false").lower() == "true"
def verbose_debug(msg: str, *args, **kwargs):
"""Function for outputting detailed debug information.
When VERBOSE_DEBUG=True, outputs the complete message.
When VERBOSE_DEBUG=False, outputs only the first 30 characters.
"""
if VERBOSE_DEBUG:
logger.debug(msg, *args, **kwargs)
def set_verbose_debug(enabled: bool):
"""Enable or disable verbose debug output"""
global VERBOSE_DEBUG
VERBOSE_DEBUG = enabled
class UnlimitedSemaphore: class UnlimitedSemaphore:
"""A context manager that allows unlimited access.""" """A context manager that allows unlimited access."""