diff --git a/lightrag/base.py b/lightrag/base.py index 94a39cf3..7b3504d0 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -31,6 +31,8 @@ class QueryParam: max_token_for_global_context: int = 4000 # Number of tokens for the entity descriptions max_token_for_local_context: int = 4000 + hl_keywords: list[str] = field(default_factory=list) + ll_keywords: list[str] = field(default_factory=list) @dataclass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 596fbdbf..e8859071 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,8 @@ from .operate import ( kg_query, naive_query, mix_kg_vector_query, + extract_keywords_only, + kg_query_with_keywords, ) from .utils import ( @@ -753,6 +755,114 @@ class LightRAG: await self._query_done() return response + def query_with_separate_keyword_extraction( + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() + ): + """ + 1. Extract keywords from the 'query' using new function in operate.py. + 2. Then run the standard aquery() flow with the final prompt (formatted_question). + """ + + loop = always_get_an_event_loop() + return loop.run_until_complete(self.aquery_with_separate_keyword_extraction(query, prompt, param)) + + async def aquery_with_separate_keyword_extraction( + self, + query: str, + prompt: str, + param: QueryParam = QueryParam() + ): + """ + 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. + 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. + """ + + # --------------------- + # STEP 1: Keyword Extraction + # --------------------- + # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords). + hl_keywords, ll_keywords = await extract_keywords_only( + text=query, + param=param, + global_config=asdict(self), + hashing_kv=self.llm_response_cache or self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ) + ) + + param.hl_keywords=hl_keywords, + param.ll_keywords=ll_keywords, + + # --------------------- + # STEP 2: Final Query Logic + # --------------------- + + # Create a new string with the prompt and the keywords + ll_keywords_str = ", ".join(ll_keywords) + hl_keywords_str = ", ".join(hl_keywords) + formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" + + if param.mode in ["local", "global", "hybrid"]: + response = await kg_query_with_keywords( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "naive": + response = await naive_query( + formatted_question, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "mix": + response = await mix_kg_vector_query( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + else: + raise ValueError(f"Unknown mode {param.mode}") + + await self._query_done() + return response + async def _query_done(self): tasks = [] for storage_inst in [self.llm_response_cache]: diff --git a/lightrag/operate.py b/lightrag/operate.py index 7216c07f..f4993873 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -680,6 +680,206 @@ async def kg_query( ) return response +async def kg_query_with_keywords( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> str: + """ + Refactored kg_query that does NOT extract keywords by itself. + It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. + Then it uses those to build context and produce a final LLM response. + """ + + # --------------------------- + # 0) Handle potential cache + # --------------------------- + use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + + # --------------------------- + # 1) RETRIEVE KEYWORDS FROM query_param + # --------------------------- + + # If these fields don't exist, default to empty lists/strings. + hl_keywords = getattr(query_param, "hl_keywords", []) or [] + ll_keywords = getattr(query_param, "ll_keywords", []) or [] + + # If neither has any keywords, you could handle that logic here. + if not hl_keywords and not ll_keywords: + logger.warning("No keywords found in query_param. Could default to global mode or fail.") + return PROMPTS["fail_response"] + if not ll_keywords and query_param.mode in ["local", "hybrid"]: + logger.warning("low_level_keywords is empty, switching to global mode.") + query_param.mode = "global" + if not hl_keywords and query_param.mode in ["global", "hybrid"]: + logger.warning("high_level_keywords is empty, switching to local mode.") + query_param.mode = "local" + + # Flatten low-level and high-level keywords if needed + ll_keywords_flat = [item for sublist in ll_keywords for item in sublist] if any(isinstance(i, list) for i in ll_keywords) else ll_keywords + hl_keywords_flat = [item for sublist in hl_keywords for item in sublist] if any(isinstance(i, list) for i in hl_keywords) else hl_keywords + + # Join the flattened lists + ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else "" + hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else "" + + keywords = [ll_keywords_str, hl_keywords_str] + + logger.info("Using %s mode for query processing", query_param.mode) + + # --------------------------- + # 2) BUILD CONTEXT + # --------------------------- + context = await _build_query_context( + keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + ) + if not context: + return PROMPTS["fail_response"] + + # If only context is needed, return it + if query_param.only_need_context: + return context + + # --------------------------- + # 3) BUILD THE SYSTEM PROMPT + CALL LLM + # --------------------------- + sys_prompt_temp = PROMPTS["rag_response"] + sys_prompt = sys_prompt_temp.format( + context_data=context, response_type=query_param.response_type + ) + + if query_param.only_need_prompt: + return sys_prompt + + # Now call the LLM with the final system prompt + response = await use_model_func( + query, + system_prompt=sys_prompt, + stream=query_param.stream, + ) + + # Clean up the response + if isinstance(response, str) and len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + + # --------------------------- + # 4) SAVE TO CACHE + # --------------------------- + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response + +async def extract_keywords_only( + text: str, + param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> tuple[list[str], list[str]]: + """ + Extract high-level and low-level keywords from the given 'text' using the LLM. + This method does NOT build the final RAG context or provide a final answer. + It ONLY extracts keywords (hl_keywords, ll_keywords). + """ + + # 1. Handle cache if needed + args_hash = compute_args_hash(param.mode, text) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, text, param.mode + ) + if cached_response is not None: + # parse the cached_response if it’s JSON containing keywords + # or simply return (hl_keywords, ll_keywords) from cached + # Assuming cached_response is in the same JSON structure: + match = re.search(r"\{.*\}", cached_response, re.DOTALL) + if match: + keywords_data = json.loads(match.group(0)) + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + return hl_keywords, ll_keywords + return [], [] + + # 2. Build the examples + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): + examples = "\n".join( + PROMPTS["keywords_extraction_examples"][: int(example_number)] + ) + else: + examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) + + # 3. Build the keyword-extraction prompt + kw_prompt_temp = PROMPTS["keywords_extraction"] + kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language) + + # 4. Call the LLM for keyword extraction + use_model_func = global_config["llm_model_func"] + result = await use_model_func(kw_prompt, keyword_extraction=True) + + # 5. Parse out JSON from the LLM response + match = re.search(r"\{.*\}", result, re.DOTALL) + if not match: + logger.error("No JSON-like structure found in the result.") + return [], [] + try: + keywords_data = json.loads(match.group(0)) + except json.JSONDecodeError as e: + logger.error(f"JSON parsing error: {e}") + return [], [] + + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + + # 6. Cache the result if needed + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=result, + prompt=text, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=param.mode, + ), + ) + return hl_keywords, ll_keywords async def _build_query_context( query: list, diff --git a/test.py b/test.py index 80bcaa6d..895f0b30 100644 --- a/test.py +++ b/test.py @@ -39,4 +39,4 @@ print( # Perform hybrid search print( rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) -) +) \ No newline at end of file