From 910a7a89360a570279173d7a2028f4a1a555ace6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 7 May 2025 03:47:09 +0800 Subject: [PATCH] Unified vector retrieval logic for mix and naive queries --- lightrag/operate.py | 174 +++++++++++++++++++++----------------------- 1 file changed, 82 insertions(+), 92 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index c49d8368..e383d686 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1198,66 +1198,10 @@ async def mix_kg_vector_query( traceback.print_exc() return None - async def get_vector_context(): - try: - # Reduce top_k for vector search in hybrid mode since we have structured information from KG - mix_topk = min(10, query_param.top_k) - results = await chunks_vdb.query( - query, top_k=mix_topk, ids=query_param.ids - ) - if not results: - return None - - valid_chunks = [] - for result in results: - if "content" in result: - # Directly use content from chunks_vdb.query result - chunk_with_time = { - "content": result["content"], - "created_at": result.get("created_at", None), - "file_path": result.get("file_path", None), - } - valid_chunks.append(chunk_with_time) - - if not valid_chunks: - return None - - maybe_trun_chunks = truncate_list_by_token_size( - valid_chunks, - key=lambda x: x["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) - - logger.debug( - f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" - ) - logger.info( - f"Naive query: {len(maybe_trun_chunks)} chunks, top_k: {mix_topk}" - ) - - if not maybe_trun_chunks: - return None - - # Include time information in content - formatted_chunks = [] - for c in maybe_trun_chunks: - chunk_text = "File path: " + c["file_path"] + "\r\n\r\n" + c["content"] - if c["created_at"]: - chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\r\n\r\n{chunk_text}" - formatted_chunks.append(chunk_text) - - logger.debug( - f"Truncate chunks from {len(valid_chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})" - ) - return "\r\n\r\n--New Chunk--\r\n\r\n".join(formatted_chunks) - except Exception as e: - logger.error(f"Error in get_vector_context: {e}") - return None # 3. Execute both retrievals in parallel kg_context, vector_context = await asyncio.gather( - get_kg_context(), get_vector_context() + get_kg_context(), _get_vector_context(query, chunks_vdb, query_param, tokenizer) ) # 4. Merge contexts @@ -2038,44 +1982,12 @@ async def naive_query( if cached_response is not None: return cached_response - results = await chunks_vdb.query( - query, top_k=query_param.top_k, ids=query_param.ids - ) - if not len(results): - return PROMPTS["fail_response"] - - valid_chunks = [result for result in results if "content" in result] - - if not valid_chunks: - logger.warning("No valid chunks found after filtering") - return PROMPTS["fail_response"] - tokenizer: Tokenizer = global_config["tokenizer"] - maybe_trun_chunks = truncate_list_by_token_size( - valid_chunks, - key=lambda x: x["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) - - if not maybe_trun_chunks: - logger.warning("No chunks left after truncation") + section = await _get_vector_context(query, chunks_vdb, query_param, tokenizer) + + if section is None: return PROMPTS["fail_response"] - logger.debug( - f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" - ) - logger.info( - f"Naive query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" - ) - - section = "\r\n\r\n--New Chunk--\r\n\r\n".join( - [ - "File path: " + c["file_path"] + "\r\n\r\n" + c["content"] - for c in maybe_trun_chunks - ] - ) - if query_param.only_need_context: return section @@ -2287,6 +2199,84 @@ async def kg_query_with_keywords( return response +async def _get_vector_context( + query: str, + chunks_vdb: BaseVectorStorage, + query_param: QueryParam, + tokenizer: Tokenizer, +) -> str | None: + """ + Retrieve vector context from the vector database. + + This function performs vector search to find relevant text chunks for a query, + formats them with file path and creation time information, and truncates + the results to fit within token limits. + + Args: + query: The query string to search for + chunks_vdb: Vector database containing document chunks + query_param: Query parameters including top_k and ids + tokenizer: Tokenizer for counting tokens + + Returns: + Formatted string containing relevant text chunks, or None if no results found + """ + try: + # Reduce top_k for vector search in hybrid mode since we have structured information from KG + mix_topk = min(10, query_param.top_k) if hasattr(query_param, 'mode') and query_param.mode == 'mix' else query_param.top_k + results = await chunks_vdb.query( + query, top_k=mix_topk, ids=query_param.ids + ) + if not results: + return None + + valid_chunks = [] + for result in results: + if "content" in result: + # Directly use content from chunks_vdb.query result + chunk_with_time = { + "content": result["content"], + "created_at": result.get("created_at", None), + "file_path": result.get("file_path", None), + } + valid_chunks.append(chunk_with_time) + + if not valid_chunks: + return None + + maybe_trun_chunks = truncate_list_by_token_size( + valid_chunks, + key=lambda x: x["content"], + max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, + ) + + logger.debug( + f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" + ) + logger.info( + f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {mix_topk}" + ) + + if not maybe_trun_chunks: + return None + + # Include time information in content + formatted_chunks = [] + for c in maybe_trun_chunks: + chunk_text = "File path: " + c["file_path"] + "\r\n\r\n" + c["content"] + if c["created_at"]: + chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\r\n\r\n{chunk_text}" + formatted_chunks.append(chunk_text) + + logger.debug( + f"Truncate chunks from {len(valid_chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})" + ) + return "\r\n\r\n--New Chunk--\r\n\r\n".join(formatted_chunks) + except Exception as e: + logger.error(f"Error in _get_vector_context: {e}") + return None + async def query_with_keywords( query: str, prompt: str,