Unified vector retrieval logic for mix and naive queries
This commit is contained in:
@@ -1198,66 +1198,10 @@ async def mix_kg_vector_query(
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def get_vector_context():
|
||||
try:
|
||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||||
mix_topk = min(10, query_param.top_k)
|
||||
results = await chunks_vdb.query(
|
||||
query, top_k=mix_topk, ids=query_param.ids
|
||||
)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
valid_chunks = []
|
||||
for result in results:
|
||||
if "content" in result:
|
||||
# Directly use content from chunks_vdb.query result
|
||||
chunk_with_time = {
|
||||
"content": result["content"],
|
||||
"created_at": result.get("created_at", None),
|
||||
"file_path": result.get("file_path", None),
|
||||
}
|
||||
valid_chunks.append(chunk_with_time)
|
||||
|
||||
if not valid_chunks:
|
||||
return None
|
||||
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
valid_chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||
)
|
||||
logger.info(
|
||||
f"Naive query: {len(maybe_trun_chunks)} chunks, top_k: {mix_topk}"
|
||||
)
|
||||
|
||||
if not maybe_trun_chunks:
|
||||
return None
|
||||
|
||||
# Include time information in content
|
||||
formatted_chunks = []
|
||||
for c in maybe_trun_chunks:
|
||||
chunk_text = "File path: " + c["file_path"] + "\r\n\r\n" + c["content"]
|
||||
if c["created_at"]:
|
||||
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\r\n\r\n{chunk_text}"
|
||||
formatted_chunks.append(chunk_text)
|
||||
|
||||
logger.debug(
|
||||
f"Truncate chunks from {len(valid_chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||
)
|
||||
return "\r\n\r\n--New Chunk--\r\n\r\n".join(formatted_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_vector_context: {e}")
|
||||
return None
|
||||
|
||||
# 3. Execute both retrievals in parallel
|
||||
kg_context, vector_context = await asyncio.gather(
|
||||
get_kg_context(), get_vector_context()
|
||||
get_kg_context(), _get_vector_context(query, chunks_vdb, query_param, tokenizer)
|
||||
)
|
||||
|
||||
# 4. Merge contexts
|
||||
@@ -2038,44 +1982,12 @@ async def naive_query(
|
||||
if cached_response is not None:
|
||||
return cached_response
|
||||
|
||||
results = await chunks_vdb.query(
|
||||
query, top_k=query_param.top_k, ids=query_param.ids
|
||||
)
|
||||
if not len(results):
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
valid_chunks = [result for result in results if "content" in result]
|
||||
|
||||
if not valid_chunks:
|
||||
logger.warning("No valid chunks found after filtering")
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
valid_chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
if not maybe_trun_chunks:
|
||||
logger.warning("No chunks left after truncation")
|
||||
section = await _get_vector_context(query, chunks_vdb, query_param, tokenizer)
|
||||
|
||||
if section is None:
|
||||
return PROMPTS["fail_response"]
|
||||
|
||||
logger.debug(
|
||||
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||
)
|
||||
logger.info(
|
||||
f"Naive query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
|
||||
)
|
||||
|
||||
section = "\r\n\r\n--New Chunk--\r\n\r\n".join(
|
||||
[
|
||||
"File path: " + c["file_path"] + "\r\n\r\n" + c["content"]
|
||||
for c in maybe_trun_chunks
|
||||
]
|
||||
)
|
||||
|
||||
if query_param.only_need_context:
|
||||
return section
|
||||
|
||||
@@ -2287,6 +2199,84 @@ async def kg_query_with_keywords(
|
||||
return response
|
||||
|
||||
|
||||
async def _get_vector_context(
|
||||
query: str,
|
||||
chunks_vdb: BaseVectorStorage,
|
||||
query_param: QueryParam,
|
||||
tokenizer: Tokenizer,
|
||||
) -> str | None:
|
||||
"""
|
||||
Retrieve vector context from the vector database.
|
||||
|
||||
This function performs vector search to find relevant text chunks for a query,
|
||||
formats them with file path and creation time information, and truncates
|
||||
the results to fit within token limits.
|
||||
|
||||
Args:
|
||||
query: The query string to search for
|
||||
chunks_vdb: Vector database containing document chunks
|
||||
query_param: Query parameters including top_k and ids
|
||||
tokenizer: Tokenizer for counting tokens
|
||||
|
||||
Returns:
|
||||
Formatted string containing relevant text chunks, or None if no results found
|
||||
"""
|
||||
try:
|
||||
# Reduce top_k for vector search in hybrid mode since we have structured information from KG
|
||||
mix_topk = min(10, query_param.top_k) if hasattr(query_param, 'mode') and query_param.mode == 'mix' else query_param.top_k
|
||||
results = await chunks_vdb.query(
|
||||
query, top_k=mix_topk, ids=query_param.ids
|
||||
)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
valid_chunks = []
|
||||
for result in results:
|
||||
if "content" in result:
|
||||
# Directly use content from chunks_vdb.query result
|
||||
chunk_with_time = {
|
||||
"content": result["content"],
|
||||
"created_at": result.get("created_at", None),
|
||||
"file_path": result.get("file_path", None),
|
||||
}
|
||||
valid_chunks.append(chunk_with_time)
|
||||
|
||||
if not valid_chunks:
|
||||
return None
|
||||
|
||||
maybe_trun_chunks = truncate_list_by_token_size(
|
||||
valid_chunks,
|
||||
key=lambda x: x["content"],
|
||||
max_token_size=query_param.max_token_for_text_unit,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||
)
|
||||
logger.info(
|
||||
f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {mix_topk}"
|
||||
)
|
||||
|
||||
if not maybe_trun_chunks:
|
||||
return None
|
||||
|
||||
# Include time information in content
|
||||
formatted_chunks = []
|
||||
for c in maybe_trun_chunks:
|
||||
chunk_text = "File path: " + c["file_path"] + "\r\n\r\n" + c["content"]
|
||||
if c["created_at"]:
|
||||
chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\r\n\r\n{chunk_text}"
|
||||
formatted_chunks.append(chunk_text)
|
||||
|
||||
logger.debug(
|
||||
f"Truncate chunks from {len(valid_chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
|
||||
)
|
||||
return "\r\n\r\n--New Chunk--\r\n\r\n".join(formatted_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _get_vector_context: {e}")
|
||||
return None
|
||||
|
||||
async def query_with_keywords(
|
||||
query: str,
|
||||
prompt: str,
|
||||
|
Reference in New Issue
Block a user