支持多轮对话
This commit is contained in:
@@ -21,6 +21,7 @@ from .utils import (
|
||||
save_to_cache,
|
||||
CacheData,
|
||||
statistic_data,
|
||||
get_conversation_turns,
|
||||
)
|
||||
from .base import (
|
||||
BaseGraphStorage,
|
||||
@@ -369,7 +370,7 @@ async def extract_entities(
|
||||
|
||||
arg_hash = compute_args_hash(_prompt)
|
||||
cached_return, _1, _2, _3 = await handle_cache(
|
||||
llm_response_cache, arg_hash, _prompt, "default"
|
||||
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
|
||||
)
|
||||
if need_to_restore:
|
||||
llm_response_cache.global_config = global_config
|
||||
@@ -576,54 +577,19 @@ async def kg_query(
|
||||
) -> str:
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
|
||||
examples = "\n".join(
|
||||
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
||||
)
|
||||
else:
|
||||
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
||||
language = global_config["addon_params"].get(
|
||||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
# Set mode
|
||||
if query_param.mode not in ["local", "global", "hybrid"]:
|
||||
logger.error(f"Unknown mode {query_param.mode} in kg_query")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# LLM generate keywords
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
logger.info("kw_prompt result:")
|
||||
print(result)
|
||||
try:
|
||||
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
|
||||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
||||
if match:
|
||||
result = match.group(0)
|
||||
keywords_data = json.loads(result)
|
||||
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
else:
|
||||
logger.error("No JSON-like structure found in the result.")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# Handle parsing error
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e} {result}")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# Handdle keywords missing
|
||||
# Handle empty keywords
|
||||
if hl_keywords == [] and ll_keywords == []:
|
||||
logger.warning("low_level_keywords and high_level_keywords is empty")
|
||||
return PROMPTS["fail_response"]
|
||||
@@ -660,12 +626,27 @@ async def kg_query(
|
||||
return context
|
||||
if context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# Process conversation history
|
||||
history_context = ""
|
||||
if query_param.conversation_history:
|
||||
recent_history = query_param.conversation_history[
|
||||
-query_param.history_window_size :
|
||||
]
|
||||
history_context = "\n".join(
|
||||
[f"{turn['role']}: {turn['content']}" for turn in recent_history]
|
||||
)
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
context_data=context,
|
||||
response_type=query_param.response_type,
|
||||
history=history_context,
|
||||
)
|
||||
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
@@ -693,140 +674,7 @@ async def kg_query(
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=query_param.mode,
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def kg_query_with_keywords(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
"""
|
||||
Refactored kg_query that does NOT extract keywords by itself.
|
||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||
Then it uses those to build context and produce a final LLM response.
|
||||
"""
|
||||
|
||||
# ---------------------------
|
||||
# 0) Handle potential cache
|
||||
# ---------------------------
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# ---------------------------
|
||||
# 1) RETRIEVE KEYWORDS FROM query_param
|
||||
# ---------------------------
|
||||
|
||||
# If these fields don't exist, default to empty lists/strings.
|
||||
hl_keywords = getattr(query_param, "hl_keywords", []) or []
|
||||
ll_keywords = getattr(query_param, "ll_keywords", []) or []
|
||||
|
||||
# If neither has any keywords, you could handle that logic here.
|
||||
if not hl_keywords and not ll_keywords:
|
||||
logger.warning(
|
||||
"No keywords found in query_param. Could default to global mode or fail."
|
||||
)
|
||||
return PROMPTS["fail_response"]
|
||||
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
|
||||
logger.warning("low_level_keywords is empty, switching to global mode.")
|
||||
query_param.mode = "global"
|
||||
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
|
||||
logger.warning("high_level_keywords is empty, switching to local mode.")
|
||||
query_param.mode = "local"
|
||||
|
||||
# Flatten low-level and high-level keywords if needed
|
||||
ll_keywords_flat = (
|
||||
[item for sublist in ll_keywords for item in sublist]
|
||||
if any(isinstance(i, list) for i in ll_keywords)
|
||||
else ll_keywords
|
||||
)
|
||||
hl_keywords_flat = (
|
||||
[item for sublist in hl_keywords for item in sublist]
|
||||
if any(isinstance(i, list) for i in hl_keywords)
|
||||
else hl_keywords
|
||||
)
|
||||
|
||||
# Join the flattened lists
|
||||
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
|
||||
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
|
||||
|
||||
keywords = [ll_keywords_str, hl_keywords_str]
|
||||
|
||||
logger.info("Using %s mode for query processing", query_param.mode)
|
||||
|
||||
# ---------------------------
|
||||
# 2) BUILD CONTEXT
|
||||
# ---------------------------
|
||||
context = await _build_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
if not context:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# If only context is needed, return it
|
||||
if query_param.only_need_context:
|
||||
return context
|
||||
|
||||
# ---------------------------
|
||||
# 3) BUILD THE SYSTEM PROMPT + CALL LLM
|
||||
# ---------------------------
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context, response_type=query_param.response_type
|
||||
)
|
||||
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
# Now call the LLM with the final system prompt
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
# Clean up the response
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# ---------------------------
|
||||
# 4) SAVE TO CACHE
|
||||
# ---------------------------
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response,
|
||||
prompt=query,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=query_param.mode,
|
||||
cache_type="query",
|
||||
),
|
||||
)
|
||||
return response
|
||||
@@ -844,22 +692,21 @@ async def extract_keywords_only(
|
||||
It ONLY extracts keywords (hl_keywords, ll_keywords).
|
||||
"""
|
||||
|
||||
# 1. Handle cache if needed
|
||||
args_hash = compute_args_hash(param.mode, text)
|
||||
# 1. Handle cache if needed - add cache type for keywords
|
||||
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, text, param.mode
|
||||
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
|
||||
)
|
||||
if cached_response is not None:
|
||||
# parse the cached_response if it’s JSON containing keywords
|
||||
# or simply return (hl_keywords, ll_keywords) from cached
|
||||
# Assuming cached_response is in the same JSON structure:
|
||||
match = re.search(r"\{.*\}", cached_response, re.DOTALL)
|
||||
if match:
|
||||
keywords_data = json.loads(match.group(0))
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
return hl_keywords, ll_keywords
|
||||
return [], []
|
||||
try:
|
||||
keywords_data = json.loads(cached_response)
|
||||
return keywords_data["high_level_keywords"], keywords_data[
|
||||
"low_level_keywords"
|
||||
]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
logger.warning(
|
||||
"Invalid cache format for keywords, proceeding with extraction"
|
||||
)
|
||||
|
||||
# 2. Build the examples
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
@@ -873,15 +720,23 @@ async def extract_keywords_only(
|
||||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||
)
|
||||
|
||||
# 3. Build the keyword-extraction prompt
|
||||
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
||||
kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language)
|
||||
# 3. Process conversation history
|
||||
history_context = ""
|
||||
if param.conversation_history:
|
||||
history_context = get_conversation_turns(
|
||||
param.conversation_history, param.history_turns
|
||||
)
|
||||
|
||||
# 4. Call the LLM for keyword extraction
|
||||
# 4. Build the keyword-extraction prompt
|
||||
kw_prompt = PROMPTS["keywords_extraction"].format(
|
||||
query=text, examples=examples, language=language, history=history_context
|
||||
)
|
||||
|
||||
# 5. Call the LLM for keyword extraction
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
|
||||
# 5. Parse out JSON from the LLM response
|
||||
# 6. Parse out JSON from the LLM response
|
||||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
||||
if not match:
|
||||
logger.error("No JSON-like structure found in the result.")
|
||||
@@ -895,22 +750,225 @@ async def extract_keywords_only(
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
|
||||
# 6. Cache the result if needed
|
||||
# 7. Cache only the processed keywords with cache type
|
||||
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=result,
|
||||
content=json.dumps(cache_data),
|
||||
prompt=text,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=param.mode,
|
||||
cache_type="keywords",
|
||||
),
|
||||
)
|
||||
return hl_keywords, ll_keywords
|
||||
|
||||
|
||||
async def mix_kg_vector_query(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
|
||||
This function performs a hybrid search by:
|
||||
1. Extracting semantic information from knowledge graph
|
||||
2. Retrieving relevant text chunks through vector similarity
|
||||
3. Combining both results for comprehensive answer generation
|
||||
"""
|
||||
# 1. Cache handling
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash("mix", query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, "mix", cache_type="query"
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# Process conversation history
|
||||
history_context = ""
|
||||
if query_param.conversation_history:
|
||||
history_context = get_conversation_turns(
|
||||
query_param.conversation_history, query_param.history_turns
|
||||
)
|
||||
|
||||
# 2. Execute knowledge graph and vector searches in parallel
|
||||
async def get_kg_context():
|
||||
try:
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
||||
if not hl_keywords and not ll_keywords:
|
||||
logger.warning("Both high-level and low-level keywords are empty")
|
||||
return None
|
||||
|
||||
# Convert keyword lists to strings
|
||||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||||
|
||||
# Set query mode based on available keywords
|
||||
if not ll_keywords_str and not hl_keywords_str:
|
||||
return None
|
||||
elif not ll_keywords_str:
|
||||
query_param.mode = "global"
|
||||
elif not hl_keywords_str:
|
||||
query_param.mode = "local"
|
||||
else:
|
||||
query_param.mode = "hybrid"
|
||||
|
||||
# Build knowledge graph context
|
||||
context = await _build_query_context(
|
||||
[ll_keywords_str, hl_keywords_str],
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_kg_context: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_vector_context():
|
||||
# Consider conversation history in vector search
|
||||
augmented_query = query
|
||||
if history_context:
|
||||
augmented_query = f"{history_context}\n{query}"
|
||||
|
||||
try:
|
||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||||
mix_topk = min(10, query_param.top_k)
|
||||
results = await chunks_vdb.query(augmented_query, top_k=mix_topk)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
chunks_ids = [r["id"] for r in results]
|
||||
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
||||
|
||||
valid_chunks = []
|
||||
for chunk, result in zip(chunks, results):
|
||||
if chunk is not None and "content" in chunk:
|
||||
# Merge chunk content and time metadata
|
||||
chunk_with_time = {
|
||||
"content": chunk["content"],
|
||||
"created_at": result.get("created_at", None),
|
||||
}
|
||||
valid_chunks.append(chunk_with_time)
|
||||
|
||||
if not valid_chunks:
|
||||
return None
|
||||
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
valid_chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
)
|
||||
|
||||
if not maybe_trun_chunks:
|
||||
return None
|
||||
|
||||
# Include time information in content
|
||||
formatted_chunks = []
|
||||
for c in maybe_trun_chunks:
|
||||
chunk_text = c["content"]
|
||||
if c["created_at"]:
|
||||
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
|
||||
formatted_chunks.append(chunk_text)
|
||||
|
||||
return "\n--New Chunk--\n".join(formatted_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_vector_context: {e}")
|
||||
return None
|
||||
|
||||
# 3. Execute both retrievals in parallel
|
||||
kg_context, vector_context = await asyncio.gather(
|
||||
get_kg_context(), get_vector_context()
|
||||
)
|
||||
|
||||
# 4. Merge contexts
|
||||
if kg_context is None and vector_context is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
if query_param.only_need_context:
|
||||
return {"kg_context": kg_context, "vector_context": vector_context}
|
||||
|
||||
# 5. Construct hybrid prompt
|
||||
sys_prompt = PROMPTS["mix_rag_response"].format(
|
||||
kg_context=kg_context
|
||||
if kg_context
|
||||
else "No relevant knowledge graph information found",
|
||||
vector_context=vector_context
|
||||
if vector_context
|
||||
else "No relevant text information found",
|
||||
response_type=query_param.response_type,
|
||||
history=history_context,
|
||||
)
|
||||
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
# 6. Generate response
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
if query_param.stream:
|
||||
# 如果是流式响应,先收集完整响应
|
||||
full_response = []
|
||||
async for chunk in response:
|
||||
full_response.append(chunk)
|
||||
|
||||
# 将完整响应组合起来用于缓存
|
||||
response = "".join(full_response)
|
||||
|
||||
# 清理响应内容
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# 7. Save cache - 只有在收集完整响应后才缓存
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
args_hash=args_hash,
|
||||
content=response,
|
||||
prompt=query,
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode="mix",
|
||||
cache_type="query",
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def _build_query_context(
|
||||
query: list,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
@@ -1407,9 +1465,9 @@ async def naive_query(
|
||||
):
|
||||
# Handle cache
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash(query_param.mode, query)
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode
|
||||
hashing_kv, args_hash, query, "default", cache_type="query"
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
@@ -1482,190 +1540,125 @@ async def naive_query(
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode=query_param.mode,
|
||||
cache_type="query",
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def mix_kg_vector_query(
|
||||
query,
|
||||
async def kg_query_with_keywords(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
||||
query_param: QueryParam,
|
||||
global_config: dict,
|
||||
hashing_kv: BaseKVStorage = None,
|
||||
) -> str:
|
||||
"""
|
||||
Hybrid retrieval implementation combining knowledge graph and vector search.
|
||||
|
||||
This function performs a hybrid search by:
|
||||
1. Extracting semantic information from knowledge graph
|
||||
2. Retrieving relevant text chunks through vector similarity
|
||||
3. Combining both results for comprehensive answer generation
|
||||
Refactored kg_query that does NOT extract keywords by itself.
|
||||
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
|
||||
Then it uses those to build context and produce a final LLM response.
|
||||
"""
|
||||
# 1. Cache handling
|
||||
|
||||
# ---------------------------
|
||||
# 1) Handle potential cache for query results
|
||||
# ---------------------------
|
||||
use_model_func = global_config["llm_model_func"]
|
||||
args_hash = compute_args_hash("mix", query)
|
||||
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
|
||||
cached_response, quantized, min_val, max_val = await handle_cache(
|
||||
hashing_kv, args_hash, query, "mix"
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
# 2. Execute knowledge graph and vector searches in parallel
|
||||
async def get_kg_context():
|
||||
try:
|
||||
# Reuse keyword extraction logic from kg_query
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(
|
||||
PROMPTS["keywords_extraction_examples"]
|
||||
):
|
||||
examples = "\n".join(
|
||||
PROMPTS["keywords_extraction_examples"][: int(example_number)]
|
||||
)
|
||||
else:
|
||||
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
|
||||
# ---------------------------
|
||||
# 2) RETRIEVE KEYWORDS FROM query_param
|
||||
# ---------------------------
|
||||
|
||||
language = global_config["addon_params"].get(
|
||||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||
)
|
||||
# If these fields don't exist, default to empty lists/strings.
|
||||
hl_keywords = getattr(query_param, "hl_keywords", []) or []
|
||||
ll_keywords = getattr(query_param, "ll_keywords", []) or []
|
||||
|
||||
# Extract keywords using LLM
|
||||
kw_prompt = PROMPTS["keywords_extraction"].format(
|
||||
query=query, examples=examples, language=language
|
||||
)
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
# If neither has any keywords, you could handle that logic here.
|
||||
if not hl_keywords and not ll_keywords:
|
||||
logger.warning(
|
||||
"No keywords found in query_param. Could default to global mode or fail."
|
||||
)
|
||||
return PROMPTS["fail_response"]
|
||||
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
|
||||
logger.warning("low_level_keywords is empty, switching to global mode.")
|
||||
query_param.mode = "global"
|
||||
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
|
||||
logger.warning("high_level_keywords is empty, switching to local mode.")
|
||||
query_param.mode = "local"
|
||||
|
||||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
||||
if not match:
|
||||
logger.warning(
|
||||
"No JSON-like structure found in keywords extraction result"
|
||||
)
|
||||
return None
|
||||
|
||||
result = match.group(0)
|
||||
keywords_data = json.loads(result)
|
||||
hl_keywords = keywords_data.get("high_level_keywords", [])
|
||||
ll_keywords = keywords_data.get("low_level_keywords", [])
|
||||
|
||||
if not hl_keywords and not ll_keywords:
|
||||
logger.warning("Both high-level and low-level keywords are empty")
|
||||
return None
|
||||
|
||||
# Convert keyword lists to strings
|
||||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||||
|
||||
# Set query mode based on available keywords
|
||||
if not ll_keywords_str and not hl_keywords_str:
|
||||
return None
|
||||
elif not ll_keywords_str:
|
||||
query_param.mode = "global"
|
||||
elif not hl_keywords_str:
|
||||
query_param.mode = "local"
|
||||
else:
|
||||
query_param.mode = "hybrid"
|
||||
|
||||
# Build knowledge graph context
|
||||
context = await _build_query_context(
|
||||
[ll_keywords_str, hl_keywords_str],
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_kg_context: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_vector_context():
|
||||
# Reuse vector search logic from naive_query
|
||||
try:
|
||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||||
mix_topk = min(10, query_param.top_k)
|
||||
results = await chunks_vdb.query(query, top_k=mix_topk)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
chunks_ids = [r["id"] for r in results]
|
||||
chunks = await text_chunks_db.get_by_ids(chunks_ids)
|
||||
|
||||
valid_chunks = []
|
||||
for chunk, result in zip(chunks, results):
|
||||
if chunk is not None and "content" in chunk:
|
||||
# Merge chunk content and time metadata
|
||||
chunk_with_time = {
|
||||
"content": chunk["content"],
|
||||
"created_at": result.get("created_at", None),
|
||||
}
|
||||
valid_chunks.append(chunk_with_time)
|
||||
|
||||
if not valid_chunks:
|
||||
return None
|
||||
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
valid_chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
)
|
||||
|
||||
if not maybe_trun_chunks:
|
||||
return None
|
||||
|
||||
# Include time information in content
|
||||
formatted_chunks = []
|
||||
for c in maybe_trun_chunks:
|
||||
chunk_text = c["content"]
|
||||
if c["created_at"]:
|
||||
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
|
||||
formatted_chunks.append(chunk_text)
|
||||
|
||||
return "\n--New Chunk--\n".join(formatted_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_vector_context: {e}")
|
||||
return None
|
||||
|
||||
# 3. Execute both retrievals in parallel
|
||||
kg_context, vector_context = await asyncio.gather(
|
||||
get_kg_context(), get_vector_context()
|
||||
# Flatten low-level and high-level keywords if needed
|
||||
ll_keywords_flat = (
|
||||
[item for sublist in ll_keywords for item in sublist]
|
||||
if any(isinstance(i, list) for i in ll_keywords)
|
||||
else ll_keywords
|
||||
)
|
||||
hl_keywords_flat = (
|
||||
[item for sublist in hl_keywords for item in sublist]
|
||||
if any(isinstance(i, list) for i in hl_keywords)
|
||||
else hl_keywords
|
||||
)
|
||||
|
||||
# 4. Merge contexts
|
||||
if kg_context is None and vector_context is None:
|
||||
# Join the flattened lists
|
||||
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
|
||||
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
|
||||
|
||||
keywords = [ll_keywords_str, hl_keywords_str]
|
||||
|
||||
logger.info("Using %s mode for query processing", query_param.mode)
|
||||
|
||||
# ---------------------------
|
||||
# 3) BUILD CONTEXT
|
||||
# ---------------------------
|
||||
context = await _build_query_context(
|
||||
keywords,
|
||||
knowledge_graph_inst,
|
||||
entities_vdb,
|
||||
relationships_vdb,
|
||||
text_chunks_db,
|
||||
query_param,
|
||||
)
|
||||
if not context:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
# If only context is needed, return it
|
||||
if query_param.only_need_context:
|
||||
return {"kg_context": kg_context, "vector_context": vector_context}
|
||||
return context
|
||||
|
||||
# 5. Construct hybrid prompt
|
||||
sys_prompt = PROMPTS["mix_rag_response"].format(
|
||||
kg_context=kg_context
|
||||
if kg_context
|
||||
else "No relevant knowledge graph information found",
|
||||
vector_context=vector_context
|
||||
if vector_context
|
||||
else "No relevant text information found",
|
||||
# ---------------------------
|
||||
# 4) BUILD THE SYSTEM PROMPT + CALL LLM
|
||||
# ---------------------------
|
||||
|
||||
# Process conversation history
|
||||
history_context = ""
|
||||
if query_param.conversation_history:
|
||||
history_context = get_conversation_turns(
|
||||
query_param.conversation_history, query_param.history_turns
|
||||
)
|
||||
|
||||
sys_prompt_temp = PROMPTS["rag_response"]
|
||||
sys_prompt = sys_prompt_temp.format(
|
||||
context_data=context,
|
||||
response_type=query_param.response_type,
|
||||
history=history_context,
|
||||
)
|
||||
|
||||
if query_param.only_need_prompt:
|
||||
return sys_prompt
|
||||
|
||||
# 6. Generate response
|
||||
response = await use_model_func(
|
||||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
@@ -1677,7 +1670,7 @@ async def mix_kg_vector_query(
|
||||
.strip()
|
||||
)
|
||||
|
||||
# 7. Save cache
|
||||
# Save to cache
|
||||
await save_to_cache(
|
||||
hashing_kv,
|
||||
CacheData(
|
||||
@@ -1687,8 +1680,8 @@ async def mix_kg_vector_query(
|
||||
quantized=quantized,
|
||||
min_val=min_val,
|
||||
max_val=max_val,
|
||||
mode="mix",
|
||||
mode=query_param.mode,
|
||||
cache_type="query",
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
Reference in New Issue
Block a user