diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e9cb0926..2145fcb1 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -53,7 +53,6 @@ from .operate import ( extract_entities, merge_nodes_and_edges, kg_query, - mix_kg_vector_query, naive_query, query_with_keywords, ) @@ -1437,8 +1436,10 @@ class LightRAG: """ # If a custom model is provided in param, temporarily update global config global_config = asdict(self) + # Save original query for vector search + param.original_query = query - if param.mode in ["local", "global", "hybrid"]: + if param.mode in ["local", "global", "hybrid", "mix"]: response = await kg_query( query.strip(), self.chunk_entity_relation_graph, @@ -1447,8 +1448,9 @@ class LightRAG: self.text_chunks, param, global_config, - hashing_kv=self.llm_response_cache, # Directly use llm_response_cache + hashing_kv=self.llm_response_cache, system_prompt=system_prompt, + chunks_vdb=self.chunks_vdb, ) elif param.mode == "naive": response = await naive_query( @@ -1457,20 +1459,7 @@ class LightRAG: self.text_chunks, param, global_config, - hashing_kv=self.llm_response_cache, # Directly use llm_response_cache - system_prompt=system_prompt, - ) - elif param.mode == "mix": - response = await mix_kg_vector_query( - query.strip(), - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.chunks_vdb, - self.text_chunks, - param, - global_config, - hashing_kv=self.llm_response_cache, # Directly use llm_response_cache + hashing_kv=self.llm_response_cache, system_prompt=system_prompt, ) elif param.mode == "bypass": diff --git a/lightrag/operate.py b/lightrag/operate.py index 2a4137a0..0ff485a8 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2,7 +2,6 @@ from __future__ import annotations from functools import partial import asyncio -import traceback import json import re import os @@ -859,6 +858,7 @@ async def kg_query( global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, + chunks_vdb: BaseVectorStorage = None, ) -> str | AsyncIterator[str]: if query_param.model_func: use_model_func = query_param.model_func @@ -911,6 +911,7 @@ async def kg_query( relationships_vdb, text_chunks_db, query_param, + chunks_vdb, ) if query_param.only_need_context: @@ -1110,182 +1111,17 @@ async def extract_keywords_only( return hl_keywords, ll_keywords -async def mix_kg_vector_query( - query: str, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - chunks_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage, - query_param: QueryParam, - global_config: dict[str, str], - hashing_kv: BaseKVStorage | None = None, - system_prompt: str | None = None, -) -> str | AsyncIterator[str]: - """ - Hybrid retrieval implementation combining knowledge graph and vector search. - - This function performs a hybrid search by: - 1. Extracting semantic information from knowledge graph - 2. Retrieving relevant text chunks through vector similarity - 3. Combining both results for comprehensive answer generation - """ - # get tokenizer - tokenizer: Tokenizer = global_config["tokenizer"] - - if query_param.model_func: - use_model_func = query_param.model_func - else: - use_model_func = global_config["llm_model_func"] - # Apply higher priority (5) to query relation LLM function - use_model_func = partial(use_model_func, _priority=5) - - # 1. Cache handling - args_hash = compute_args_hash("mix", query, cache_type="query") - cached_response, quantized, min_val, max_val = await handle_cache( - hashing_kv, args_hash, query, "mix", cache_type="query" - ) - if cached_response is not None: - return cached_response - - # Process conversation history - history_context = "" - if query_param.conversation_history: - history_context = get_conversation_turns( - query_param.conversation_history, query_param.history_turns - ) - - # 2. Execute knowledge graph and vector searches in parallel - async def _get_kg_context(): - try: - hl_keywords, ll_keywords = await get_keywords_from_query( - query, query_param, global_config, hashing_kv - ) - - if not hl_keywords and not ll_keywords: - logger.warning("Both high-level and low-level keywords are empty") - return None - - # Convert keyword lists to strings - ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" - hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" - - # Set query mode based on available keywords - if not ll_keywords_str and not hl_keywords_str: - return None - elif not ll_keywords_str: - query_param.mode = "global" - elif not hl_keywords_str: - query_param.mode = "local" - else: - query_param.mode = "hybrid" - - # Build knowledge graph context - context = await _build_query_context( - ll_keywords_str, - hl_keywords_str, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - text_chunks_db, - query_param, - ) - - return context - - except Exception as e: - logger.error(f"Error in _get_kg_context: {str(e)}") - traceback.print_exc() - return None - - # 3. Execute both retrievals in parallel - kg_context, vector_context = await asyncio.gather( - _get_kg_context(), _get_vector_context(query, chunks_vdb, query_param, tokenizer) - ) - - # 4. Merge contexts - if kg_context is None and vector_context is None: - return PROMPTS["fail_response"] - - if query_param.only_need_context: - context_str = f"""\r\n\r\n-----Knowledge Graph Context-----\r\n\r\n -{kg_context if kg_context else "No relevant knowledge graph information found"} - -\r\n\r\n-----Vector Context-----\r\n\r\n -{vector_context if vector_context else "No relevant text information found"} -""".strip() - return context_str - - # 5. Construct hybrid prompt - sys_prompt = ( - system_prompt if system_prompt else PROMPTS["mix_rag_response"] - ).format( - kg_context=kg_context - if kg_context - else "No relevant knowledge graph information found", - vector_context=vector_context - if vector_context - else "No relevant text information found", - response_type=query_param.response_type, - history=history_context, - ) - - if query_param.only_need_prompt: - return sys_prompt - - len_of_prompts = len(tokenizer.encode(query + sys_prompt)) - logger.debug(f"[mix_kg_vector_query]Prompt Tokens: {len_of_prompts}") - - # 6. Generate response - response = await use_model_func( - query, - system_prompt=sys_prompt, - stream=query_param.stream, - ) - - # Clean up response content - if isinstance(response, str) and len(response) > len(sys_prompt): - response = ( - response.replace(sys_prompt, "") - .replace("user", "") - .replace("model", "") - .replace(query, "") - .replace("", "") - .replace("", "") - .strip() - ) - - if hashing_kv.global_config.get("enable_llm_cache"): - # 7. Save cache - Only cache after collecting complete response - await save_to_cache( - hashing_kv, - CacheData( - args_hash=args_hash, - content=response, - prompt=query, - quantized=quantized, - min_val=min_val, - max_val=max_val, - mode="mix", - cache_type="query", - ), - ) - - return response - - async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, tokenizer: Tokenizer, -) -> str | None: +) -> tuple[list, list, list] | None: """ Retrieve vector context from the vector database. This function performs vector search to find relevant text chunks for a query, - formats them with file path and creation time information, and truncates - the results to fit within token limits. + formats them with file path and creation time information. Args: query: The query string to search for @@ -1294,18 +1130,15 @@ async def _get_vector_context( tokenizer: Tokenizer for counting tokens Returns: - Formatted string containing relevant text chunks, or None if no results found + Tuple (empty_entities, empty_relations, text_units) for combine_contexts, + compatible with _get_edge_data and _get_node_data format """ try: - # Reduce top_k for vector search in hybrid mode since we have structured information from KG - mix_topk = ( - min(10, query_param.top_k) - if hasattr(query_param, "mode") and query_param.mode == "mix" - else query_param.top_k + results = await chunks_vdb.query( + query, top_k=query_param.top_k, ids=query_param.ids ) - results = await chunks_vdb.query(query, top_k=mix_topk, ids=query_param.ids) if not results: - return None + return [], [], [] valid_chunks = [] for result in results: @@ -1314,12 +1147,12 @@ async def _get_vector_context( chunk_with_time = { "content": result["content"], "created_at": result.get("created_at", None), - "file_path": result.get("file_path", None), + "file_path": result.get("file_path", "unknown_source"), } valid_chunks.append(chunk_with_time) if not valid_chunks: - return None + return [], [], [] maybe_trun_chunks = truncate_list_by_token_size( valid_chunks, @@ -1331,26 +1164,37 @@ async def _get_vector_context( logger.debug( f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" ) - logger.info(f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {mix_topk}") + logger.info( + f"Vector query: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" + ) if not maybe_trun_chunks: - return None + return [], [], [] - # Include time information in content - formatted_chunks = [] - for c in maybe_trun_chunks: - chunk_text = "File path: " + c["file_path"] + "\r\n\r\n" + c["content"] - if c["created_at"]: - chunk_text = f"[Created at: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(c['created_at']))}]\r\n\r\n{chunk_text}" - formatted_chunks.append(chunk_text) + # Create empty entities and relations contexts + entities_context = [] + relations_context = [] - logger.debug( - f"Truncate chunks from {len(valid_chunks)} to {len(formatted_chunks)} (max tokens:{query_param.max_token_for_text_unit})" - ) - return "\r\n\r\n-------New Chunk-------\r\n\r\n".join(formatted_chunks) + # Create text_units_context in the same format as _get_edge_data and _get_node_data + text_units_section_list = [["id", "content", "file_path"]] + + for i, chunk in enumerate(maybe_trun_chunks): + # Add to text_units_section_list + text_units_section_list.append( + [ + i + 1, # id + chunk["content"], # content + chunk["file_path"], # file_path + ] + ) + + # Convert to dictionary format using list_of_list_to_dict + text_units_context = list_of_list_to_dict(text_units_section_list) + + return entities_context, relations_context, text_units_context except Exception as e: logger.error(f"Error in _get_vector_context: {e}") - return None + return [], [], [] async def _build_query_context( @@ -1361,8 +1205,11 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, + chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode ): - logger.info(f"Process {os.getpid()} buidling query context...") + logger.info(f"Process {os.getpid()} building query context...") + + # Handle local and global modes as before if query_param.mode == "local": entities_context, relations_context, text_units_context = await _get_node_data( ll_keywords, @@ -1379,7 +1226,7 @@ async def _build_query_context( text_chunks_db, query_param, ) - else: # hybrid mode + else: # hybrid or mix mode ll_data = await _get_node_data( ll_keywords, knowledge_graph_inst, @@ -1407,10 +1254,43 @@ async def _build_query_context( hl_text_units_context, ) = hl_data - entities_context, relations_context, text_units_context = combine_contexts( - [hl_entities_context, ll_entities_context], - [hl_relations_context, ll_relations_context], - [hl_text_units_context, ll_text_units_context], + # Initialize vector data with empty lists + vector_entities_context, vector_relations_context, vector_text_units_context = ( + [], + [], + [], + ) + + # Only get vector data if in mix mode + if query_param.mode == "mix" and hasattr(query_param, "original_query"): + # Get tokenizer from text_chunks_db + tokenizer = text_chunks_db.global_config.get("tokenizer") + + # Get vector context in triple format + vector_data = await _get_vector_context( + query_param.original_query, # We need to pass the original query + chunks_vdb, + query_param, + tokenizer, + ) + + # If vector_data is not None, unpack it + if vector_data is not None: + ( + vector_entities_context, + vector_relations_context, + vector_text_units_context, + ) = vector_data + + # Combine and deduplicate the entities, relationships, and sources + entities_context = process_combine_contexts( + hl_entities_context, ll_entities_context, vector_entities_context + ) + relations_context = process_combine_contexts( + hl_relations_context, ll_relations_context, vector_relations_context + ) + text_units_context = process_combine_contexts( + hl_text_units_context, ll_text_units_context, vector_text_units_context ) # not necessary to use LLM to generate a response if not entities_context and not relations_context: @@ -1539,7 +1419,7 @@ async def _get_node_data( entites_section_list.append( [ - i, + i + 1, n["entity_name"], n.get("entity_type", "UNKNOWN"), n.get("description", "UNKNOWN"), @@ -1574,7 +1454,7 @@ async def _get_node_data( relations_section_list.append( [ - i, + i + 1, e["src_tgt"][0], e["src_tgt"][1], e["description"], @@ -1590,7 +1470,7 @@ async def _get_node_data( text_units_section_list = [["id", "content", "file_path"]] for i, t in enumerate(use_text_units): text_units_section_list.append( - [i, t["content"], t.get("file_path", "unknown_source")] + [i + 1, t["content"], t.get("file_path", "unknown_source")] ) text_units_context = list_of_list_to_dict(text_units_section_list) return entities_context, relations_context, text_units_context @@ -1859,7 +1739,7 @@ async def _get_edge_data( relations_section_list.append( [ - i, + i + 1, e["src_id"], e["tgt_id"], e["description"], @@ -1886,7 +1766,7 @@ async def _get_edge_data( entites_section_list.append( [ - i, + i + 1, n["entity_name"], n.get("entity_type", "UNKNOWN"), n.get("description", "UNKNOWN"), @@ -1899,7 +1779,9 @@ async def _get_edge_data( text_units_section_list = [["id", "content", "file_path"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"], t.get("file_path", "unknown")]) + text_units_section_list.append( + [i + 1, t["content"], t.get("file_path", "unknown")] + ) text_units_context = list_of_list_to_dict(text_units_section_list) return entities_context, relations_context, text_units_context @@ -2016,25 +1898,6 @@ async def _find_related_text_unit_from_relationships( return all_text_units -def combine_contexts(entities, relationships, sources): - # Function to extract entities, relationships, and sources from context strings - hl_entities, ll_entities = entities[0], entities[1] - hl_relationships, ll_relationships = relationships[0], relationships[1] - hl_sources, ll_sources = sources[0], sources[1] - # Combine and deduplicate the entities - combined_entities = process_combine_contexts(hl_entities, ll_entities) - - # Combine and deduplicate the relationships - combined_relationships = process_combine_contexts( - hl_relationships, ll_relationships - ) - - # Combine and deduplicate the sources - combined_sources = process_combine_contexts(hl_sources, ll_sources) - - return combined_entities, combined_relationships, combined_sources - - async def naive_query( query: str, chunks_vdb: BaseVectorStorage, @@ -2060,14 +1923,24 @@ async def naive_query( return cached_response tokenizer: Tokenizer = global_config["tokenizer"] - section = await _get_vector_context(query, chunks_vdb, query_param, tokenizer) - if section is None: + _, _, text_units_context = await _get_vector_context( + query, chunks_vdb, query_param, tokenizer + ) + + if text_units_context is None or len(text_units_context) == 0: return PROMPTS["fail_response"] + text_units_str = json.dumps(text_units_context, ensure_ascii=False) if query_param.only_need_context: - return section + return f""" +---Document Chunks--- +```json +{text_units_str} +``` + +""" # Process conversation history history_context = "" if query_param.conversation_history: @@ -2077,7 +1950,7 @@ async def naive_query( sys_prompt_temp = system_prompt if system_prompt else PROMPTS["naive_rag_response"] sys_prompt = sys_prompt_temp.format( - content_data=section, + content_data=text_units_str, response_type=query_param.response_type, history=history_context, ) @@ -2134,6 +2007,9 @@ async def kg_query_with_keywords( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, + ll_keywords: list[str] = [], + hl_keywords: list[str] = [], + chunks_vdb: BaseVectorStorage | None = None, ) -> str | AsyncIterator[str]: """ Refactored kg_query that does NOT extract keywords by itself. @@ -2147,9 +2023,6 @@ async def kg_query_with_keywords( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - # --------------------------- - # 1) Handle potential cache for query results - # --------------------------- args_hash = compute_args_hash(query_param.mode, query, cache_type="query") cached_response, quantized, min_val, max_val = await handle_cache( hashing_kv, args_hash, query, query_param.mode, cache_type="query" @@ -2157,14 +2030,6 @@ async def kg_query_with_keywords( if cached_response is not None: return cached_response - # --------------------------- - # 2) RETRIEVE KEYWORDS FROM query_param - # --------------------------- - - # If these fields don't exist, default to empty lists/strings. - hl_keywords = getattr(query_param, "hl_keywords", []) or [] - ll_keywords = getattr(query_param, "ll_keywords", []) or [] - # If neither has any keywords, you could handle that logic here. if not hl_keywords and not ll_keywords: logger.warning( @@ -2178,25 +2043,9 @@ async def kg_query_with_keywords( logger.warning("high_level_keywords is empty, switching to local mode.") query_param.mode = "local" - # Flatten low-level and high-level keywords if needed - ll_keywords_flat = ( - [item for sublist in ll_keywords for item in sublist] - if any(isinstance(i, list) for i in ll_keywords) - else ll_keywords - ) - hl_keywords_flat = ( - [item for sublist in hl_keywords for item in sublist] - if any(isinstance(i, list) for i in hl_keywords) - else hl_keywords - ) + ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" + hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" - # Join the flattened lists - ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else "" - hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else "" - - # --------------------------- - # 3) BUILD CONTEXT - # --------------------------- context = await _build_query_context( ll_keywords_str, hl_keywords_str, @@ -2205,18 +2054,14 @@ async def kg_query_with_keywords( relationships_vdb, text_chunks_db, query_param, + chunks_vdb=chunks_vdb, ) if not context: return PROMPTS["fail_response"] - # If only context is needed, return it if query_param.only_need_context: return context - # --------------------------- - # 4) BUILD THE SYSTEM PROMPT + CALL LLM - # --------------------------- - # Process conversation history history_context = "" if query_param.conversation_history: @@ -2258,7 +2103,6 @@ async def kg_query_with_keywords( ) if hashing_kv.global_config.get("enable_llm_cache"): - # 7. Save cache - 只有在收集完整响应后才缓存 await save_to_cache( hashing_kv, CacheData( @@ -2319,12 +2163,15 @@ async def query_with_keywords( ) # Create a new string with the prompt and the keywords - ll_keywords_str = ", ".join(ll_keywords) - hl_keywords_str = ", ".join(hl_keywords) - formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" + keywords_str = ", ".join(ll_keywords + hl_keywords) + formatted_question = ( + f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" + ) + + param.original_query = query # Use appropriate query method based on mode - if param.mode in ["local", "global", "hybrid"]: + if param.mode in ["local", "global", "hybrid", "mix"]: return await kg_query_with_keywords( formatted_question, knowledge_graph_inst, @@ -2334,6 +2181,9 @@ async def query_with_keywords( param, global_config, hashing_kv=hashing_kv, + hl_keywords=hl_keywords, + ll_keywords=ll_keywords, + chunks_vdb=chunks_vdb, ) elif param.mode == "naive": return await naive_query( @@ -2344,17 +2194,5 @@ async def query_with_keywords( global_config, hashing_kv=hashing_kv, ) - elif param.mode == "mix": - return await mix_kg_vector_query( - formatted_question, - knowledge_graph_inst, - entities_vdb, - relationships_vdb, - chunks_vdb, - text_chunks_db, - param, - global_config, - hashing_kv=hashing_kv, - ) else: raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/utils.py b/lightrag/utils.py index 2ed831b5..7b4920eb 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -721,19 +721,19 @@ def truncate_list_by_token_size( def list_of_list_to_dict(data: list[list[str]]) -> list[dict[str, str]]: """Convert a 2D string list (table-like data) into a list of dictionaries. - + The first row is treated as header containing field names. Subsequent rows become dictionary entries where keys come from header and values from row data. - + Args: data: 2D string array where first row contains headers and rest are data rows. Minimum 2 columns required in data rows (rows with <2 elements are skipped). - + Returns: List of dictionaries where each dict represents a data row with: - Keys: Header values from first row - Values: Corresponding row values (empty string if missing) - + Example: Input: [["Name","Age"], ["Alice","23"], ["Bob"]] Output: [{"Name":"Alice","Age":"23"}, {"Name":"Bob","Age":""}] @@ -822,21 +822,33 @@ def xml_to_json(xml_file): return None -def process_combine_contexts( - hl_context: list[dict[str, str]], ll_context: list[dict[str, str]] -): +def process_combine_contexts(*context_lists): + """ + Combine multiple context lists and remove duplicate content + + Args: + *context_lists: Any number of context lists + + Returns: + Combined context list with duplicates removed + """ seen_content = {} combined_data = [] - for item in hl_context + ll_context: - content_dict = {k: v for k, v in item.items() if k != "id"} - content_key = tuple(sorted(content_dict.items())) - if content_key not in seen_content: - seen_content[content_key] = item - combined_data.append(item) + # Iterate through all input context lists + for context_list in context_lists: + if not context_list: # Skip empty lists + continue + for item in context_list: + content_dict = {k: v for k, v in item.items() if k != "id"} + content_key = tuple(sorted(content_dict.items())) + if content_key not in seen_content: + seen_content[content_key] = item + combined_data.append(item) + # Reassign IDs for i, item in enumerate(combined_data): - item["id"] = str(i) + item["id"] = str(i + 1) return combined_data