diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3a7d340a..48b464a8 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -30,11 +30,10 @@ from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, extract_entities, - extract_keywords_only, kg_query, - kg_query_with_keywords, mix_kg_vector_query, naive_query, + query_with_keywords, ) from .prompt import GRAPH_FIELD_SEP, PROMPTS from .utils import ( @@ -45,6 +44,9 @@ from .utils import ( encode_string_by_tiktoken, lazy_external_import, limit_async_func_call, + get_content_summary, + clean_text, + check_storage_env_vars, logger, ) from .types import KnowledgeGraph @@ -309,7 +311,7 @@ class LightRAG: # Verify storage implementation compatibility verify_storage_implementation(storage_type, storage_name) # Check environment variables - # self.check_storage_env_vars(storage_name) + check_storage_env_vars(storage_name) # Ensure vector_db_storage_cls_kwargs has required fields self.vector_db_storage_cls_kwargs = { @@ -536,11 +538,6 @@ class LightRAG: storage_class = lazy_external_import(import_path, storage_name) return storage_class - @staticmethod - def clean_text(text: str) -> str: - """Clean text by removing null bytes (0x00) and whitespace""" - return text.strip().replace("\x00", "") - def insert( self, input: str | list[str], @@ -602,8 +599,8 @@ class LightRAG: update_storage = False try: # Clean input texts - full_text = self.clean_text(full_text) - text_chunks = [self.clean_text(chunk) for chunk in text_chunks] + full_text = clean_text(full_text) + text_chunks = [clean_text(chunk) for chunk in text_chunks] # Process cleaned texts if doc_id is None: @@ -682,7 +679,7 @@ class LightRAG: contents = {id_: doc for id_, doc in zip(ids, input)} else: # Clean input text and remove duplicates - input = list(set(self.clean_text(doc) for doc in input)) + input = list(set(clean_text(doc) for doc in input)) # Generate contents dict of MD5 hash IDs and documents contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input} @@ -698,7 +695,7 @@ class LightRAG: new_docs: dict[str, Any] = { id_: { "content": content, - "content_summary": self._get_content_summary(content), + "content_summary": get_content_summary(content), "content_length": len(content), "status": DocStatus.PENDING, "created_at": datetime.now().isoformat(), @@ -1063,7 +1060,7 @@ class LightRAG: all_chunks_data: dict[str, dict[str, str]] = {} chunk_to_source_map: dict[str, str] = {} for chunk_data in custom_kg.get("chunks", []): - chunk_content = self.clean_text(chunk_data["content"]) + chunk_content = clean_text(chunk_data["content"]) source_id = chunk_data["source_id"] tokens = len( encode_string_by_tiktoken( @@ -1296,8 +1293,17 @@ class LightRAG: self, query: str, prompt: str, param: QueryParam = QueryParam() ): """ - 1. Extract keywords from the 'query' using new function in operate.py. - 2. Then run the standard aquery() flow with the final prompt (formatted_question). + Query with separate keyword extraction step. + + This method extracts keywords from the query first, then uses them for the query. + + Args: + query: User query + prompt: Additional prompt for the query + param: Query parameters + + Returns: + Query response """ loop = always_get_an_event_loop() return loop.run_until_complete( @@ -1308,66 +1314,29 @@ class LightRAG: self, query: str, prompt: str, param: QueryParam = QueryParam() ) -> str | AsyncIterator[str]: """ - 1. Calls extract_keywords_only to get HL/LL keywords from 'query'. - 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed. + Async version of query_with_separate_keyword_extraction. + + Args: + query: User query + prompt: Additional prompt for the query + param: Query parameters + + Returns: + Query response or async iterator """ - # --------------------- - # STEP 1: Keyword Extraction - # --------------------- - hl_keywords, ll_keywords = await extract_keywords_only( - text=query, + response = await query_with_keywords( + query=query, + prompt=prompt, param=param, + knowledge_graph_inst=self.chunk_entity_relation_graph, + entities_vdb=self.entities_vdb, + relationships_vdb=self.relationships_vdb, + chunks_vdb=self.chunks_vdb, + text_chunks_db=self.text_chunks, global_config=asdict(self), - hashing_kv=self.llm_response_cache, # Directly use llm_response_cache + hashing_kv=self.llm_response_cache, ) - param.hl_keywords = hl_keywords - param.ll_keywords = ll_keywords - - # --------------------- - # STEP 2: Final Query Logic - # --------------------- - - # Create a new string with the prompt and the keywords - ll_keywords_str = ", ".join(ll_keywords) - hl_keywords_str = ", ".join(hl_keywords) - formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" - - if param.mode in ["local", "global", "hybrid"]: - response = await kg_query_with_keywords( - formatted_question, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.text_chunks, - param, - asdict(self), - hashing_kv=self.llm_response_cache, # Directly use llm_response_cache - ) - elif param.mode == "naive": - response = await naive_query( - formatted_question, - self.chunks_vdb, - self.text_chunks, - param, - asdict(self), - hashing_kv=self.llm_response_cache, # Directly use llm_response_cache - ) - elif param.mode == "mix": - response = await mix_kg_vector_query( - formatted_question, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.chunks_vdb, - self.text_chunks, - param, - asdict(self), - hashing_kv=self.llm_response_cache, # Directly use llm_response_cache - ) - else: - raise ValueError(f"Unknown mode {param.mode}") - await self._query_done() return response @@ -1465,21 +1434,6 @@ class LightRAG: ] ) - def _get_content_summary(self, content: str, max_length: int = 100) -> str: - """Get summary of document content - - Args: - content: Original document content - max_length: Maximum length of summary - - Returns: - Truncated content with ellipsis if needed - """ - content = content.strip() - if len(content) <= max_length: - return content - return content[:max_length] + "..." - async def get_processing_status(self) -> dict[str, int]: """Get current document processing status counts diff --git a/lightrag/operate.py b/lightrag/operate.py index 5baec1eb..1815f308 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1916,3 +1916,90 @@ async def kg_query_with_keywords( ) return response + + +async def query_with_keywords( + query: str, + prompt: str, + param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + chunks_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, +) -> str | AsyncIterator[str]: + """ + Extract keywords from the query and then use them for retrieving information. + + 1. Extracts high-level and low-level keywords from the query + 2. Formats the query with the extracted keywords and prompt + 3. Uses the appropriate query method based on param.mode + + Args: + query: The user's query + prompt: Additional prompt to prepend to the query + param: Query parameters + knowledge_graph_inst: Knowledge graph storage + entities_vdb: Entities vector database + relationships_vdb: Relationships vector database + chunks_vdb: Document chunks vector database + text_chunks_db: Text chunks storage + global_config: Global configuration + hashing_kv: Cache storage + + Returns: + Query response or async iterator + """ + # Extract keywords + hl_keywords, ll_keywords = await extract_keywords_only( + text=query, + param=param, + global_config=global_config, + hashing_kv=hashing_kv, + ) + + param.hl_keywords = hl_keywords + param.ll_keywords = ll_keywords + + # Create a new string with the prompt and the keywords + ll_keywords_str = ", ".join(ll_keywords) + hl_keywords_str = ", ".join(hl_keywords) + formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}" + + # Use appropriate query method based on mode + if param.mode in ["local", "global", "hybrid"]: + return await kg_query_with_keywords( + formatted_question, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + param, + global_config, + hashing_kv=hashing_kv, + ) + elif param.mode == "naive": + return await naive_query( + formatted_question, + chunks_vdb, + text_chunks_db, + param, + global_config, + hashing_kv=hashing_kv, + ) + elif param.mode == "mix": + return await mix_kg_vector_query( + formatted_question, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + chunks_vdb, + text_chunks_db, + param, + global_config, + hashing_kv=hashing_kv, + ) + else: + raise ValueError(f"Unknown mode {param.mode}") diff --git a/lightrag/utils.py b/lightrag/utils.py index e8f79610..b8f00c5d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -890,3 +890,52 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any return cls(*args, **kwargs) return import_class + + +def get_content_summary(content: str, max_length: int = 100) -> str: + """Get summary of document content + + Args: + content: Original document content + max_length: Maximum length of summary + + Returns: + Truncated content with ellipsis if needed + """ + content = content.strip() + if len(content) <= max_length: + return content + return content[:max_length] + "..." + + +def clean_text(text: str) -> str: + """Clean text by removing null bytes (0x00) and whitespace + + Args: + text: Input text to clean + + Returns: + Cleaned text + """ + return text.strip().replace("\x00", "") + + +def check_storage_env_vars(storage_name: str) -> None: + """Check if all required environment variables for storage implementation exist + + Args: + storage_name: Storage implementation name + + Raises: + ValueError: If required environment variables are missing + """ + from lightrag.kg import STORAGE_ENV_REQUIREMENTS + + required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + raise ValueError( + f"Storage implementation '{storage_name}' requires the following " + f"environment variables: {', '.join(missing_vars)}" + )