支持多轮对话

This commit is contained in:
Magic_yuan
2025-01-24 18:59:24 +08:00
parent 3d93df4049
commit 5719aa8882
5 changed files with 479 additions and 364 deletions

View File

@@ -21,6 +21,7 @@ from .utils import (
save_to_cache,
CacheData,
statistic_data,
get_conversation_turns,
)
from .base import (
BaseGraphStorage,
@@ -369,7 +370,7 @@ async def extract_entities(
arg_hash = compute_args_hash(_prompt)
cached_return, _1, _2, _3 = await handle_cache(
llm_response_cache, arg_hash, _prompt, "default"
llm_response_cache, arg_hash, _prompt, "default", cache_type="default"
)
if need_to_restore:
llm_response_cache.global_config = global_config
@@ -576,54 +577,19 @@ async def kg_query(
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if cached_response is not None:
return cached_response
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv
)
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt, keyword_extraction=True)
logger.info("kw_prompt result:")
print(result)
try:
# json_text = locate_json_string_body_from_string(result) # handled in use_model_func
match = re.search(r"\{.*\}", result, re.DOTALL)
if match:
result = match.group(0)
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
else:
logger.error("No JSON-like structure found in the result.")
return PROMPTS["fail_response"]
# Handle parsing error
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e} {result}")
return PROMPTS["fail_response"]
# Handdle keywords missing
# Handle empty keywords
if hl_keywords == [] and ll_keywords == []:
logger.warning("low_level_keywords and high_level_keywords is empty")
return PROMPTS["fail_response"]
@@ -660,12 +626,27 @@ async def kg_query(
return context
if context is None:
return PROMPTS["fail_response"]
# Process conversation history
history_context = ""
if query_param.conversation_history:
recent_history = query_param.conversation_history[
-query_param.history_window_size :
]
history_context = "\n".join(
[f"{turn['role']}: {turn['content']}" for turn in recent_history]
)
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
context_data=context,
response_type=query_param.response_type,
history=history_context,
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
@@ -693,140 +674,7 @@ async def kg_query(
min_val=min_val,
max_val=max_val,
mode=query_param.mode,
),
)
return response
async def kg_query_with_keywords(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
"""
Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
Then it uses those to build context and produce a final LLM response.
"""
# ---------------------------
# 0) Handle potential cache
# ---------------------------
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
)
if cached_response is not None:
return cached_response
# ---------------------------
# 1) RETRIEVE KEYWORDS FROM query_param
# ---------------------------
# If these fields don't exist, default to empty lists/strings.
hl_keywords = getattr(query_param, "hl_keywords", []) or []
ll_keywords = getattr(query_param, "ll_keywords", []) or []
# If neither has any keywords, you could handle that logic here.
if not hl_keywords and not ll_keywords:
logger.warning(
"No keywords found in query_param. Could default to global mode or fail."
)
return PROMPTS["fail_response"]
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty, switching to global mode.")
query_param.mode = "global"
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty, switching to local mode.")
query_param.mode = "local"
# Flatten low-level and high-level keywords if needed
ll_keywords_flat = (
[item for sublist in ll_keywords for item in sublist]
if any(isinstance(i, list) for i in ll_keywords)
else ll_keywords
)
hl_keywords_flat = (
[item for sublist in hl_keywords for item in sublist]
if any(isinstance(i, list) for i in hl_keywords)
else hl_keywords
)
# Join the flattened lists
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
keywords = [ll_keywords_str, hl_keywords_str]
logger.info("Using %s mode for query processing", query_param.mode)
# ---------------------------
# 2) BUILD CONTEXT
# ---------------------------
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if not context:
return PROMPTS["fail_response"]
# If only context is needed, return it
if query_param.only_need_context:
return context
# ---------------------------
# 3) BUILD THE SYSTEM PROMPT + CALL LLM
# ---------------------------
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
# Now call the LLM with the final system prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
# Clean up the response
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# ---------------------------
# 4) SAVE TO CACHE
# ---------------------------
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=query_param.mode,
cache_type="query",
),
)
return response
@@ -844,22 +692,21 @@ async def extract_keywords_only(
It ONLY extracts keywords (hl_keywords, ll_keywords).
"""
# 1. Handle cache if needed
args_hash = compute_args_hash(param.mode, text)
# 1. Handle cache if needed - add cache type for keywords
args_hash = compute_args_hash(param.mode, text, cache_type="keywords")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, text, param.mode
hashing_kv, args_hash, text, param.mode, cache_type="keywords"
)
if cached_response is not None:
# parse the cached_response if its JSON containing keywords
# or simply return (hl_keywords, ll_keywords) from cached
# Assuming cached_response is in the same JSON structure:
match = re.search(r"\{.*\}", cached_response, re.DOTALL)
if match:
keywords_data = json.loads(match.group(0))
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
return hl_keywords, ll_keywords
return [], []
try:
keywords_data = json.loads(cached_response)
return keywords_data["high_level_keywords"], keywords_data[
"low_level_keywords"
]
except (json.JSONDecodeError, KeyError):
logger.warning(
"Invalid cache format for keywords, proceeding with extraction"
)
# 2. Build the examples
example_number = global_config["addon_params"].get("example_number", None)
@@ -873,15 +720,23 @@ async def extract_keywords_only(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
# 3. Build the keyword-extraction prompt
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language)
# 3. Process conversation history
history_context = ""
if param.conversation_history:
history_context = get_conversation_turns(
param.conversation_history, param.history_turns
)
# 4. Call the LLM for keyword extraction
# 4. Build the keyword-extraction prompt
kw_prompt = PROMPTS["keywords_extraction"].format(
query=text, examples=examples, language=language, history=history_context
)
# 5. Call the LLM for keyword extraction
use_model_func = global_config["llm_model_func"]
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 5. Parse out JSON from the LLM response
# 6. Parse out JSON from the LLM response
match = re.search(r"\{.*\}", result, re.DOTALL)
if not match:
logger.error("No JSON-like structure found in the result.")
@@ -895,22 +750,225 @@ async def extract_keywords_only(
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
# 6. Cache the result if needed
# 7. Cache only the processed keywords with cache type
cache_data = {"high_level_keywords": hl_keywords, "low_level_keywords": ll_keywords}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=result,
content=json.dumps(cache_data),
prompt=text,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode=param.mode,
cache_type="keywords",
),
)
return hl_keywords, ll_keywords
async def mix_kg_vector_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
This function performs a hybrid search by:
1. Extracting semantic information from knowledge graph
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
"""
# 1. Cache handling
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash("mix", query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix", cache_type="query"
)
if cached_response is not None:
return cached_response
# Process conversation history
history_context = ""
if query_param.conversation_history:
history_context = get_conversation_turns(
query_param.conversation_history, query_param.history_turns
)
# 2. Execute knowledge graph and vector searches in parallel
async def get_kg_context():
try:
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv
)
if not hl_keywords and not ll_keywords:
logger.warning("Both high-level and low-level keywords are empty")
return None
# Convert keyword lists to strings
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Set query mode based on available keywords
if not ll_keywords_str and not hl_keywords_str:
return None
elif not ll_keywords_str:
query_param.mode = "global"
elif not hl_keywords_str:
query_param.mode = "local"
else:
query_param.mode = "hybrid"
# Build knowledge graph context
context = await _build_query_context(
[ll_keywords_str, hl_keywords_str],
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
return context
except Exception as e:
logger.error(f"Error in get_kg_context: {str(e)}")
return None
async def get_vector_context():
# Consider conversation history in vector search
augmented_query = query
if history_context:
augmented_query = f"{history_context}\n{query}"
try:
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
mix_topk = min(10, query_param.top_k)
results = await chunks_vdb.query(augmented_query, top_k=mix_topk)
if not results:
return None
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
valid_chunks = []
for chunk, result in zip(chunks, results):
if chunk is not None and "content" in chunk:
# Merge chunk content and time metadata
chunk_with_time = {
"content": chunk["content"],
"created_at": result.get("created_at", None),
}
valid_chunks.append(chunk_with_time)
if not valid_chunks:
return None
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
if not maybe_trun_chunks:
return None
# Include time information in content
formatted_chunks = []
for c in maybe_trun_chunks:
chunk_text = c["content"]
if c["created_at"]:
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
formatted_chunks.append(chunk_text)
return "\n--New Chunk--\n".join(formatted_chunks)
except Exception as e:
logger.error(f"Error in get_vector_context: {e}")
return None
# 3. Execute both retrievals in parallel
kg_context, vector_context = await asyncio.gather(
get_kg_context(), get_vector_context()
)
# 4. Merge contexts
if kg_context is None and vector_context is None:
return PROMPTS["fail_response"]
if query_param.only_need_context:
return {"kg_context": kg_context, "vector_context": vector_context}
# 5. Construct hybrid prompt
sys_prompt = PROMPTS["mix_rag_response"].format(
kg_context=kg_context
if kg_context
else "No relevant knowledge graph information found",
vector_context=vector_context
if vector_context
else "No relevant text information found",
response_type=query_param.response_type,
history=history_context,
)
if query_param.only_need_prompt:
return sys_prompt
# 6. Generate response
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
if query_param.stream:
# 如果是流式响应,先收集完整响应
full_response = []
async for chunk in response:
full_response.append(chunk)
# 将完整响应组合起来用于缓存
response = "".join(full_response)
# 清理响应内容
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# 7. Save cache - 只有在收集完整响应后才缓存
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode="mix",
cache_type="query",
),
)
return response
async def _build_query_context(
query: list,
knowledge_graph_inst: BaseGraphStorage,
@@ -1407,9 +1465,9 @@ async def naive_query(
):
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
hashing_kv, args_hash, query, "default", cache_type="query"
)
if cached_response is not None:
return cached_response
@@ -1482,190 +1540,125 @@ async def naive_query(
min_val=min_val,
max_val=max_val,
mode=query_param.mode,
cache_type="query",
),
)
return response
async def mix_kg_vector_query(
query,
async def kg_query_with_keywords(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
This function performs a hybrid search by:
1. Extracting semantic information from knowledge graph
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
Refactored kg_query that does NOT extract keywords by itself.
It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
Then it uses those to build context and produce a final LLM response.
"""
# 1. Cache handling
# ---------------------------
# 1) Handle potential cache for query results
# ---------------------------
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash("mix", query)
args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix"
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if cached_response is not None:
return cached_response
# 2. Execute knowledge graph and vector searches in parallel
async def get_kg_context():
try:
# Reuse keyword extraction logic from kg_query
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(
PROMPTS["keywords_extraction_examples"]
):
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
# ---------------------------
# 2) RETRIEVE KEYWORDS FROM query_param
# ---------------------------
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
# If these fields don't exist, default to empty lists/strings.
hl_keywords = getattr(query_param, "hl_keywords", []) or []
ll_keywords = getattr(query_param, "ll_keywords", []) or []
# Extract keywords using LLM
kw_prompt = PROMPTS["keywords_extraction"].format(
query=query, examples=examples, language=language
)
result = await use_model_func(kw_prompt, keyword_extraction=True)
# If neither has any keywords, you could handle that logic here.
if not hl_keywords and not ll_keywords:
logger.warning(
"No keywords found in query_param. Could default to global mode or fail."
)
return PROMPTS["fail_response"]
if not ll_keywords and query_param.mode in ["local", "hybrid"]:
logger.warning("low_level_keywords is empty, switching to global mode.")
query_param.mode = "global"
if not hl_keywords and query_param.mode in ["global", "hybrid"]:
logger.warning("high_level_keywords is empty, switching to local mode.")
query_param.mode = "local"
match = re.search(r"\{.*\}", result, re.DOTALL)
if not match:
logger.warning(
"No JSON-like structure found in keywords extraction result"
)
return None
result = match.group(0)
keywords_data = json.loads(result)
hl_keywords = keywords_data.get("high_level_keywords", [])
ll_keywords = keywords_data.get("low_level_keywords", [])
if not hl_keywords and not ll_keywords:
logger.warning("Both high-level and low-level keywords are empty")
return None
# Convert keyword lists to strings
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# Set query mode based on available keywords
if not ll_keywords_str and not hl_keywords_str:
return None
elif not ll_keywords_str:
query_param.mode = "global"
elif not hl_keywords_str:
query_param.mode = "local"
else:
query_param.mode = "hybrid"
# Build knowledge graph context
context = await _build_query_context(
[ll_keywords_str, hl_keywords_str],
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
return context
except Exception as e:
logger.error(f"Error in get_kg_context: {str(e)}")
return None
async def get_vector_context():
# Reuse vector search logic from naive_query
try:
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
mix_topk = min(10, query_param.top_k)
results = await chunks_vdb.query(query, top_k=mix_topk)
if not results:
return None
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
valid_chunks = []
for chunk, result in zip(chunks, results):
if chunk is not None and "content" in chunk:
# Merge chunk content and time metadata
chunk_with_time = {
"content": chunk["content"],
"created_at": result.get("created_at", None),
}
valid_chunks.append(chunk_with_time)
if not valid_chunks:
return None
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
if not maybe_trun_chunks:
return None
# Include time information in content
formatted_chunks = []
for c in maybe_trun_chunks:
chunk_text = c["content"]
if c["created_at"]:
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\n{chunk_text}"
formatted_chunks.append(chunk_text)
return "\n--New Chunk--\n".join(formatted_chunks)
except Exception as e:
logger.error(f"Error in get_vector_context: {e}")
return None
# 3. Execute both retrievals in parallel
kg_context, vector_context = await asyncio.gather(
get_kg_context(), get_vector_context()
# Flatten low-level and high-level keywords if needed
ll_keywords_flat = (
[item for sublist in ll_keywords for item in sublist]
if any(isinstance(i, list) for i in ll_keywords)
else ll_keywords
)
hl_keywords_flat = (
[item for sublist in hl_keywords for item in sublist]
if any(isinstance(i, list) for i in hl_keywords)
else hl_keywords
)
# 4. Merge contexts
if kg_context is None and vector_context is None:
# Join the flattened lists
ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
keywords = [ll_keywords_str, hl_keywords_str]
logger.info("Using %s mode for query processing", query_param.mode)
# ---------------------------
# 3) BUILD CONTEXT
# ---------------------------
context = await _build_query_context(
keywords,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
)
if not context:
return PROMPTS["fail_response"]
# If only context is needed, return it
if query_param.only_need_context:
return {"kg_context": kg_context, "vector_context": vector_context}
return context
# 5. Construct hybrid prompt
sys_prompt = PROMPTS["mix_rag_response"].format(
kg_context=kg_context
if kg_context
else "No relevant knowledge graph information found",
vector_context=vector_context
if vector_context
else "No relevant text information found",
# ---------------------------
# 4) BUILD THE SYSTEM PROMPT + CALL LLM
# ---------------------------
# Process conversation history
history_context = ""
if query_param.conversation_history:
history_context = get_conversation_turns(
query_param.conversation_history, query_param.history_turns
)
sys_prompt_temp = PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context,
response_type=query_param.response_type,
history=history_context,
)
if query_param.only_need_prompt:
return sys_prompt
# 6. Generate response
response = await use_model_func(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
@@ -1677,7 +1670,7 @@ async def mix_kg_vector_query(
.strip()
)
# 7. Save cache
# Save to cache
await save_to_cache(
hashing_kv,
CacheData(
@@ -1687,8 +1680,8 @@ async def mix_kg_vector_query(
quantized=quantized,
min_val=min_val,
max_val=max_val,
mode="mix",
mode=query_param.mode,
cache_type="query",
),
)
return response