Refactor: Unify naive context to JSON format

- Merges 'mix' mode query handling into 'hybrid' mode, simplifying query logic by removing the dedicated `mix_kg_vector_query` function
- Standardizes vector search result by using JSON string format to build context
- Fixes a bug in `query_with_keywords` ensuring `hl_keywords` and `ll_keywords` are correctly passed to `kg_query_with_keywords`
This commit is contained in:
yangdx
2025-05-07 17:42:14 +08:00
parent 59771b60df
commit 156244e260
3 changed files with 148 additions and 309 deletions

View File

@@ -53,7 +53,6 @@ from .operate import (
extract_entities,
merge_nodes_and_edges,
kg_query,
mix_kg_vector_query,
naive_query,
query_with_keywords,
)
@@ -1437,8 +1436,10 @@ class LightRAG:
"""
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
# Save original query for vector search
param.original_query = query
if param.mode in ["local", "global", "hybrid"]:
if param.mode in ["local", "global", "hybrid", "mix"]:
response = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
@@ -1447,8 +1448,9 @@ class LightRAG:
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
response = await naive_query(
@@ -1457,20 +1459,7 @@ class LightRAG:
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
system_prompt=system_prompt,
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
)
elif param.mode == "bypass":

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
from functools import partial
import asyncio
import traceback
import json
import re
import os
@@ -859,6 +858,7 @@ async def kg_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
) -> str | AsyncIterator[str]:
if query_param.model_func:
use_model_func = query_param.model_func
@@ -911,6 +911,7 @@ async def kg_query(
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb,
)
if query_param.only_need_context:
@@ -1110,182 +1111,17 @@ async def extract_keywords_only(
return hl_keywords, ll_keywords
async def mix_kg_vector_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
This function performs a hybrid search by:
1. Extracting semantic information from knowledge graph
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
"""
# get tokenizer
tokenizer: Tokenizer = global_config["tokenizer"]
if query_param.model_func:
use_model_func = query_param.model_func
else:
use_model_func = global_config["llm_model_func"]
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# 1. Cache handling
args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query"
)
if cached_response is not None:
return cached_response
# Process conversation history
history_context = ""
if query_param.conversation_history:
history_context = get_conversation_turns(
query_param.conversation_history, query_param.history_turns
)
# 2. Execute knowledge graph and vector searches in parallel
async def _get_kg_context():
try:
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
if not hl_keywords and not ll_keywords:
logger.warning("Both high-level and low-level keywords are empty")
return None
# Convert keyword lists to strings
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Set query mode based on available keywords
if not ll_keywords_str and not hl_keywords_str:
return None
elif not ll_keywords_str:
query_param.mode = "global"
elif not hl_keywords_str:
query_param.mode = "local"
else:
query_param.mode = "hybrid"
# Build knowledge graph context
context = await _build_query_context(
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
return context
except Exception as e:
logger.error(f"Error in _get_kg_context: {str(e)}")
traceback.print_exc()
return None
# 3. Execute both retrievals in parallel
kg_context, vector_context = await asyncio.gather(
_get_kg_context(), _get_vector_context(query, chunks_vdb, query_param, tokenizer)
)
# 4. Merge contexts
if kg_context is None and vector_context is None:
return PROMPTS["fail_response"]
if query_param.only_need_context:
context_str = f"""\r\n\r\n-----Knowledge Graph Context-----\r\n\r\n
{kg_context if kg_context else "No relevant knowledge graph information found"}
\r\n\r\n-----Vector Context-----\r\n\r\n
{vector_context if vector_context else "No relevant text information found"}
""".strip()
return context_str
# 5. Construct hybrid prompt
sys_prompt = (
system_prompt if system_prompt else PROMPTS["mix_rag_response"]
).format(
kg_context=kg_context
if kg_context
else "No relevant knowledge graph information found",
vector_context=vector_context
if vector_context
else "No relevant text information found",
response_type=query_param.response_type,
history=history_context,
)
if query_param.only_need_prompt:
return sys_prompt
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
# 6. Generate response
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
# Clean up response content
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"):
# 7. Save cache - Only cache after collecting complete response
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode="mix",
cache_type="query",
),
)
return response
async def _get_vector_context(
query: str,
chunks_vdb: BaseVectorStorage,
query_param: QueryParam,
tokenizer: Tokenizer,
) -> str | None:
) -> tuple[list, list, list] | None:
"""
Retrieve vector context from the vector database.
This function performs vector search to find relevant text chunks for a query,
formats them with file path and creation time information, and truncates
the results to fit within token limits.
formats them with file path and creation time information.
Args:
query: The query string to search for
@@ -1294,18 +1130,15 @@ async def _get_vector_context(
tokenizer: Tokenizer for counting tokens
Returns:
Formatted string containing relevant text chunks, or None if no results found
Tuple (empty_entities, empty_relations, text_units) for combine_contexts,
compatible with _get_edge_data and _get_node_data format
"""
try:
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
mix_topk = (
min(10, query_param.top_k)
if hasattr(query_param, "mode") and query_param.mode == "mix"
else query_param.top_k
results = await chunks_vdb.query(
query, top_k=query_param.top_k, ids=query_param.ids
)
results = await chunks_vdb.query(query, top_k=mix_topk, ids=query_param.ids)
if not results:
return None
return [], [], []
valid_chunks = []
for result in results:
@@ -1314,12 +1147,12 @@ async def _get_vector_context(
chunk_with_time = {
"content": result["content"],
"created_at": result.get("created_at", None),
"file_path": result.get("file_path", None),
"file_path": result.get("file_path", "unknown_source"),
}
valid_chunks.append(chunk_with_time)
if not valid_chunks:
return None
return [], [], []
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
@@ -1331,26 +1164,37 @@ async def _get_vector_context(
logger.debug(
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
logger.info(f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {mix_topk}")
logger.info(
f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
)
if not maybe_trun_chunks:
return None
return [], [], []
# Include time information in content
formatted_chunks = []
for c in maybe_trun_chunks:
chunk_text = "File path: " + c["file_path"] + "\r\n\r\n" + c["content"]
if c["created_at"]:
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\r\n\r\n{chunk_text}"
formatted_chunks.append(chunk_text)
# Create empty entities and relations contexts
entities_context = []
relations_context = []
logger.debug(
f"Truncate chunks from {len(valid_chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
# Create text_units_context in the same format as _get_edge_data and _get_node_data
text_units_section_list = [["id", "content", "file_path"]]
for i, chunk in enumerate(maybe_trun_chunks):
# Add to text_units_section_list
text_units_section_list.append(
[
i + 1, # id
chunk["content"], # content
chunk["file_path"], # file_path
]
)
return "\r\n\r\n-------New Chunk-------\r\n\r\n".join(formatted_chunks)
# Convert to dictionary format using list_of_list_to_dict
text_units_context = list_of_list_to_dict(text_units_section_list)
return entities_context, relations_context, text_units_context
except Exception as e:
logger.error(f"Error in _get_vector_context: {e}")
return None
return [], [], []
async def _build_query_context(
@@ -1361,8 +1205,11 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode
):
logger.info(f"Process {os.getpid()} buidling query context...")
logger.info(f"Process {os.getpid()} building query context...")
# Handle local and global modes as before
if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data(
ll_keywords,
@@ -1379,7 +1226,7 @@ async def _build_query_context(
text_chunks_db,
query_param,
)
else: # hybrid mode
else: # hybrid or mix mode
ll_data = await _get_node_data(
ll_keywords,
knowledge_graph_inst,
@@ -1407,10 +1254,43 @@ async def _build_query_context(
hl_text_units_context,
) = hl_data
entities_context, relations_context, text_units_context = combine_contexts(
[hl_entities_context, ll_entities_context],
[hl_relations_context, ll_relations_context],
[hl_text_units_context, ll_text_units_context],
# Initialize vector data with empty lists
vector_entities_context, vector_relations_context, vector_text_units_context = (
[],
[],
[],
)
# Only get vector data if in mix mode
if query_param.mode == "mix" and hasattr(query_param, "original_query"):
# Get tokenizer from text_chunks_db
tokenizer = text_chunks_db.global_config.get("tokenizer")
# Get vector context in triple format
vector_data = await _get_vector_context(
query_param.original_query, # We need to pass the original query
chunks_vdb,
query_param,
tokenizer,
)
# If vector_data is not None, unpack it
if vector_data is not None:
(
vector_entities_context,
vector_relations_context,
vector_text_units_context,
) = vector_data
# Combine and deduplicate the entities, relationships, and sources
entities_context = process_combine_contexts(
hl_entities_context, ll_entities_context, vector_entities_context
)
relations_context = process_combine_contexts(
hl_relations_context, ll_relations_context, vector_relations_context
)
text_units_context = process_combine_contexts(
hl_text_units_context, ll_text_units_context, vector_text_units_context
)
# not necessary to use LLM to generate a response
if not entities_context and not relations_context:
@@ -1539,7 +1419,7 @@ async def _get_node_data(
entites_section_list.append(
[
i,
i + 1,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
@@ -1574,7 +1454,7 @@ async def _get_node_data(
relations_section_list.append(
[
i,
i + 1,
e["src_tgt"][0],
e["src_tgt"][1],
e["description"],
@@ -1590,7 +1470,7 @@ async def _get_node_data(
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append(
[i, t["content"], t.get("file_path", "unknown_source")]
[i + 1, t["content"], t.get("file_path", "unknown_source")]
)
text_units_context = list_of_list_to_dict(text_units_section_list)
return entities_context, relations_context, text_units_context
@@ -1859,7 +1739,7 @@ async def _get_edge_data(
relations_section_list.append(
[
i,
i + 1,
e["src_id"],
e["tgt_id"],
e["description"],
@@ -1886,7 +1766,7 @@ async def _get_edge_data(
entites_section_list.append(
[
i,
i + 1,
n["entity_name"],
n.get("entity_type", "UNKNOWN"),
n.get("description", "UNKNOWN"),
@@ -1899,7 +1779,9 @@ async def _get_edge_data(
text_units_section_list = [["id", "content", "file_path"]]
for i, t in enumerate(use_text_units):
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
text_units_section_list.append(
[i + 1, t["content"], t.get("file_path", "unknown")]
)
text_units_context = list_of_list_to_dict(text_units_section_list)
return entities_context, relations_context, text_units_context
@@ -2016,25 +1898,6 @@ async def _find_related_text_unit_from_relationships(
return all_text_units
def combine_contexts(entities, relationships, sources):
# Function to extract entities, relationships, and sources from context strings
hl_entities, ll_entities = entities[0], entities[1]
hl_relationships, ll_relationships = relationships[0], relationships[1]
hl_sources, ll_sources = sources[0], sources[1]
# Combine and deduplicate the entities
combined_entities = process_combine_contexts(hl_entities, ll_entities)
# Combine and deduplicate the relationships
combined_relationships = process_combine_contexts(
hl_relationships, ll_relationships
)
# Combine and deduplicate the sources
combined_sources = process_combine_contexts(hl_sources, ll_sources)
return combined_entities, combined_relationships, combined_sources
async def naive_query(
query: str,
chunks_vdb: BaseVectorStorage,
@@ -2060,14 +1923,24 @@ async def naive_query(
return cached_response
tokenizer: Tokenizer = global_config["tokenizer"]
section = await _get_vector_context(query, chunks_vdb, query_param, tokenizer)
if section is None:
_, _, text_units_context = await _get_vector_context(
query, chunks_vdb, query_param, tokenizer
)
if text_units_context is None or len(text_units_context) == 0:
return PROMPTS["fail_response"]
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
if query_param.only_need_context:
return section
return f"""
---Document Chunks---
```json
{text_units_str}
```
"""
# Process conversation history
history_context = ""
if query_param.conversation_history:
@@ -2077,7 +1950,7 @@ async def naive_query(
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format(
content_data=section,
content_data=text_units_str,
response_type=query_param.response_type,
history=history_context,
)
@@ -2134,6 +2007,9 @@ async def kg_query_with_keywords(
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
ll_keywords: list[str] = [],
hl_keywords: list[str] = [],
chunks_vdb: BaseVectorStorage | None = None,
) -> str | AsyncIterator[str]:
"""
Refactored kg_query that does NOT extract keywords by itself.
@@ -2147,9 +2023,6 @@ async def kg_query_with_keywords(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# ---------------------------
# 1) Handle potential cache for query results
# ---------------------------
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
@@ -2157,14 +2030,6 @@ async def kg_query_with_keywords(
if cached_response is not None:
return cached_response
# ---------------------------
# 2) RETRIEVE KEYWORDS FROM query_param
# ---------------------------
# If these fields don't exist, default to empty lists/strings.
hl_keywords = getattr(query_param, "hl_keywords", []) or []
ll_keywords = getattr(query_param, "ll_keywords", []) or []
# If neither has any keywords, you could handle that logic here.
if not hl_keywords and not ll_keywords:
logger.warning(
@@ -2178,25 +2043,9 @@ async def kg_query_with_keywords(
logger.warning("high_level_keywords is empty, switching to local mode.")
query_param.mode = "local"
# Flatten low-level and high-level keywords if needed
ll_keywords_flat = (
[item for sublist in ll_keywords for item in sublist]
if any(isinstance(i, list) for i in ll_keywords)
else ll_keywords
)
hl_keywords_flat = (
[item for sublist in hl_keywords for item in sublist]
if any(isinstance(i, list) for i in hl_keywords)
else hl_keywords
)
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Join the flattened lists
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
# ---------------------------
# 3) BUILD CONTEXT
# ---------------------------
context = await _build_query_context(
ll_keywords_str,
hl_keywords_str,
@@ -2205,18 +2054,14 @@ async def kg_query_with_keywords(
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb=chunks_vdb,
)
if not context:
return PROMPTS["fail_response"]
# If only context is needed, return it
if query_param.only_need_context:
return context
# ---------------------------
# 4) BUILD THE SYSTEM PROMPT + CALL LLM
# ---------------------------
# Process conversation history
history_context = ""
if query_param.conversation_history:
@@ -2258,7 +2103,6 @@ async def kg_query_with_keywords(
)
if hashing_kv.global_config.get("enable_llm_cache"):
# 7. Save cache - 只有在收集完整响应后才缓存
await save_to_cache(
hashing_kv,
CacheData(
@@ -2319,12 +2163,15 @@ async def query_with_keywords(
)
# Create a new string with the prompt and the keywords
ll_keywords_str = ", ".join(ll_keywords)
hl_keywords_str = ", ".join(hl_keywords)
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
keywords_str = ", ".join(ll_keywords + hl_keywords)
formatted_question = (
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
)
param.original_query = query
# Use appropriate query method based on mode
if param.mode in ["local", "global", "hybrid"]:
if param.mode in ["local", "global", "hybrid", "mix"]:
return await kg_query_with_keywords(
formatted_question,
knowledge_graph_inst,
@@ -2334,6 +2181,9 @@ async def query_with_keywords(
param,
global_config,
hashing_kv=hashing_kv,
hl_keywords=hl_keywords,
ll_keywords=ll_keywords,
chunks_vdb=chunks_vdb,
)
elif param.mode == "naive":
return await naive_query(
@@ -2344,17 +2194,5 @@ async def query_with_keywords(
global_config,
hashing_kv=hashing_kv,
)
elif param.mode == "mix":
return await mix_kg_vector_query(
formatted_question,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
chunks_vdb,
text_chunks_db,
param,
global_config,
hashing_kv=hashing_kv,
)
else:
raise ValueError(f"Unknown mode {param.mode}")

View File

@@ -822,21 +822,33 @@ def xml_to_json(xml_file):
return None
def process_combine_contexts(
hl_context: list[dict[str, str]], ll_context: list[dict[str, str]]
):
def process_combine_contexts(*context_lists):
"""
Combine multiple context lists and remove duplicate content
Args:
*context_lists: Any number of context lists
Returns:
Combined context list with duplicates removed
"""
seen_content = {}
combined_data = []
for item in hl_context + ll_context:
# Iterate through all input context lists
for context_list in context_lists:
if not context_list: # Skip empty lists
continue
for item in context_list:
content_dict = {k: v for k, v in item.items() if k != "id"}
content_key = tuple(sorted(content_dict.items()))
if content_key not in seen_content:
seen_content[content_key] = item
combined_data.append(item)
# Reassign IDs
for i, item in enumerate(combined_data):
item["id"] = str(i)
item["id"] = str(i + 1)
return combined_data