From 601df31edf366efb91a1a8b16e754d626c8aba86 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 16 Feb 2025 19:26:57 +0800 Subject: [PATCH] feat: move query-related settings to env file for better configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add env vars for token and chunk settings • Add token count logging for prompts • Add token count logging for context • Move hardcoded values to env variables • Improve logging clarity and consistency --- .env.example | 19 +++++++++++----- lightrag/base.py | 8 ++++--- lightrag/lightrag.py | 6 ++--- lightrag/operate.py | 53 +++++++++++++++++++++++++++++++++++++++----- 4 files changed, 69 insertions(+), 17 deletions(-) diff --git a/.env.example b/.env.example index 2701335a..7057281d 100644 --- a/.env.example +++ b/.env.example @@ -27,14 +27,21 @@ TIMEOUT=300 ### RAG Configuration MAX_ASYNC=4 -MAX_TOKENS=32768 EMBEDDING_DIM=1024 MAX_EMBED_TOKENS=8192 -#HISTORY_TURNS=3 -#CHUNK_SIZE=1200 -#CHUNK_OVERLAP_SIZE=100 -#COSINE_THRESHOLD=0.2 -#TOP_K=60 +### Settings relative to query +HISTORY_TURNS=3 +COSINE_THRESHOLD=0.2 +TOP_K=60 +MAX_TOKEN_TEXT_CHUNK = 4000 +MAX_TOKEN_RELATION_DESC = 4000 +MAX_TOKEN_ENTITY_DESC = 4000 +### Settings relative to indexing +CHUNK_SIZE=1200 +CHUNK_OVERLAP_SIZE=100 +MAX_TOKENS=32768 +MAX_TOKEN_SUMMARY=500 +SUMMARY_LANGUAGE=English ### LLM Configuration (Use valid host. For local services, you can use host.docker.internal) ### Ollama example diff --git a/lightrag/base.py b/lightrag/base.py index e75167c4..aa8e6d9e 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -54,13 +54,15 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" - max_token_for_text_unit: int = 4000 + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) """Maximum number of tokens allowed for each retrieved text chunk.""" - max_token_for_global_context: int = 4000 + max_token_for_global_context: int = int( + os.getenv("MAX_TOKEN_RELATION_DESC", "4000") + ) """Maximum number of tokens allocated for relationship descriptions in global retrieval.""" - max_token_for_local_context: int = 4000 + max_token_for_local_context: int = int(os.getenv("MAX_TOKEN_ENTITY_DESC", "4000")) """Maximum number of tokens allocated for entity descriptions in local retrieval.""" hl_keywords: list[str] = field(default_factory=list) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9f74c917..554cba22 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -263,10 +263,10 @@ class LightRAG: """Directory where logs are stored. Defaults to the current working directory.""" # Text chunking - chunk_token_size: int = 1200 + chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200")) """Maximum number of tokens per text chunk when splitting documents.""" - chunk_overlap_token_size: int = 100 + chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100")) """Number of overlapping tokens between consecutive text chunks to preserve context.""" tiktoken_model_name: str = "gpt-4o-mini" @@ -276,7 +276,7 @@ class LightRAG: entity_extract_max_gleaning: int = 1 """Maximum number of entity extraction attempts for ambiguous content.""" - entity_summary_to_max_tokens: int = 500 + entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500")) """Maximum number of tokens used for summarizing extracted entities.""" # Node embedding diff --git a/lightrag/operate.py b/lightrag/operate.py index 04aad0d4..fb351a71 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -642,9 +642,13 @@ async def kg_query( history=history_context, ) + if query_param.only_need_prompt: return sys_prompt + len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) + logger.info(f"[kg_query]Prompt Tokens: {len_of_prompts}") + response = await use_model_func( query, system_prompt=sys_prompt, @@ -730,6 +734,9 @@ async def extract_keywords_only( query=text, examples=examples, language=language, history=history_context ) + len_of_prompts = len(encode_string_by_tiktoken(kw_prompt)) + logger.info(f"[kg_query]Prompt Tokens: {len_of_prompts}") + # 5. Call the LLM for keyword extraction use_model_func = global_config["llm_model_func"] result = await use_model_func(kw_prompt, keyword_extraction=True) @@ -893,7 +900,9 @@ async def mix_kg_vector_query( chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}" formatted_chunks.append(chunk_text) - logger.info(f"Truncate {len(chunks)} to {len(formatted_chunks)} chunks") + logger.info( + f"Truncate text chunks from {len(chunks)} to {len(formatted_chunks)}" + ) return "\n--New Chunk--\n".join(formatted_chunks) except Exception as e: logger.error(f"Error in get_vector_context: {e}") @@ -926,6 +935,9 @@ async def mix_kg_vector_query( if query_param.only_need_prompt: return sys_prompt + len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) + logger.info(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}") + # 6. Generate response response = await use_model_func( query, @@ -1031,7 +1043,7 @@ async def _build_query_context( if not entities_context.strip() and not relations_context.strip(): return None - return f""" + result = f""" -----Entities----- ```csv {entities_context} @@ -1045,6 +1057,15 @@ async def _build_query_context( {text_units_context} ``` """ + contex_tokens = len(encode_string_by_tiktoken(result)) + entities_tokens = len(encode_string_by_tiktoken(entities_context)) + relations_tokens = len(encode_string_by_tiktoken(relations_context)) + text_units_tokens = len(encode_string_by_tiktoken(text_units_context)) + logger.info( + f"Context Tokens - Total: {contex_tokens}, Entities: {entities_tokens}, Relations: {relations_tokens}, Chunks: {text_units_tokens}" + ) + + return result async def _get_node_data( @@ -1089,7 +1110,7 @@ async def _get_node_data( ), ) logger.info( - f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units" + f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} chunks" ) # build prompt @@ -1222,6 +1243,10 @@ async def _find_most_related_text_unit_from_entities( max_token_size=query_param.max_token_for_text_unit, ) + logger.info( + f"Truncate text chunks from {len(all_text_units_lookup)} to {len(all_text_units)}" + ) + all_text_units = [t["data"] for t in all_text_units] return all_text_units @@ -1263,6 +1288,9 @@ async def _find_most_related_edges_from_entities( key=lambda x: x["description"], max_token_size=query_param.max_token_for_global_context, ) + + logger.info(f"Truncate relations from {len(all_edges)} to {len(all_edges_data)}") + return all_edges_data @@ -1310,11 +1338,13 @@ async def _get_edge_data( edge_datas = sorted( edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True ) + len_edge_datas = len(edge_datas) edge_datas = truncate_list_by_token_size( edge_datas, key=lambda x: x["description"], max_token_size=query_param.max_token_for_global_context, ) + logger.info(f"Truncate relations from {len_edge_datas} to {len(edge_datas)}") use_entities, use_text_units = await asyncio.gather( _find_most_related_entities_from_relationships( @@ -1325,7 +1355,7 @@ async def _get_edge_data( ), ) logger.info( - f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units" + f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} chunks" ) relations_section_list = [ @@ -1414,11 +1444,13 @@ async def _find_most_related_entities_from_relationships( for k, n, d in zip(entity_names, node_datas, node_degrees) ] + len_node_datas = len(node_datas) node_datas = truncate_list_by_token_size( node_datas, key=lambda x: x["description"], max_token_size=query_param.max_token_for_local_context, ) + logger.info(f"Truncate entities from {len_node_datas} to {len(node_datas)}") return node_datas @@ -1474,6 +1506,10 @@ async def _find_related_text_unit_from_relationships( max_token_size=query_param.max_token_for_text_unit, ) + logger.info( + f"Truncate text chunks from {len(valid_text_units)} to {len(truncated_text_units)}" + ) + all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return all_text_units @@ -1541,7 +1577,8 @@ async def naive_query( logger.warning("No chunks left after truncation") return PROMPTS["fail_response"] - logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") + logger.info(f"Truncate text chunks from {len(chunks)} to {len(maybe_trun_chunks)}") + section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks]) if query_param.only_need_context: @@ -1564,6 +1601,9 @@ async def naive_query( if query_param.only_need_prompt: return sys_prompt + len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) + logger.info(f"[naive_query]Prompt Tokens: {len_of_prompts}") + response = await use_model_func( query, system_prompt=sys_prompt, @@ -1706,6 +1746,9 @@ async def kg_query_with_keywords( if query_param.only_need_prompt: return sys_prompt + len_of_prompts = len(encode_string_by_tiktoken(query + sys_prompt)) + logger.info(f"[kg_query_with_keywords]Prompt Tokens: {len_of_prompts}") + response = await use_model_func( query, system_prompt=sys_prompt,