feat: move query-related settings to env file for better configuration

• Add env vars for token and chunk settings
• Add token count logging for prompts
• Add token count logging for context
• Move hardcoded values to env variables
• Improve logging clarity and consistency
This commit is contained in:
yangdx
2025-02-16 19:26:57 +08:00
parent 8fdbcb0d3f
commit 601df31edf
4 changed files with 69 additions and 17 deletions

View File

@@ -27,14 +27,21 @@ TIMEOUT=300
### RAG Configuration ### RAG Configuration
MAX_ASYNC=4 MAX_ASYNC=4
MAX_TOKENS=32768
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
MAX_EMBED_TOKENS=8192 MAX_EMBED_TOKENS=8192
#HISTORY_TURNS=3 ### Settings relative to query
#CHUNK_SIZE=1200 HISTORY_TURNS=3
#CHUNK_OVERLAP_SIZE=100 COSINE_THRESHOLD=0.2
#COSINE_THRESHOLD=0.2 TOP_K=60
#TOP_K=60 MAX_TOKEN_TEXT_CHUNK = 4000
MAX_TOKEN_RELATION_DESC = 4000
MAX_TOKEN_ENTITY_DESC = 4000
### Settings relative to indexing
CHUNK_SIZE=1200
CHUNK_OVERLAP_SIZE=100
MAX_TOKENS=32768
MAX_TOKEN_SUMMARY=500
SUMMARY_LANGUAGE=English
### LLM Configuration (Use valid host. For local services, you can use host.docker.internal) ### LLM Configuration (Use valid host. For local services, you can use host.docker.internal)
### Ollama example ### Ollama example

View File

@@ -54,13 +54,15 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = 4000 max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = 4000 max_token_for_global_context: int = int(
os.getenv("MAX_TOKEN_RELATION_DESC", "4000")
)
"""Maximum number of tokens allocated for relationship descriptions in global retrieval.""" """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
max_token_for_local_context: int = 4000 max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000"))
"""Maximum number of tokens allocated for entity descriptions in local retrieval.""" """Maximum number of tokens allocated for entity descriptions in local retrieval."""
hl_keywords: list[str] = field(default_factory=list) hl_keywords: list[str] = field(default_factory=list)

View File

@@ -263,10 +263,10 @@ class LightRAG:
"""Directory where logs are stored. Defaults to the current working directory.""" """Directory where logs are stored. Defaults to the current working directory."""
# Text chunking # Text chunking
chunk_token_size: int = 1200 chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
"""Maximum number of tokens per text chunk when splitting documents.""" """Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = 100 chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
"""Number of overlapping tokens between consecutive text chunks to preserve context.""" """Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini" tiktoken_model_name: str = "gpt-4o-mini"
@@ -276,7 +276,7 @@ class LightRAG:
entity_extract_max_gleaning: int = 1 entity_extract_max_gleaning: int = 1
"""Maximum number of entity extraction attempts for ambiguous content.""" """Maximum number of entity extraction attempts for ambiguous content."""
entity_summary_to_max_tokens: int = 500 entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500"))
"""Maximum number of tokens used for summarizing extracted entities.""" """Maximum number of tokens used for summarizing extracted entities."""
# Node embedding # Node embedding

View File

@@ -642,9 +642,13 @@ async def kg_query(
history=history_context, history=history_context,
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.info(f"[kg_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -730,6 +734,9 @@ async def extract_keywords_only(
query=text, examples=examples, language=language, history=history_context query=text, examples=examples, language=language, history=history_context
) )
len_of_prompts = len(encode_string_by_tiktoken(kw_prompt))
logger.info(f"[kg_query]Prompt Tokens: {len_of_prompts}")
# 5. Call the LLM for keyword extraction # 5. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"] use_model_func = global_config["llm_model_func"]
result = await use_model_func(kw_prompt, keyword_extraction=True) result = await use_model_func(kw_prompt, keyword_extraction=True)
@@ -893,7 +900,9 @@ async def mix_kg_vector_query(
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}" chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
formatted_chunks.append(chunk_text) formatted_chunks.append(chunk_text)
logger.info(f"Truncate {len(chunks)} to {len(formatted_chunks)} chunks") logger.info(
f"Truncate text chunks from {len(chunks)} to {len(formatted_chunks)}"
)
return "\n--New Chunk--\n".join(formatted_chunks) return "\n--New Chunk--\n".join(formatted_chunks)
except Exception as e: except Exception as e:
logger.error(f"Error in get_vector_context: {e}") logger.error(f"Error in get_vector_context: {e}")
@@ -926,6 +935,9 @@ async def mix_kg_vector_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.info(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response # 6. Generate response
response = await use_model_func( response = await use_model_func(
query, query,
@@ -1031,7 +1043,7 @@ async def _build_query_context(
if not entities_context.strip() and not relations_context.strip(): if not entities_context.strip() and not relations_context.strip():
return None return None
return f""" result = f"""
-----Entities----- -----Entities-----
```csv ```csv
{entities_context} {entities_context}
@@ -1045,6 +1057,15 @@ async def _build_query_context(
{text_units_context} {text_units_context}
``` ```
""" """
contex_tokens = len(encode_string_by_tiktoken(result))
entities_tokens = len(encode_string_by_tiktoken(entities_context))
relations_tokens = len(encode_string_by_tiktoken(relations_context))
text_units_tokens = len(encode_string_by_tiktoken(text_units_context))
logger.info(
f"Context Tokens - Total: {contex_tokens}, Entities: {entities_tokens}, Relations: {relations_tokens}, Chunks: {text_units_tokens}"
)
return result
async def _get_node_data( async def _get_node_data(
@@ -1089,7 +1110,7 @@ async def _get_node_data(
), ),
) )
logger.info( logger.info(
f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks"
) )
# build prompt # build prompt
@@ -1222,6 +1243,10 @@ async def _find_most_related_text_unit_from_entities(
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
logger.info(
f"Truncate text chunks from {len(all_text_units_lookup)} to {len(all_text_units)}"
)
all_text_units = [t["data"] for t in all_text_units] all_text_units = [t["data"] for t in all_text_units]
return all_text_units return all_text_units
@@ -1263,6 +1288,9 @@ async def _find_most_related_edges_from_entities(
key=lambda x: x["description"], key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
) )
logger.info(f"Truncate relations from {len(all_edges)} to {len(all_edges_data)}")
return all_edges_data return all_edges_data
@@ -1310,11 +1338,13 @@ async def _get_edge_data(
edge_datas = sorted( edge_datas = sorted(
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
) )
len_edge_datas = len(edge_datas)
edge_datas = truncate_list_by_token_size( edge_datas = truncate_list_by_token_size(
edge_datas, edge_datas,
key=lambda x: x["description"], key=lambda x: x["description"],
max_token_size=query_param.max_token_for_global_context, max_token_size=query_param.max_token_for_global_context,
) )
logger.info(f"Truncate relations from {len_edge_datas} to {len(edge_datas)}")
use_entities, use_text_units = await asyncio.gather( use_entities, use_text_units = await asyncio.gather(
_find_most_related_entities_from_relationships( _find_most_related_entities_from_relationships(
@@ -1325,7 +1355,7 @@ async def _get_edge_data(
), ),
) )
logger.info( logger.info(
f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks"
) )
relations_section_list = [ relations_section_list = [
@@ -1414,11 +1444,13 @@ async def _find_most_related_entities_from_relationships(
for k, n, d in zip(entity_names, node_datas, node_degrees) for k, n, d in zip(entity_names, node_datas, node_degrees)
] ]
len_node_datas = len(node_datas)
node_datas = truncate_list_by_token_size( node_datas = truncate_list_by_token_size(
node_datas, node_datas,
key=lambda x: x["description"], key=lambda x: x["description"],
max_token_size=query_param.max_token_for_local_context, max_token_size=query_param.max_token_for_local_context,
) )
logger.info(f"Truncate entities from {len_node_datas} to {len(node_datas)}")
return node_datas return node_datas
@@ -1474,6 +1506,10 @@ async def _find_related_text_unit_from_relationships(
max_token_size=query_param.max_token_for_text_unit, max_token_size=query_param.max_token_for_text_unit,
) )
logger.info(
f"Truncate text chunks from {len(valid_text_units)} to {len(truncated_text_units)}"
)
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
return all_text_units return all_text_units
@@ -1541,7 +1577,8 @@ async def naive_query(
logger.warning("No chunks left after truncation") logger.warning("No chunks left after truncation")
return PROMPTS["fail_response"] return PROMPTS["fail_response"]
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") logger.info(f"Truncate text chunks from {len(chunks)} to {len(maybe_trun_chunks)}")
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context: if query_param.only_need_context:
@@ -1564,6 +1601,9 @@ async def naive_query(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.info(f"[naive_query]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,
@@ -1706,6 +1746,9 @@ async def kg_query_with_keywords(
if query_param.only_need_prompt: if query_param.only_need_prompt:
return sys_prompt return sys_prompt
len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt))
logger.info(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}")
response = await use_model_func( response = await use_model_func(
query, query,
system_prompt=sys_prompt, system_prompt=sys_prompt,