diff --git a/README.md b/README.md index da1a1d56..8a0da666 100644 --- a/README.md +++ b/README.md @@ -330,6 +330,26 @@ rag = LightRAG( with open("./newText.txt") as f: rag.insert(f.read()) ``` +### Separate Keyword Extraction +We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords. + +##### How It Works? +The function operates by dividing the input into two parts: +- `User Query` +- `Prompt` + +It then performs keyword extraction exclusively on the `user query`. This separation ensures that the extraction process is focused and relevant, unaffected by any additional language in the `prompt`. It also allows the `prompt` to serve purely for response formatting, maintaining the intent and clarity of the user's original question. + +##### Usage Example +This `example` shows how to tailor the function for educational content, focusing on detailed explanations for older students. + +```python +rag.query_with_separate_keyword_extraction( + query="Explain the law of gravity", + prompt="Provide a detailed explanation suitable for high school students studying physics.", + param=QueryParam(mode="hybrid") +) +``` ### Using Neo4J for Storage diff --git a/examples/query_keyword_separation_example.py b/examples/query_keyword_separation_example.py new file mode 100644 index 00000000..f11ce8c1 --- /dev/null +++ b/examples/query_keyword_separation_example.py @@ -0,0 +1,116 @@ +import os +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.utils import EmbeddingFunc +import numpy as np +from dotenv import load_dotenv +import logging +from openai import AzureOpenAI + +logging.basicConfig(level=logging.INFO) + +load_dotenv() + +AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION") +AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT") +AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") +AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") + +AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT") +AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION") + +WORKING_DIR = "./dickens" + +if os.path.exists(WORKING_DIR): + import shutil + + shutil.rmtree(WORKING_DIR) + +os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +) -> str: + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_OPENAI_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if history_messages: + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + chat_completion = client.chat.completions.create( + model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name". + messages=messages, + temperature=kwargs.get("temperature", 0), + top_p=kwargs.get("top_p", 1), + n=kwargs.get("n", 1), + ) + return chat_completion.choices[0].message.content + + +async def embedding_func(texts: list[str]) -> np.ndarray: + client = AzureOpenAI( + api_key=AZURE_OPENAI_API_KEY, + api_version=AZURE_EMBEDDING_API_VERSION, + azure_endpoint=AZURE_OPENAI_ENDPOINT, + ) + embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts) + + embeddings = [item.embedding for item in embedding.data] + return np.array(embeddings) + + +async def test_funcs(): + result = await llm_model_func("How are you?") + print("Resposta do llm_model_func: ", result) + + result = await embedding_func(["How are you?"]) + print("Resultado do embedding_func: ", result.shape) + print("Dimensão da embedding: ", result.shape[1]) + + +asyncio.run(test_funcs()) + +embedding_dimension = 3072 + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dimension, + max_token_size=8192, + func=embedding_func, + ), +) + +book1 = open("./book_1.txt", encoding="utf-8") +book2 = open("./book_2.txt", encoding="utf-8") + +rag.insert([book1.read(), book2.read()]) + + +# Example function demonstrating the new query_with_separate_keyword_extraction usage +async def run_example(): + query = "What are the top themes in this story?" + prompt = "Please simplify the response for a young audience." + + # Using the new method to ensure the keyword extraction is only applied to the query + response = rag.query_with_separate_keyword_extraction( + query=query, + prompt=prompt, + param=QueryParam(mode="hybrid"), # Adjust QueryParam mode as necessary + ) + + print("Extracted Response:", response) + + +# Run the example asynchronously +if __name__ == "__main__": + asyncio.run(run_example()) diff --git a/lightrag/base.py b/lightrag/base.py index 94a39cf3..7b3504d0 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -31,6 +31,8 @@ class QueryParam: max_token_for_global_context: int = 4000 # Number of tokens for the entity descriptions max_token_for_local_context: int = 4000 + hl_keywords: list[str] = field(default_factory=list) + ll_keywords: list[str] = field(default_factory=list) @dataclass diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 596fbdbf..cacdfc50 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,8 @@ from .operate import ( kg_query, naive_query, mix_kg_vector_query, + extract_keywords_only, + kg_query_with_keywords, ) from .utils import ( @@ -753,6 +755,114 @@ class LightRAG: await self._query_done() return response + def query_with_separate_keyword_extraction( + self, query: str, prompt: str, param: QueryParam = QueryParam() + ): + """ + 1. Extract keywords from the 'query' using new function in operate.py. + 2. Then run the standard aquery() flow with the final prompt (formatted_question). + """ + + loop = always_get_an_event_loop() + return loop.run_until_complete( + self.aquery_with_separate_keyword_extraction(query, prompt, param) + ) + + async def aquery_with_separate_keyword_extraction( + self, query: str, prompt: str, param: QueryParam = QueryParam() + ): + """ + 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. + 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. + """ + + # --------------------- + # STEP 1: Keyword Extraction + # --------------------- + # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords). + hl_keywords, ll_keywords = await extract_keywords_only( + text=query, + param=param, + global_config=asdict(self), + hashing_kv=self.llm_response_cache + or self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + + param.hl_keywords = (hl_keywords,) + param.ll_keywords = (ll_keywords,) + + # --------------------- + # STEP 2: Final Query Logic + # --------------------- + + # Create a new string with the prompt and the keywords + ll_keywords_str = ", ".join(ll_keywords) + hl_keywords_str = ", ".join(hl_keywords) + formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" + + if param.mode in ["local", "global", "hybrid"]: + response = await kg_query_with_keywords( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "naive": + response = await naive_query( + formatted_question, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + elif param.mode == "mix": + response = await mix_kg_vector_query( + formatted_question, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.text_chunks, + param, + asdict(self), + hashing_kv=self.llm_response_cache + if self.llm_response_cache + and hasattr(self.llm_response_cache, "global_config") + else self.key_string_value_json_storage_cls( + namespace="llm_response_cache", + global_config=asdict(self), + embedding_func=None, + ), + ) + else: + raise ValueError(f"Unknown mode {param.mode}") + + await self._query_done() + return response + async def _query_done(self): tasks = [] for storage_inst in [self.llm_response_cache]: diff --git a/lightrag/operate.py b/lightrag/operate.py index 7216c07f..7df489b3 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -681,6 +681,219 @@ async def kg_query( return response +async def kg_query_with_keywords( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage[TextChunkSchema], + query_param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> str: + """ + Refactored kg_query that does NOT extract keywords by itself. + It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty. + Then it uses those to build context and produce a final LLM response. + """ + + # --------------------------- + # 0) Handle potential cache + # --------------------------- + use_model_func = global_config["llm_model_func"] + args_hash = compute_args_hash(query_param.mode, query) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, query, query_param.mode + ) + if cached_response is not None: + return cached_response + + # --------------------------- + # 1) RETRIEVE KEYWORDS FROM query_param + # --------------------------- + + # If these fields don't exist, default to empty lists/strings. + hl_keywords = getattr(query_param, "hl_keywords", []) or [] + ll_keywords = getattr(query_param, "ll_keywords", []) or [] + + # If neither has any keywords, you could handle that logic here. + if not hl_keywords and not ll_keywords: + logger.warning( + "No keywords found in query_param. Could default to global mode or fail." + ) + return PROMPTS["fail_response"] + if not ll_keywords and query_param.mode in ["local", "hybrid"]: + logger.warning("low_level_keywords is empty, switching to global mode.") + query_param.mode = "global" + if not hl_keywords and query_param.mode in ["global", "hybrid"]: + logger.warning("high_level_keywords is empty, switching to local mode.") + query_param.mode = "local" + + # Flatten low-level and high-level keywords if needed + ll_keywords_flat = ( + [item for sublist in ll_keywords for item in sublist] + if any(isinstance(i, list) for i in ll_keywords) + else ll_keywords + ) + hl_keywords_flat = ( + [item for sublist in hl_keywords for item in sublist] + if any(isinstance(i, list) for i in hl_keywords) + else hl_keywords + ) + + # Join the flattened lists + ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else "" + hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else "" + + keywords = [ll_keywords_str, hl_keywords_str] + + logger.info("Using %s mode for query processing", query_param.mode) + + # --------------------------- + # 2) BUILD CONTEXT + # --------------------------- + context = await _build_query_context( + keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + ) + if not context: + return PROMPTS["fail_response"] + + # If only context is needed, return it + if query_param.only_need_context: + return context + + # --------------------------- + # 3) BUILD THE SYSTEM PROMPT + CALL LLM + # --------------------------- + sys_prompt_temp = PROMPTS["rag_response"] + sys_prompt = sys_prompt_temp.format( + context_data=context, response_type=query_param.response_type + ) + + if query_param.only_need_prompt: + return sys_prompt + + # Now call the LLM with the final system prompt + response = await use_model_func( + query, + system_prompt=sys_prompt, + stream=query_param.stream, + ) + + # Clean up the response + if isinstance(response, str) and len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + + # --------------------------- + # 4) SAVE TO CACHE + # --------------------------- + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=query_param.mode, + ), + ) + return response + + +async def extract_keywords_only( + text: str, + param: QueryParam, + global_config: dict, + hashing_kv: BaseKVStorage = None, +) -> tuple[list[str], list[str]]: + """ + Extract high-level and low-level keywords from the given 'text' using the LLM. + This method does NOT build the final RAG context or provide a final answer. + It ONLY extracts keywords (hl_keywords, ll_keywords). + """ + + # 1. Handle cache if needed + args_hash = compute_args_hash(param.mode, text) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, text, param.mode + ) + if cached_response is not None: + # parse the cached_response if it’s JSON containing keywords + # or simply return (hl_keywords, ll_keywords) from cached + # Assuming cached_response is in the same JSON structure: + match = re.search(r"\{.*\}", cached_response, re.DOTALL) + if match: + keywords_data = json.loads(match.group(0)) + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + return hl_keywords, ll_keywords + return [], [] + + # 2. Build the examples + example_number = global_config["addon_params"].get("example_number", None) + if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]): + examples = "\n".join( + PROMPTS["keywords_extraction_examples"][: int(example_number)] + ) + else: + examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + language = global_config["addon_params"].get( + "language", PROMPTS["DEFAULT_LANGUAGE"] + ) + + # 3. Build the keyword-extraction prompt + kw_prompt_temp = PROMPTS["keywords_extraction"] + kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language) + + # 4. Call the LLM for keyword extraction + use_model_func = global_config["llm_model_func"] + result = await use_model_func(kw_prompt, keyword_extraction=True) + + # 5. Parse out JSON from the LLM response + match = re.search(r"\{.*\}", result, re.DOTALL) + if not match: + logger.error("No JSON-like structure found in the result.") + return [], [] + try: + keywords_data = json.loads(match.group(0)) + except json.JSONDecodeError as e: + logger.error(f"JSON parsing error: {e}") + return [], [] + + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + + # 6. Cache the result if needed + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=result, + prompt=text, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=param.mode, + ), + ) + return hl_keywords, ll_keywords + + async def _build_query_context( query: list, knowledge_graph_inst: BaseGraphStorage,