Refactor: Unify naive context to JSON format
- Merges 'mix' mode query handling into 'hybrid' mode, simplifying query logic by removing the dedicated `mix_kg_vector_query` function - Standardizes vector search result by using JSON string format to build context - Fixes a bug in `query_with_keywords` ensuring `hl_keywords` and `ll_keywords` are correctly passed to `kg_query_with_keywords`
This commit is contained in:
@@ -53,7 +53,6 @@ from .operate import (
|
|||||||
extract_entities,
|
extract_entities,
|
||||||
merge_nodes_and_edges,
|
merge_nodes_and_edges,
|
||||||
kg_query,
|
kg_query,
|
||||||
mix_kg_vector_query,
|
|
||||||
naive_query,
|
naive_query,
|
||||||
query_with_keywords,
|
query_with_keywords,
|
||||||
)
|
)
|
||||||
@@ -1437,8 +1436,10 @@ class LightRAG:
|
|||||||
"""
|
"""
|
||||||
# If a custom model is provided in param, temporarily update global config
|
# If a custom model is provided in param, temporarily update global config
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
|
# Save original query for vector search
|
||||||
|
param.original_query = query
|
||||||
|
|
||||||
if param.mode in ["local", "global", "hybrid"]:
|
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||||||
response = await kg_query(
|
response = await kg_query(
|
||||||
query.strip(),
|
query.strip(),
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
@@ -1447,8 +1448,9 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
global_config,
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
hashing_kv=self.llm_response_cache,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
chunks_vdb=self.chunks_vdb,
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
response = await naive_query(
|
response = await naive_query(
|
||||||
@@ -1457,20 +1459,7 @@ class LightRAG:
|
|||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
param,
|
||||||
global_config,
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
hashing_kv=self.llm_response_cache,
|
||||||
system_prompt=system_prompt,
|
|
||||||
)
|
|
||||||
elif param.mode == "mix":
|
|
||||||
response = await mix_kg_vector_query(
|
|
||||||
query.strip(),
|
|
||||||
self.chunk_entity_relation_graph,
|
|
||||||
self.entities_vdb,
|
|
||||||
self.relationships_vdb,
|
|
||||||
self.chunks_vdb,
|
|
||||||
self.text_chunks,
|
|
||||||
param,
|
|
||||||
global_config,
|
|
||||||
hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
|
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
elif param.mode == "bypass":
|
elif param.mode == "bypass":
|
||||||
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
@@ -859,6 +858,7 @@ async def kg_query(
|
|||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
) -> str | AsyncIterator[str]:
|
) -> str | AsyncIterator[str]:
|
||||||
if query_param.model_func:
|
if query_param.model_func:
|
||||||
use_model_func = query_param.model_func
|
use_model_func = query_param.model_func
|
||||||
@@ -911,6 +911,7 @@ async def kg_query(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
chunks_vdb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
@@ -1110,182 +1111,17 @@ async def extract_keywords_only(
|
|||||||
return hl_keywords, ll_keywords
|
return hl_keywords, ll_keywords
|
||||||
|
|
||||||
|
|
||||||
async def mix_kg_vector_query(
|
|
||||||
query: str,
|
|
||||||
knowledge_graph_inst: BaseGraphStorage,
|
|
||||||
entities_vdb: BaseVectorStorage,
|
|
||||||
relationships_vdb: BaseVectorStorage,
|
|
||||||
chunks_vdb: BaseVectorStorage,
|
|
||||||
text_chunks_db: BaseKVStorage,
|
|
||||||
query_param: QueryParam,
|
|
||||||
global_config: dict[str, str],
|
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
) -> str | AsyncIterator[str]:
|
|
||||||
"""
|
|
||||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
|
||||||
|
|
||||||
This function performs a hybrid search by:
|
|
||||||
1. Extracting semantic information from knowledge graph
|
|
||||||
2. Retrieving relevant text chunks through vector similarity
|
|
||||||
3. Combining both results for comprehensive answer generation
|
|
||||||
"""
|
|
||||||
# get tokenizer
|
|
||||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
|
||||||
|
|
||||||
if query_param.model_func:
|
|
||||||
use_model_func = query_param.model_func
|
|
||||||
else:
|
|
||||||
use_model_func = global_config["llm_model_func"]
|
|
||||||
# Apply higher priority (5) to query relation LLM function
|
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
|
||||||
|
|
||||||
# 1. Cache handling
|
|
||||||
args_hash = compute_args_hash("mix", query, cache_type="query")
|
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
|
||||||
hashing_kv, args_hash, query, "mix", cache_type="query"
|
|
||||||
)
|
|
||||||
if cached_response is not None:
|
|
||||||
return cached_response
|
|
||||||
|
|
||||||
# Process conversation history
|
|
||||||
history_context = ""
|
|
||||||
if query_param.conversation_history:
|
|
||||||
history_context = get_conversation_turns(
|
|
||||||
query_param.conversation_history, query_param.history_turns
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Execute knowledge graph and vector searches in parallel
|
|
||||||
async def _get_kg_context():
|
|
||||||
try:
|
|
||||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
|
||||||
query, query_param, global_config, hashing_kv
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hl_keywords and not ll_keywords:
|
|
||||||
logger.warning("Both high-level and low-level keywords are empty")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Convert keyword lists to strings
|
|
||||||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
|
||||||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
|
||||||
|
|
||||||
# Set query mode based on available keywords
|
|
||||||
if not ll_keywords_str and not hl_keywords_str:
|
|
||||||
return None
|
|
||||||
elif not ll_keywords_str:
|
|
||||||
query_param.mode = "global"
|
|
||||||
elif not hl_keywords_str:
|
|
||||||
query_param.mode = "local"
|
|
||||||
else:
|
|
||||||
query_param.mode = "hybrid"
|
|
||||||
|
|
||||||
# Build knowledge graph context
|
|
||||||
context = await _build_query_context(
|
|
||||||
ll_keywords_str,
|
|
||||||
hl_keywords_str,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
entities_vdb,
|
|
||||||
relationships_vdb,
|
|
||||||
text_chunks_db,
|
|
||||||
query_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in _get_kg_context: {str(e)}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 3. Execute both retrievals in parallel
|
|
||||||
kg_context, vector_context = await asyncio.gather(
|
|
||||||
_get_kg_context(), _get_vector_context(query, chunks_vdb, query_param, tokenizer)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Merge contexts
|
|
||||||
if kg_context is None and vector_context is None:
|
|
||||||
return PROMPTS["fail_response"]
|
|
||||||
|
|
||||||
if query_param.only_need_context:
|
|
||||||
context_str = f"""\r\n\r\n-----Knowledge Graph Context-----\r\n\r\n
|
|
||||||
{kg_context if kg_context else "No relevant knowledge graph information found"}
|
|
||||||
|
|
||||||
\r\n\r\n-----Vector Context-----\r\n\r\n
|
|
||||||
{vector_context if vector_context else "No relevant text information found"}
|
|
||||||
""".strip()
|
|
||||||
return context_str
|
|
||||||
|
|
||||||
# 5. Construct hybrid prompt
|
|
||||||
sys_prompt = (
|
|
||||||
system_prompt if system_prompt else PROMPTS["mix_rag_response"]
|
|
||||||
).format(
|
|
||||||
kg_context=kg_context
|
|
||||||
if kg_context
|
|
||||||
else "No relevant knowledge graph information found",
|
|
||||||
vector_context=vector_context
|
|
||||||
if vector_context
|
|
||||||
else "No relevant text information found",
|
|
||||||
response_type=query_param.response_type,
|
|
||||||
history=history_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
if query_param.only_need_prompt:
|
|
||||||
return sys_prompt
|
|
||||||
|
|
||||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
|
||||||
logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}")
|
|
||||||
|
|
||||||
# 6. Generate response
|
|
||||||
response = await use_model_func(
|
|
||||||
query,
|
|
||||||
system_prompt=sys_prompt,
|
|
||||||
stream=query_param.stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up response content
|
|
||||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
|
||||||
response = (
|
|
||||||
response.replace(sys_prompt, "")
|
|
||||||
.replace("user", "")
|
|
||||||
.replace("model", "")
|
|
||||||
.replace(query, "")
|
|
||||||
.replace("<system>", "")
|
|
||||||
.replace("</system>", "")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
|
||||||
# 7. Save cache - Only cache after collecting complete response
|
|
||||||
await save_to_cache(
|
|
||||||
hashing_kv,
|
|
||||||
CacheData(
|
|
||||||
args_hash=args_hash,
|
|
||||||
content=response,
|
|
||||||
prompt=query,
|
|
||||||
quantized=quantized,
|
|
||||||
min_val=min_val,
|
|
||||||
max_val=max_val,
|
|
||||||
mode="mix",
|
|
||||||
cache_type="query",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_vector_context(
|
async def _get_vector_context(
|
||||||
query: str,
|
query: str,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
) -> str | None:
|
) -> tuple[list, list, list] | None:
|
||||||
"""
|
"""
|
||||||
Retrieve vector context from the vector database.
|
Retrieve vector context from the vector database.
|
||||||
|
|
||||||
This function performs vector search to find relevant text chunks for a query,
|
This function performs vector search to find relevant text chunks for a query,
|
||||||
formats them with file path and creation time information, and truncates
|
formats them with file path and creation time information.
|
||||||
the results to fit within token limits.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The query string to search for
|
query: The query string to search for
|
||||||
@@ -1294,18 +1130,15 @@ async def _get_vector_context(
|
|||||||
tokenizer: Tokenizer for counting tokens
|
tokenizer: Tokenizer for counting tokens
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted string containing relevant text chunks, or None if no results found
|
Tuple (empty_entities, empty_relations, text_units) for combine_contexts,
|
||||||
|
compatible with _get_edge_data and _get_node_data format
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
results = await chunks_vdb.query(
|
||||||
mix_topk = (
|
query, top_k=query_param.top_k, ids=query_param.ids
|
||||||
min(10, query_param.top_k)
|
|
||||||
if hasattr(query_param, "mode") and query_param.mode == "mix"
|
|
||||||
else query_param.top_k
|
|
||||||
)
|
)
|
||||||
results = await chunks_vdb.query(query, top_k=mix_topk, ids=query_param.ids)
|
|
||||||
if not results:
|
if not results:
|
||||||
return None
|
return [], [], []
|
||||||
|
|
||||||
valid_chunks = []
|
valid_chunks = []
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -1314,12 +1147,12 @@ async def _get_vector_context(
|
|||||||
chunk_with_time = {
|
chunk_with_time = {
|
||||||
"content": result["content"],
|
"content": result["content"],
|
||||||
"created_at": result.get("created_at", None),
|
"created_at": result.get("created_at", None),
|
||||||
"file_path": result.get("file_path", None),
|
"file_path": result.get("file_path", "unknown_source"),
|
||||||
}
|
}
|
||||||
valid_chunks.append(chunk_with_time)
|
valid_chunks.append(chunk_with_time)
|
||||||
|
|
||||||
if not valid_chunks:
|
if not valid_chunks:
|
||||||
return None
|
return [], [], []
|
||||||
|
|
||||||
maybe_trun_chunks = truncate_list_by_token_size(
|
maybe_trun_chunks = truncate_list_by_token_size(
|
||||||
valid_chunks,
|
valid_chunks,
|
||||||
@@ -1331,26 +1164,37 @@ async def _get_vector_context(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||||
)
|
)
|
||||||
logger.info(f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {mix_topk}")
|
logger.info(
|
||||||
|
f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
|
||||||
|
)
|
||||||
|
|
||||||
if not maybe_trun_chunks:
|
if not maybe_trun_chunks:
|
||||||
return None
|
return [], [], []
|
||||||
|
|
||||||
# Include time information in content
|
# Create empty entities and relations contexts
|
||||||
formatted_chunks = []
|
entities_context = []
|
||||||
for c in maybe_trun_chunks:
|
relations_context = []
|
||||||
chunk_text = "File path: " + c["file_path"] + "\r\n\r\n" + c["content"]
|
|
||||||
if c["created_at"]:
|
|
||||||
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\r\n\r\n{chunk_text}"
|
|
||||||
formatted_chunks.append(chunk_text)
|
|
||||||
|
|
||||||
logger.debug(
|
# Create text_units_context in the same format as _get_edge_data and _get_node_data
|
||||||
f"Truncate chunks from {len(valid_chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
text_units_section_list = [["id", "content", "file_path"]]
|
||||||
)
|
|
||||||
return "\r\n\r\n-------New Chunk-------\r\n\r\n".join(formatted_chunks)
|
for i, chunk in enumerate(maybe_trun_chunks):
|
||||||
|
# Add to text_units_section_list
|
||||||
|
text_units_section_list.append(
|
||||||
|
[
|
||||||
|
i + 1, # id
|
||||||
|
chunk["content"], # content
|
||||||
|
chunk["file_path"], # file_path
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to dictionary format using list_of_list_to_dict
|
||||||
|
text_units_context = list_of_list_to_dict(text_units_section_list)
|
||||||
|
|
||||||
|
return entities_context, relations_context, text_units_context
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in _get_vector_context: {e}")
|
logger.error(f"Error in _get_vector_context: {e}")
|
||||||
return None
|
return [], [], []
|
||||||
|
|
||||||
|
|
||||||
async def _build_query_context(
|
async def _build_query_context(
|
||||||
@@ -1361,8 +1205,11 @@ async def _build_query_context(
|
|||||||
relationships_vdb: BaseVectorStorage,
|
relationships_vdb: BaseVectorStorage,
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
|
chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode
|
||||||
):
|
):
|
||||||
logger.info(f"Process {os.getpid()} buidling query context...")
|
logger.info(f"Process {os.getpid()} building query context...")
|
||||||
|
|
||||||
|
# Handle local and global modes as before
|
||||||
if query_param.mode == "local":
|
if query_param.mode == "local":
|
||||||
entities_context, relations_context, text_units_context = await _get_node_data(
|
entities_context, relations_context, text_units_context = await _get_node_data(
|
||||||
ll_keywords,
|
ll_keywords,
|
||||||
@@ -1379,7 +1226,7 @@ async def _build_query_context(
|
|||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
)
|
)
|
||||||
else: # hybrid mode
|
else: # hybrid or mix mode
|
||||||
ll_data = await _get_node_data(
|
ll_data = await _get_node_data(
|
||||||
ll_keywords,
|
ll_keywords,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -1407,10 +1254,43 @@ async def _build_query_context(
|
|||||||
hl_text_units_context,
|
hl_text_units_context,
|
||||||
) = hl_data
|
) = hl_data
|
||||||
|
|
||||||
entities_context, relations_context, text_units_context = combine_contexts(
|
# Initialize vector data with empty lists
|
||||||
[hl_entities_context, ll_entities_context],
|
vector_entities_context, vector_relations_context, vector_text_units_context = (
|
||||||
[hl_relations_context, ll_relations_context],
|
[],
|
||||||
[hl_text_units_context, ll_text_units_context],
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only get vector data if in mix mode
|
||||||
|
if query_param.mode == "mix" and hasattr(query_param, "original_query"):
|
||||||
|
# Get tokenizer from text_chunks_db
|
||||||
|
tokenizer = text_chunks_db.global_config.get("tokenizer")
|
||||||
|
|
||||||
|
# Get vector context in triple format
|
||||||
|
vector_data = await _get_vector_context(
|
||||||
|
query_param.original_query, # We need to pass the original query
|
||||||
|
chunks_vdb,
|
||||||
|
query_param,
|
||||||
|
tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If vector_data is not None, unpack it
|
||||||
|
if vector_data is not None:
|
||||||
|
(
|
||||||
|
vector_entities_context,
|
||||||
|
vector_relations_context,
|
||||||
|
vector_text_units_context,
|
||||||
|
) = vector_data
|
||||||
|
|
||||||
|
# Combine and deduplicate the entities, relationships, and sources
|
||||||
|
entities_context = process_combine_contexts(
|
||||||
|
hl_entities_context, ll_entities_context, vector_entities_context
|
||||||
|
)
|
||||||
|
relations_context = process_combine_contexts(
|
||||||
|
hl_relations_context, ll_relations_context, vector_relations_context
|
||||||
|
)
|
||||||
|
text_units_context = process_combine_contexts(
|
||||||
|
hl_text_units_context, ll_text_units_context, vector_text_units_context
|
||||||
)
|
)
|
||||||
# not necessary to use LLM to generate a response
|
# not necessary to use LLM to generate a response
|
||||||
if not entities_context and not relations_context:
|
if not entities_context and not relations_context:
|
||||||
@@ -1539,7 +1419,7 @@ async def _get_node_data(
|
|||||||
|
|
||||||
entites_section_list.append(
|
entites_section_list.append(
|
||||||
[
|
[
|
||||||
i,
|
i + 1,
|
||||||
n["entity_name"],
|
n["entity_name"],
|
||||||
n.get("entity_type", "UNKNOWN"),
|
n.get("entity_type", "UNKNOWN"),
|
||||||
n.get("description", "UNKNOWN"),
|
n.get("description", "UNKNOWN"),
|
||||||
@@ -1574,7 +1454,7 @@ async def _get_node_data(
|
|||||||
|
|
||||||
relations_section_list.append(
|
relations_section_list.append(
|
||||||
[
|
[
|
||||||
i,
|
i + 1,
|
||||||
e["src_tgt"][0],
|
e["src_tgt"][0],
|
||||||
e["src_tgt"][1],
|
e["src_tgt"][1],
|
||||||
e["description"],
|
e["description"],
|
||||||
@@ -1590,7 +1470,7 @@ async def _get_node_data(
|
|||||||
text_units_section_list = [["id", "content", "file_path"]]
|
text_units_section_list = [["id", "content", "file_path"]]
|
||||||
for i, t in enumerate(use_text_units):
|
for i, t in enumerate(use_text_units):
|
||||||
text_units_section_list.append(
|
text_units_section_list.append(
|
||||||
[i, t["content"], t.get("file_path", "unknown_source")]
|
[i + 1, t["content"], t.get("file_path", "unknown_source")]
|
||||||
)
|
)
|
||||||
text_units_context = list_of_list_to_dict(text_units_section_list)
|
text_units_context = list_of_list_to_dict(text_units_section_list)
|
||||||
return entities_context, relations_context, text_units_context
|
return entities_context, relations_context, text_units_context
|
||||||
@@ -1859,7 +1739,7 @@ async def _get_edge_data(
|
|||||||
|
|
||||||
relations_section_list.append(
|
relations_section_list.append(
|
||||||
[
|
[
|
||||||
i,
|
i + 1,
|
||||||
e["src_id"],
|
e["src_id"],
|
||||||
e["tgt_id"],
|
e["tgt_id"],
|
||||||
e["description"],
|
e["description"],
|
||||||
@@ -1886,7 +1766,7 @@ async def _get_edge_data(
|
|||||||
|
|
||||||
entites_section_list.append(
|
entites_section_list.append(
|
||||||
[
|
[
|
||||||
i,
|
i + 1,
|
||||||
n["entity_name"],
|
n["entity_name"],
|
||||||
n.get("entity_type", "UNKNOWN"),
|
n.get("entity_type", "UNKNOWN"),
|
||||||
n.get("description", "UNKNOWN"),
|
n.get("description", "UNKNOWN"),
|
||||||
@@ -1899,7 +1779,9 @@ async def _get_edge_data(
|
|||||||
|
|
||||||
text_units_section_list = [["id", "content", "file_path"]]
|
text_units_section_list = [["id", "content", "file_path"]]
|
||||||
for i, t in enumerate(use_text_units):
|
for i, t in enumerate(use_text_units):
|
||||||
text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")])
|
text_units_section_list.append(
|
||||||
|
[i + 1, t["content"], t.get("file_path", "unknown")]
|
||||||
|
)
|
||||||
text_units_context = list_of_list_to_dict(text_units_section_list)
|
text_units_context = list_of_list_to_dict(text_units_section_list)
|
||||||
return entities_context, relations_context, text_units_context
|
return entities_context, relations_context, text_units_context
|
||||||
|
|
||||||
@@ -2016,25 +1898,6 @@ async def _find_related_text_unit_from_relationships(
|
|||||||
return all_text_units
|
return all_text_units
|
||||||
|
|
||||||
|
|
||||||
def combine_contexts(entities, relationships, sources):
|
|
||||||
# Function to extract entities, relationships, and sources from context strings
|
|
||||||
hl_entities, ll_entities = entities[0], entities[1]
|
|
||||||
hl_relationships, ll_relationships = relationships[0], relationships[1]
|
|
||||||
hl_sources, ll_sources = sources[0], sources[1]
|
|
||||||
# Combine and deduplicate the entities
|
|
||||||
combined_entities = process_combine_contexts(hl_entities, ll_entities)
|
|
||||||
|
|
||||||
# Combine and deduplicate the relationships
|
|
||||||
combined_relationships = process_combine_contexts(
|
|
||||||
hl_relationships, ll_relationships
|
|
||||||
)
|
|
||||||
|
|
||||||
# Combine and deduplicate the sources
|
|
||||||
combined_sources = process_combine_contexts(hl_sources, ll_sources)
|
|
||||||
|
|
||||||
return combined_entities, combined_relationships, combined_sources
|
|
||||||
|
|
||||||
|
|
||||||
async def naive_query(
|
async def naive_query(
|
||||||
query: str,
|
query: str,
|
||||||
chunks_vdb: BaseVectorStorage,
|
chunks_vdb: BaseVectorStorage,
|
||||||
@@ -2060,14 +1923,24 @@ async def naive_query(
|
|||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
section = await _get_vector_context(query, chunks_vdb, query_param, tokenizer)
|
|
||||||
|
|
||||||
if section is None:
|
_, _, text_units_context = await _get_vector_context(
|
||||||
|
query, chunks_vdb, query_param, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_units_context is None or len(text_units_context) == 0:
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
|
text_units_str = json.dumps(text_units_context, ensure_ascii=False)
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
return section
|
return f"""
|
||||||
|
---Document Chunks---
|
||||||
|
|
||||||
|
```json
|
||||||
|
{text_units_str}
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
# Process conversation history
|
# Process conversation history
|
||||||
history_context = ""
|
history_context = ""
|
||||||
if query_param.conversation_history:
|
if query_param.conversation_history:
|
||||||
@@ -2077,7 +1950,7 @@ async def naive_query(
|
|||||||
|
|
||||||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"]
|
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"]
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
content_data=section,
|
content_data=text_units_str,
|
||||||
response_type=query_param.response_type,
|
response_type=query_param.response_type,
|
||||||
history=history_context,
|
history=history_context,
|
||||||
)
|
)
|
||||||
@@ -2134,6 +2007,9 @@ async def kg_query_with_keywords(
|
|||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
|
ll_keywords: list[str] = [],
|
||||||
|
hl_keywords: list[str] = [],
|
||||||
|
chunks_vdb: BaseVectorStorage | None = None,
|
||||||
) -> str | AsyncIterator[str]:
|
) -> str | AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Refactored kg_query that does NOT extract keywords by itself.
|
Refactored kg_query that does NOT extract keywords by itself.
|
||||||
@@ -2147,9 +2023,6 @@ async def kg_query_with_keywords(
|
|||||||
# Apply higher priority (5) to query relation LLM function
|
# Apply higher priority (5) to query relation LLM function
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# 1) Handle potential cache for query results
|
|
||||||
# ---------------------------
|
|
||||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||||
@@ -2157,14 +2030,6 @@ async def kg_query_with_keywords(
|
|||||||
if cached_response is not None:
|
if cached_response is not None:
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# 2) RETRIEVE KEYWORDS FROM query_param
|
|
||||||
# ---------------------------
|
|
||||||
|
|
||||||
# If these fields don't exist, default to empty lists/strings.
|
|
||||||
hl_keywords = getattr(query_param, "hl_keywords", []) or []
|
|
||||||
ll_keywords = getattr(query_param, "ll_keywords", []) or []
|
|
||||||
|
|
||||||
# If neither has any keywords, you could handle that logic here.
|
# If neither has any keywords, you could handle that logic here.
|
||||||
if not hl_keywords and not ll_keywords:
|
if not hl_keywords and not ll_keywords:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -2178,25 +2043,9 @@ async def kg_query_with_keywords(
|
|||||||
logger.warning("high_level_keywords is empty, switching to local mode.")
|
logger.warning("high_level_keywords is empty, switching to local mode.")
|
||||||
query_param.mode = "local"
|
query_param.mode = "local"
|
||||||
|
|
||||||
# Flatten low-level and high-level keywords if needed
|
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||||||
ll_keywords_flat = (
|
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||||||
[item for sublist in ll_keywords for item in sublist]
|
|
||||||
if any(isinstance(i, list) for i in ll_keywords)
|
|
||||||
else ll_keywords
|
|
||||||
)
|
|
||||||
hl_keywords_flat = (
|
|
||||||
[item for sublist in hl_keywords for item in sublist]
|
|
||||||
if any(isinstance(i, list) for i in hl_keywords)
|
|
||||||
else hl_keywords
|
|
||||||
)
|
|
||||||
|
|
||||||
# Join the flattened lists
|
|
||||||
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
|
|
||||||
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
|
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# 3) BUILD CONTEXT
|
|
||||||
# ---------------------------
|
|
||||||
context = await _build_query_context(
|
context = await _build_query_context(
|
||||||
ll_keywords_str,
|
ll_keywords_str,
|
||||||
hl_keywords_str,
|
hl_keywords_str,
|
||||||
@@ -2205,18 +2054,14 @@ async def kg_query_with_keywords(
|
|||||||
relationships_vdb,
|
relationships_vdb,
|
||||||
text_chunks_db,
|
text_chunks_db,
|
||||||
query_param,
|
query_param,
|
||||||
|
chunks_vdb=chunks_vdb,
|
||||||
)
|
)
|
||||||
if not context:
|
if not context:
|
||||||
return PROMPTS["fail_response"]
|
return PROMPTS["fail_response"]
|
||||||
|
|
||||||
# If only context is needed, return it
|
|
||||||
if query_param.only_need_context:
|
if query_param.only_need_context:
|
||||||
return context
|
return context
|
||||||
|
|
||||||
# ---------------------------
|
|
||||||
# 4) BUILD THE SYSTEM PROMPT + CALL LLM
|
|
||||||
# ---------------------------
|
|
||||||
|
|
||||||
# Process conversation history
|
# Process conversation history
|
||||||
history_context = ""
|
history_context = ""
|
||||||
if query_param.conversation_history:
|
if query_param.conversation_history:
|
||||||
@@ -2258,7 +2103,6 @@ async def kg_query_with_keywords(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
# 7. Save cache - 只有在收集完整响应后才缓存
|
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
@@ -2319,12 +2163,15 @@ async def query_with_keywords(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create a new string with the prompt and the keywords
|
# Create a new string with the prompt and the keywords
|
||||||
ll_keywords_str = ", ".join(ll_keywords)
|
keywords_str = ", ".join(ll_keywords + hl_keywords)
|
||||||
hl_keywords_str = ", ".join(hl_keywords)
|
formatted_question = (
|
||||||
formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
|
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
|
||||||
|
)
|
||||||
|
|
||||||
|
param.original_query = query
|
||||||
|
|
||||||
# Use appropriate query method based on mode
|
# Use appropriate query method based on mode
|
||||||
if param.mode in ["local", "global", "hybrid"]:
|
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||||||
return await kg_query_with_keywords(
|
return await kg_query_with_keywords(
|
||||||
formatted_question,
|
formatted_question,
|
||||||
knowledge_graph_inst,
|
knowledge_graph_inst,
|
||||||
@@ -2334,6 +2181,9 @@ async def query_with_keywords(
|
|||||||
param,
|
param,
|
||||||
global_config,
|
global_config,
|
||||||
hashing_kv=hashing_kv,
|
hashing_kv=hashing_kv,
|
||||||
|
hl_keywords=hl_keywords,
|
||||||
|
ll_keywords=ll_keywords,
|
||||||
|
chunks_vdb=chunks_vdb,
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
return await naive_query(
|
return await naive_query(
|
||||||
@@ -2344,17 +2194,5 @@ async def query_with_keywords(
|
|||||||
global_config,
|
global_config,
|
||||||
hashing_kv=hashing_kv,
|
hashing_kv=hashing_kv,
|
||||||
)
|
)
|
||||||
elif param.mode == "mix":
|
|
||||||
return await mix_kg_vector_query(
|
|
||||||
formatted_question,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
entities_vdb,
|
|
||||||
relationships_vdb,
|
|
||||||
chunks_vdb,
|
|
||||||
text_chunks_db,
|
|
||||||
param,
|
|
||||||
global_config,
|
|
||||||
hashing_kv=hashing_kv,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {param.mode}")
|
||||||
|
@@ -721,19 +721,19 @@ def truncate_list_by_token_size(
|
|||||||
|
|
||||||
def list_of_list_to_dict(data: list[list[str]]) -> list[dict[str, str]]:
|
def list_of_list_to_dict(data: list[list[str]]) -> list[dict[str, str]]:
|
||||||
"""Convert a 2D string list (table-like data) into a list of dictionaries.
|
"""Convert a 2D string list (table-like data) into a list of dictionaries.
|
||||||
|
|
||||||
The first row is treated as header containing field names. Subsequent rows become
|
The first row is treated as header containing field names. Subsequent rows become
|
||||||
dictionary entries where keys come from header and values from row data.
|
dictionary entries where keys come from header and values from row data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 2D string array where first row contains headers and rest are data rows.
|
data: 2D string array where first row contains headers and rest are data rows.
|
||||||
Minimum 2 columns required in data rows (rows with <2 elements are skipped).
|
Minimum 2 columns required in data rows (rows with <2 elements are skipped).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionaries where each dict represents a data row with:
|
List of dictionaries where each dict represents a data row with:
|
||||||
- Keys: Header values from first row
|
- Keys: Header values from first row
|
||||||
- Values: Corresponding row values (empty string if missing)
|
- Values: Corresponding row values (empty string if missing)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
Input: [["Name","Age"], ["Alice","23"], ["Bob"]]
|
Input: [["Name","Age"], ["Alice","23"], ["Bob"]]
|
||||||
Output: [{"Name":"Alice","Age":"23"}, {"Name":"Bob","Age":""}]
|
Output: [{"Name":"Alice","Age":"23"}, {"Name":"Bob","Age":""}]
|
||||||
@@ -822,21 +822,33 @@ def xml_to_json(xml_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_combine_contexts(
|
def process_combine_contexts(*context_lists):
|
||||||
hl_context: list[dict[str, str]], ll_context: list[dict[str, str]]
|
"""
|
||||||
):
|
Combine multiple context lists and remove duplicate content
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*context_lists: Any number of context lists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined context list with duplicates removed
|
||||||
|
"""
|
||||||
seen_content = {}
|
seen_content = {}
|
||||||
combined_data = []
|
combined_data = []
|
||||||
|
|
||||||
for item in hl_context + ll_context:
|
# Iterate through all input context lists
|
||||||
content_dict = {k: v for k, v in item.items() if k != "id"}
|
for context_list in context_lists:
|
||||||
content_key = tuple(sorted(content_dict.items()))
|
if not context_list: # Skip empty lists
|
||||||
if content_key not in seen_content:
|
continue
|
||||||
seen_content[content_key] = item
|
for item in context_list:
|
||||||
combined_data.append(item)
|
content_dict = {k: v for k, v in item.items() if k != "id"}
|
||||||
|
content_key = tuple(sorted(content_dict.items()))
|
||||||
|
if content_key not in seen_content:
|
||||||
|
seen_content[content_key] = item
|
||||||
|
combined_data.append(item)
|
||||||
|
|
||||||
|
# Reassign IDs
|
||||||
for i, item in enumerate(combined_data):
|
for i, item in enumerate(combined_data):
|
||||||
item["id"] = str(i)
|
item["id"] = str(i + 1)
|
||||||
|
|
||||||
return combined_data
|
return combined_data
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user