From 9f2c659d9cba73214e178c7b8910b3e510ea760b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 12:54:14 +0100 Subject: [PATCH 01/14] remove unused log --- lightrag/kg/oracle_impl.py | 2 +- lightrag/kg/postgres_impl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index de61a2ca..35983ad3 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -44,7 +44,7 @@ class OracleDB: self.increment = 1 logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") if self.user is None or self.password is None: - raise ValueError("Missing database user or password in addon_params") + raise ValueError("Missing database user or password") try: oracledb.defaults.fetch_lobs = False diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index ababc05f..52370821 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -55,7 +55,7 @@ class PostgreSQLDB: if self.user is None or self.password is None or self.database is None: raise ValueError( - "Missing database user, password, or database in addon_params" + "Missing database user, password, or database" ) async def initdb(self): From de56aeb7c5fcb60ea2c391a54f7e1d4ed0559178 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 12:54:52 +0100 Subject: [PATCH 02/14] removed lock --- lightrag/lightrag.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a4daeced..a34ae20d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -7,7 +7,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial from typing import Any, AsyncIterator, Callable, Iterator, cast -from asyncio import Lock + from .base import ( BaseGraphStorage, BaseKVStorage, @@ -358,9 +358,6 @@ class LightRAG: convert_response_to_json ) - # Lock for entity extraction - _entity_lock = Lock() - # Custom Chunking Function chunking_func: Callable[ [ @@ -1203,7 +1200,6 @@ class LightRAG: # --------------------- # STEP 1: Keyword Extraction # --------------------- - # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords). hl_keywords, ll_keywords = await extract_keywords_only( text=query, param=param, From bae21a6fadbc6093bb4094ec6e151fff9592d721 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 12:57:25 +0100 Subject: [PATCH 03/14] added max paralle insert --- lightrag/lightrag.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a34ae20d..22c32770 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -347,6 +347,9 @@ class LightRAG: """If True, enables caching for entity extraction steps to reduce LLM costs.""" # Extensions + max_parallel_insert: int = field(default_factory=lambda: int(os.getenv("MAX_PARALLEL_INSERT", 20))) + """Maximum number of parallel insert operations.""" + addon_params: dict[str, Any] = field(default_factory=dict) # Storages Management @@ -786,10 +789,9 @@ class LightRAG: return # 2. split docs into chunks, insert chunks, update doc status - batch_size = self.addon_params.get("insert_batch_size", 10) docs_batches = [ - list(to_process_docs.items())[i : i + batch_size] - for i in range(0, len(to_process_docs), batch_size) + list(to_process_docs.items())[i : i + self.max_parallel_insert] + for i in range(0, len(to_process_docs), self.max_parallel_insert) ] logger.info(f"Number of batches to process: {len(docs_batches)}.") From 37addb7c01682d1e273c74532bbf9ceada0020f2 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:05:35 +0100 Subject: [PATCH 04/14] added final --- lightrag/lightrag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 22c32770..f2d48444 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -6,7 +6,7 @@ import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, AsyncIterator, Callable, Iterator, cast +from typing import Any, AsyncIterator, Callable, Iterator, cast, final from .base import ( BaseGraphStorage, @@ -225,7 +225,7 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: asyncio.set_event_loop(new_loop) return new_loop - +@final @dataclass class LightRAG: """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" From 2370a4336b0e16387db942a6e4f03c29d32ee5e7 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:05:59 +0100 Subject: [PATCH 05/14] added field --- lightrag/lightrag.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index f2d48444..28d5d078 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -231,12 +231,12 @@ class LightRAG: """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" working_dir: str = field( - default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" + default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) """Directory where cache and temporary files are stored.""" embedding_cache_config: dict[str, Any] = field( - default_factory=lambda: { + default={ "enabled": False, "similarity_threshold": 0.95, "use_llm_check": False, @@ -261,32 +261,31 @@ class LightRAG: """Storage type for tracking document processing statuses.""" # Logging - current_log_level = logger.level - log_level: int = field(default=current_log_level) + log_level: int = field(default=logger.level) """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" log_dir: str = field(default=os.getcwd()) """Directory where logs are stored. Defaults to the current working directory.""" # Text chunking - chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200")) + chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200))) """Maximum number of tokens per text chunk when splitting documents.""" - chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100")) + chunk_overlap_token_size: int = field(default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))) """Number of overlapping tokens between consecutive text chunks to preserve context.""" - tiktoken_model_name: str = "gpt-4o-mini" + tiktoken_model_name: str = field(default="gpt-4o-mini") """Model name used for tokenization when chunking text.""" # Entity extraction - entity_extract_max_gleaning: int = 1 + entity_extract_max_gleaning: int = field(default=1) """Maximum number of entity extraction attempts for ambiguous content.""" - entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500")) + entity_summary_to_max_tokens: int = field(default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))) """Maximum number of tokens used for summarizing extracted entities.""" # Node embedding - node_embedding_algorithm: str = "node2vec" + node_embedding_algorithm: str = field(default="node2vec") """Algorithm used for node embedding in knowledge graphs.""" node2vec_params: dict[str, int] = field( From f5a93c7bb5d78f5b6a0c94e96ead46c7cb3bd147 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:06:16 +0100 Subject: [PATCH 06/14] added fields --- lightrag/lightrag.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 28d5d078..5b01c18b 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -307,26 +307,26 @@ class LightRAG: - random_seed: Seed value for reproducibility. """ - embedding_func: EmbeddingFunc | None = None + embedding_func: EmbeddingFunc | None = field(default=None) """Function for computing text embeddings. Must be set before use.""" - embedding_batch_num: int = 32 + embedding_batch_num: int = field(default=32) """Batch size for embedding computations.""" - embedding_func_max_async: int = 16 + embedding_func_max_async: int = field(default=16) """Maximum number of concurrent embedding function calls.""" # LLM Configuration - llm_model_func: Callable[..., object] | None = None + llm_model_func: Callable[..., object] | None = field(default=None) """Function for interacting with the large language model (LLM). Must be set before use.""" - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" + llm_model_name: str = field(default="gpt-4o-mini") """Name of the LLM model used for generating responses.""" - llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768")) + llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768))) """Maximum number of tokens allowed per LLM response.""" - llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) + llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16))) """Maximum number of concurrent LLM calls.""" llm_model_kwargs: dict[str, Any] = field(default_factory=dict) From 4b478d1c0ff521cd25d607f3ff1a5c3a66e238d3 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:06:34 +0100 Subject: [PATCH 07/14] added fields --- lightrag/lightrag.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5b01c18b..5706e189 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -339,20 +339,20 @@ class LightRAG: namespace_prefix: str = field(default="") """Prefix for namespacing stored data across different environments.""" - enable_llm_cache: bool = True + enable_llm_cache: bool = field(default=True) """Enables caching for LLM responses to avoid redundant computations.""" - enable_llm_cache_for_entity_extract: bool = True + enable_llm_cache_for_entity_extract: bool = field(default=True) """If True, enables caching for entity extraction steps to reduce LLM costs.""" # Extensions - max_parallel_insert: int = field(default_factory=lambda: int(os.getenv("MAX_PARALLEL_INSERT", 20))) + max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20))) """Maximum number of parallel insert operations.""" addon_params: dict[str, Any] = field(default_factory=dict) # Storages Management - auto_manage_storages_states: bool = True + auto_manage_storages_states: bool = field(default=True) """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" """Dictionary for additional parameters and extensions.""" From 32d0f1acb04c9499024b7d953957736cef0c850c Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:09:33 +0100 Subject: [PATCH 08/14] added docs and fields --- lightrag/kg/postgres_impl.py | 4 +--- lightrag/lightrag.py | 44 ++++++++++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 52370821..d7ace41a 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -54,9 +54,7 @@ class PostgreSQLDB: self.pool: Pool | None = None if self.user is None or self.password is None or self.database is None: - raise ValueError( - "Missing database user, password, or database" - ) + raise ValueError("Missing database user, password, or database") async def initdb(self): try: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5706e189..247e09ec 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -225,6 +225,7 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: asyncio.set_event_loop(new_loop) return new_loop + @final @dataclass class LightRAG: @@ -271,7 +272,9 @@ class LightRAG: chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200))) """Maximum number of tokens per text chunk when splitting documents.""" - chunk_overlap_token_size: int = field(default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))) + chunk_overlap_token_size: int = field( + default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100)) + ) """Number of overlapping tokens between consecutive text chunks to preserve context.""" tiktoken_model_name: str = field(default="gpt-4o-mini") @@ -281,11 +284,13 @@ class LightRAG: entity_extract_max_gleaning: int = field(default=1) """Maximum number of entity extraction attempts for ambiguous content.""" - entity_summary_to_max_tokens: int = field(default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))) + entity_summary_to_max_tokens: int = field( + default=int(os.getenv("MAX_TOKEN_SUMMARY", 500)) + ) """Maximum number of tokens used for summarizing extracted entities.""" # Node embedding - node_embedding_algorithm: str = field(default="node2vec") + node_embedding_algorithm: str = field(default="node2vec") """Algorithm used for node embedding in knowledge graphs.""" node2vec_params: dict[str, int] = field( @@ -348,19 +353,22 @@ class LightRAG: # Extensions max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20))) """Maximum number of parallel insert operations.""" - + addon_params: dict[str, Any] = field(default_factory=dict) # Storages Management auto_manage_storages_states: bool = field(default=True) """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" - """Dictionary for additional parameters and extensions.""" - convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( - convert_response_to_json + convert_response_to_json_func: Callable[[str], dict[str, Any]] = field( + default_factory=lambda: convert_response_to_json ) + """ + Custom function for converting LLM responses to JSON format. + + The default function is :func:`.utils.convert_response_to_json`. + """ - # Custom Chunking Function chunking_func: Callable[ [ str, @@ -371,7 +379,25 @@ class LightRAG: str, ], list[dict[str, Any]], - ] = chunking_by_token_size + ] = field(default_factory=lambda: chunking_by_token_size) + """ + Custom chunking function for splitting text into chunks before processing. + + The function should take the following parameters: + + - `content`: The text to be split into chunks. + - `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens. + - `split_by_character_only`: If True, the text is split only on the specified character. + - `chunk_token_size`: The maximum number of tokens per chunk. + - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks. + - `tiktoken_model_name`: The name of the tiktoken model to use for tokenization. + + The function should return a list of dictionaries, where each dictionary contains the following keys: + - `tokens`: The number of tokens in the chunk. + - `content`: The text content of the chunk. + + Defaults to `chunking_by_token_size` if not specified. + """ def verify_storage_implementation( self, storage_type: str, storage_name: str From 72b978d6d5dec60f18431ddf7f2488b6908a2d32 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:13:38 +0100 Subject: [PATCH 09/14] cleanup --- lightrag/lightrag.py | 227 ++++++++++++++++++++++++------------------- 1 file changed, 128 insertions(+), 99 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 247e09ec..481025af 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -231,23 +231,16 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: class LightRAG: """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" + # Directory + # --- + working_dir: str = field( default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) """Directory where cache and temporary files are stored.""" - embedding_cache_config: dict[str, Any] = field( - default={ - "enabled": False, - "similarity_threshold": 0.95, - "use_llm_check": False, - } - ) - """Configuration for embedding cache. - - enabled: If True, enables caching to avoid redundant computations. - - similarity_threshold: Minimum similarity score to use cached embeddings. - - use_llm_check: If True, validates cached embeddings using an LLM. - """ + # Storage + # --- kv_storage: str = field(default="JsonKVStorage") """Storage backend for key-value data.""" @@ -262,13 +255,27 @@ class LightRAG: """Storage type for tracking document processing statuses.""" # Logging + # --- + log_level: int = field(default=logger.level) """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" log_dir: str = field(default=os.getcwd()) """Directory where logs are stored. Defaults to the current working directory.""" + # Entity extraction + # --- + + entity_extract_max_gleaning: int = field(default=1) + """Maximum number of entity extraction attempts for ambiguous content.""" + + entity_summary_to_max_tokens: int = field( + default=int(os.getenv("MAX_TOKEN_SUMMARY", 500)) + ) + # Text chunking + # --- + chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200))) """Maximum number of tokens per text chunk when splitting documents.""" @@ -280,95 +287,8 @@ class LightRAG: tiktoken_model_name: str = field(default="gpt-4o-mini") """Model name used for tokenization when chunking text.""" - # Entity extraction - entity_extract_max_gleaning: int = field(default=1) - """Maximum number of entity extraction attempts for ambiguous content.""" - - entity_summary_to_max_tokens: int = field( - default=int(os.getenv("MAX_TOKEN_SUMMARY", 500)) - ) """Maximum number of tokens used for summarizing extracted entities.""" - # Node embedding - node_embedding_algorithm: str = field(default="node2vec") - """Algorithm used for node embedding in knowledge graphs.""" - - node2vec_params: dict[str, int] = field( - default_factory=lambda: { - "dimensions": 1536, - "num_walks": 10, - "walk_length": 40, - "window_size": 2, - "iterations": 3, - "random_seed": 3, - } - ) - """Configuration for the node2vec embedding algorithm: - - dimensions: Number of dimensions for embeddings. - - num_walks: Number of random walks per node. - - walk_length: Number of steps per random walk. - - window_size: Context window size for training. - - iterations: Number of iterations for training. - - random_seed: Seed value for reproducibility. - """ - - embedding_func: EmbeddingFunc | None = field(default=None) - """Function for computing text embeddings. Must be set before use.""" - - embedding_batch_num: int = field(default=32) - """Batch size for embedding computations.""" - - embedding_func_max_async: int = field(default=16) - """Maximum number of concurrent embedding function calls.""" - - # LLM Configuration - llm_model_func: Callable[..., object] | None = field(default=None) - """Function for interacting with the large language model (LLM). Must be set before use.""" - - llm_model_name: str = field(default="gpt-4o-mini") - """Name of the LLM model used for generating responses.""" - - llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768))) - """Maximum number of tokens allowed per LLM response.""" - - llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16))) - """Maximum number of concurrent LLM calls.""" - - llm_model_kwargs: dict[str, Any] = field(default_factory=dict) - """Additional keyword arguments passed to the LLM model function.""" - - # Storage - vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) - """Additional parameters for vector database storage.""" - - namespace_prefix: str = field(default="") - """Prefix for namespacing stored data across different environments.""" - - enable_llm_cache: bool = field(default=True) - """Enables caching for LLM responses to avoid redundant computations.""" - - enable_llm_cache_for_entity_extract: bool = field(default=True) - """If True, enables caching for entity extraction steps to reduce LLM costs.""" - - # Extensions - max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20))) - """Maximum number of parallel insert operations.""" - - addon_params: dict[str, Any] = field(default_factory=dict) - - # Storages Management - auto_manage_storages_states: bool = field(default=True) - """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" - - convert_response_to_json_func: Callable[[str], dict[str, Any]] = field( - default_factory=lambda: convert_response_to_json - ) - """ - Custom function for converting LLM responses to JSON format. - - The default function is :func:`.utils.convert_response_to_json`. - """ - chunking_func: Callable[ [ str, @@ -399,6 +319,115 @@ class LightRAG: Defaults to `chunking_by_token_size` if not specified. """ + # Node embedding + # --- + + node_embedding_algorithm: str = field(default="node2vec") + """Algorithm used for node embedding in knowledge graphs.""" + + node2vec_params: dict[str, int] = field( + default_factory=lambda: { + "dimensions": 1536, + "num_walks": 10, + "walk_length": 40, + "window_size": 2, + "iterations": 3, + "random_seed": 3, + } + ) + """Configuration for the node2vec embedding algorithm: + - dimensions: Number of dimensions for embeddings. + - num_walks: Number of random walks per node. + - walk_length: Number of steps per random walk. + - window_size: Context window size for training. + - iterations: Number of iterations for training. + - random_seed: Seed value for reproducibility. + """ + + # Embedding + # --- + + embedding_func: EmbeddingFunc | None = field(default=None) + """Function for computing text embeddings. Must be set before use.""" + + embedding_batch_num: int = field(default=32) + """Batch size for embedding computations.""" + + embedding_func_max_async: int = field(default=16) + """Maximum number of concurrent embedding function calls.""" + + embedding_cache_config: dict[str, Any] = field( + default={ + "enabled": False, + "similarity_threshold": 0.95, + "use_llm_check": False, + } + ) + """Configuration for embedding cache. + - enabled: If True, enables caching to avoid redundant computations. + - similarity_threshold: Minimum similarity score to use cached embeddings. + - use_llm_check: If True, validates cached embeddings using an LLM. + """ + + # LLM Configuration + # --- + + llm_model_func: Callable[..., object] | None = field(default=None) + """Function for interacting with the large language model (LLM). Must be set before use.""" + + llm_model_name: str = field(default="gpt-4o-mini") + """Name of the LLM model used for generating responses.""" + + llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768))) + """Maximum number of tokens allowed per LLM response.""" + + llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16))) + """Maximum number of concurrent LLM calls.""" + + llm_model_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional keyword arguments passed to the LLM model function.""" + + # Storage + # --- + + vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) + """Additional parameters for vector database storage.""" + + namespace_prefix: str = field(default="") + """Prefix for namespacing stored data across different environments.""" + + enable_llm_cache: bool = field(default=True) + """Enables caching for LLM responses to avoid redundant computations.""" + + enable_llm_cache_for_entity_extract: bool = field(default=True) + """If True, enables caching for entity extraction steps to reduce LLM costs.""" + + # Extensions + # --- + + max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20))) + """Maximum number of parallel insert operations.""" + + addon_params: dict[str, Any] = field(default_factory=dict) + + # Storages Management + # --- + + auto_manage_storages_states: bool = field(default=True) + """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" + + # Storages Management + # --- + + convert_response_to_json_func: Callable[[str], dict[str, Any]] = field( + default_factory=lambda: convert_response_to_json + ) + """ + Custom function for converting LLM responses to JSON format. + + The default function is :func:`.utils.convert_response_to_json`. + """ + def verify_storage_implementation( self, storage_type: str, storage_name: str ) -> None: From 32e489865c0378c61fe240151b3b1572f3b7e1ae Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:18:17 +0100 Subject: [PATCH 10/14] cleanup code --- docker-compose.yml | 2 - examples/lightrag_api_oracle_demo.py | 1 - .../lightrag_openai_compatible_stream_demo.py | 7 - examples/lightrag_tidb_demo.py | 1 - lightrag/lightrag.py | 121 ++++++------------ lightrag/utils.py | 44 +++++++ reproduce/Step_3.py | 10 +- reproduce/Step_3_openai_compatible.py | 11 +- 8 files changed, 89 insertions(+), 108 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index b5659692..4ced24ca 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3.8' - services: lightrag: build: . diff --git a/examples/lightrag_api_oracle_demo.py b/examples/lightrag_api_oracle_demo.py index e66e3f94..3675795e 100644 --- a/examples/lightrag_api_oracle_demo.py +++ b/examples/lightrag_api_oracle_demo.py @@ -98,7 +98,6 @@ async def init(): # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage - # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py index 93c4297c..7509e4dc 100644 --- a/examples/lightrag_openai_compatible_stream_demo.py +++ b/examples/lightrag_openai_compatible_stream_demo.py @@ -1,9 +1,7 @@ import os -import inspect from lightrag import LightRAG from lightrag.llm import openai_complete, openai_embed from lightrag.utils import EmbeddingFunc -from lightrag.lightrag import always_get_an_event_loop from lightrag import QueryParam # WorkingDir @@ -48,8 +46,3 @@ async def print_stream(stream): print(chunk, end="", flush=True) -loop = always_get_an_event_loop() -if inspect.isasyncgen(resp): - loop.run_until_complete(print_stream(resp)) -else: - print(resp) diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index f4004f84..f2ee9ad8 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -63,7 +63,6 @@ async def main(): # Initialize LightRAG # We use TiDB DB as the KV/vector - # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 481025af..8b695883 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -32,8 +32,10 @@ from .operate import ( from .prompt import GRAPH_FIELD_SEP from .utils import ( EmbeddingFunc, + always_get_an_event_loop, compute_mdhash_id, convert_response_to_json, + lazy_external_import, limit_async_func_call, logger, set_logger, @@ -182,48 +184,9 @@ STORAGES = { } -def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: - """Lazily import a class from an external module based on the package of the caller.""" - # Get the caller's module and package - import inspect - - caller_frame = inspect.currentframe().f_back - module = inspect.getmodule(caller_frame) - package = module.__package__ if module else None - - def import_class(*args: Any, **kwargs: Any): - import importlib - - module = importlib.import_module(module_name, package=package) - cls = getattr(module, class_name) - return cls(*args, **kwargs) - - return import_class -def always_get_an_event_loop() -> asyncio.AbstractEventLoop: - """ - Ensure that there is always an event loop available. - This function tries to get the current event loop. If the current event loop is closed or does not exist, - it creates a new event loop and sets it as the current event loop. - - Returns: - asyncio.AbstractEventLoop: The current or newly created event loop. - """ - try: - # Try to get the current event loop - current_loop = asyncio.get_event_loop() - if current_loop.is_closed(): - raise RuntimeError("Event loop is closed.") - return current_loop - - except RuntimeError: - # If no event loop exists or it is closed, create a new one - logger.info("Creating a new event loop in main thread.") - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - return new_loop @final @@ -428,46 +391,6 @@ class LightRAG: The default function is :func:`.utils.convert_response_to_json`. """ - def verify_storage_implementation( - self, storage_type: str, storage_name: str - ) -> None: - """Verify if storage implementation is compatible with specified storage type - - Args: - storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.) - storage_name: Storage implementation name - - Raises: - ValueError: If storage implementation is incompatible or missing required methods - """ - if storage_type not in STORAGE_IMPLEMENTATIONS: - raise ValueError(f"Unknown storage type: {storage_type}") - - storage_info = STORAGE_IMPLEMENTATIONS[storage_type] - if storage_name not in storage_info["implementations"]: - raise ValueError( - f"Storage implementation '{storage_name}' is not compatible with {storage_type}. " - f"Compatible implementations are: {', '.join(storage_info['implementations'])}" - ) - - def check_storage_env_vars(self, storage_name: str) -> None: - """Check if all required environment variables for storage implementation exist - - Args: - storage_name: Storage implementation name - - Raises: - ValueError: If required environment variables are missing - """ - required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) - missing_vars = [var for var in required_vars if var not in os.environ] - - if missing_vars: - raise ValueError( - f"Storage implementation '{storage_name}' requires the following " - f"environment variables: {', '.join(missing_vars)}" - ) - def __post_init__(self): os.makedirs(self.log_dir, exist_ok=True) log_file = os.path.join(self.log_dir, "lightrag.log") @@ -1681,3 +1604,43 @@ class LightRAG: result["vector_data"] = vector_data[0] if vector_data else None return result + + def verify_storage_implementation( + self, storage_type: str, storage_name: str + ) -> None: + """Verify if storage implementation is compatible with specified storage type + + Args: + storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.) + storage_name: Storage implementation name + + Raises: + ValueError: If storage implementation is incompatible or missing required methods + """ + if storage_type not in STORAGE_IMPLEMENTATIONS: + raise ValueError(f"Unknown storage type: {storage_type}") + + storage_info = STORAGE_IMPLEMENTATIONS[storage_type] + if storage_name not in storage_info["implementations"]: + raise ValueError( + f"Storage implementation '{storage_name}' is not compatible with {storage_type}. " + f"Compatible implementations are: {', '.join(storage_info['implementations'])}" + ) + + def check_storage_env_vars(self, storage_name: str) -> None: + """Check if all required environment variables for storage implementation exist + + Args: + storage_name: Storage implementation name + + Raises: + ValueError: If required environment variables are missing + """ + required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + raise ValueError( + f"Storage implementation '{storage_name}' requires the following " + f"environment variables: {', '.join(missing_vars)}" + ) \ No newline at end of file diff --git a/lightrag/utils.py b/lightrag/utils.py index d932f149..62f62d4d 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -713,3 +713,47 @@ def get_conversation_turns( ) return "\n".join(formatted_turns) + +def always_get_an_event_loop() -> asyncio.AbstractEventLoop: + """ + Ensure that there is always an event loop available. + + This function tries to get the current event loop. If the current event loop is closed or does not exist, + it creates a new event loop and sets it as the current event loop. + + Returns: + asyncio.AbstractEventLoop: The current or newly created event loop. + """ + try: + # Try to get the current event loop + current_loop = asyncio.get_event_loop() + if current_loop.is_closed(): + raise RuntimeError("Event loop is closed.") + return current_loop + + except RuntimeError: + # If no event loop exists or it is closed, create a new one + logger.info("Creating a new event loop in main thread.") + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + return new_loop + + +def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: + """Lazily import a class from an external module based on the package of the caller.""" + # Get the caller's module and package + import inspect + + caller_frame = inspect.currentframe().f_back + module = inspect.getmodule(caller_frame) + package = module.__package__ if module else None + + def import_class(*args: Any, **kwargs: Any): + import importlib + + module = importlib.import_module(module_name, package=package) + cls = getattr(module, class_name) + return cls(*args, **kwargs) + + return import_class + \ No newline at end of file diff --git a/reproduce/Step_3.py b/reproduce/Step_3.py index f9ee3257..be5ba99d 100644 --- a/reproduce/Step_3.py +++ b/reproduce/Step_3.py @@ -1,7 +1,7 @@ import re import json -import asyncio from lightrag import LightRAG, QueryParam +from lightrag.utils import always_get_an_event_loop def extract_queries(file_path): @@ -23,14 +23,6 @@ async def process_query(query_text, rag_instance, query_param): return None, {"query": query_text, "error": str(e)} -def always_get_an_event_loop() -> asyncio.AbstractEventLoop: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - def run_queries_and_save_to_json( queries, rag_instance, query_param, output_file, error_file diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py index e4833adf..b1d33f93 100644 --- a/reproduce/Step_3_openai_compatible.py +++ b/reproduce/Step_3_openai_compatible.py @@ -1,10 +1,9 @@ import os import re import json -import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.utils import EmbeddingFunc +from lightrag.utils import EmbeddingFunc, always_get_an_event_loop import numpy as np @@ -55,13 +54,7 @@ async def process_query(query_text, rag_instance, query_param): return None, {"query": query_text, "error": str(e)} -def always_get_an_event_loop() -> asyncio.AbstractEventLoop: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop + def run_queries_and_save_to_json( From c7bc2c63cfaab68a263ebe14a626c24079a123b0 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:21:41 +0100 Subject: [PATCH 11/14] cleanup storages --- .../lightrag_openai_compatible_stream_demo.py | 2 - lightrag/kg/__init__.py | 137 +++++++++++++++- lightrag/lightrag.py | 147 +----------------- lightrag/utils.py | 6 +- reproduce/Step_3.py | 1 - reproduce/Step_3_openai_compatible.py | 3 - 6 files changed, 142 insertions(+), 154 deletions(-) diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py index 7509e4dc..750f139e 100644 --- a/examples/lightrag_openai_compatible_stream_demo.py +++ b/examples/lightrag_openai_compatible_stream_demo.py @@ -44,5 +44,3 @@ async def print_stream(stream): async for chunk in stream: if chunk: print(chunk, end="", flush=True) - - diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 087eaac9..2f3eae87 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -1 +1,136 @@ -# print ("init package vars here. ......") +STORAGE_IMPLEMENTATIONS = { + "KV_STORAGE": { + "implementations": [ + "JsonKVStorage", + "MongoKVStorage", + "RedisKVStorage", + "TiDBKVStorage", + "PGKVStorage", + "OracleKVStorage", + ], + "required_methods": ["get_by_id", "upsert"], + }, + "GRAPH_STORAGE": { + "implementations": [ + "NetworkXStorage", + "Neo4JStorage", + "MongoGraphStorage", + "TiDBGraphStorage", + "AGEStorage", + "GremlinStorage", + "PGGraphStorage", + "OracleGraphStorage", + ], + "required_methods": ["upsert_node", "upsert_edge"], + }, + "VECTOR_STORAGE": { + "implementations": [ + "NanoVectorDBStorage", + "MilvusVectorDBStorage", + "ChromaVectorDBStorage", + "TiDBVectorDBStorage", + "PGVectorStorage", + "FaissVectorDBStorage", + "QdrantVectorDBStorage", + "OracleVectorDBStorage", + "MongoVectorDBStorage", + ], + "required_methods": ["query", "upsert"], + }, + "DOC_STATUS_STORAGE": { + "implementations": [ + "JsonDocStatusStorage", + "PGDocStatusStorage", + "PGDocStatusStorage", + "MongoDocStatusStorage", + ], + "required_methods": ["get_docs_by_status"], + }, +} + +# Storage implementation environment variable without default value +STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { + # KV Storage Implementations + "JsonKVStorage": [], + "MongoKVStorage": [], + "RedisKVStorage": ["REDIS_URI"], + "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "OracleKVStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], + # Graph Storage Implementations + "NetworkXStorage": [], + "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], + "MongoGraphStorage": [], + "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + "AGEStorage": [ + "AGE_POSTGRES_DB", + "AGE_POSTGRES_USER", + "AGE_POSTGRES_PASSWORD", + ], + "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], + "PGGraphStorage": [ + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "POSTGRES_DATABASE", + ], + "OracleGraphStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], + # Vector Storage Implementations + "NanoVectorDBStorage": [], + "MilvusVectorDBStorage": [], + "ChromaVectorDBStorage": [], + "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "FaissVectorDBStorage": [], + "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None + "OracleVectorDBStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], + "MongoVectorDBStorage": [], + # Document Status Storage Implementations + "JsonDocStatusStorage": [], + "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "MongoDocStatusStorage": [], +} + +# Storage implementation module mapping +STORAGES = { + "NetworkXStorage": ".kg.networkx_impl", + "JsonKVStorage": ".kg.json_kv_impl", + "NanoVectorDBStorage": ".kg.nano_vector_db_impl", + "JsonDocStatusStorage": ".kg.json_doc_status_impl", + "Neo4JStorage": ".kg.neo4j_impl", + "OracleKVStorage": ".kg.oracle_impl", + "OracleGraphStorage": ".kg.oracle_impl", + "OracleVectorDBStorage": ".kg.oracle_impl", + "MilvusVectorDBStorage": ".kg.milvus_impl", + "MongoKVStorage": ".kg.mongo_impl", + "MongoDocStatusStorage": ".kg.mongo_impl", + "MongoGraphStorage": ".kg.mongo_impl", + "MongoVectorDBStorage": ".kg.mongo_impl", + "RedisKVStorage": ".kg.redis_impl", + "ChromaVectorDBStorage": ".kg.chroma_impl", + "TiDBKVStorage": ".kg.tidb_impl", + "TiDBVectorDBStorage": ".kg.tidb_impl", + "TiDBGraphStorage": ".kg.tidb_impl", + "PGKVStorage": ".kg.postgres_impl", + "PGVectorStorage": ".kg.postgres_impl", + "AGEStorage": ".kg.age_impl", + "PGGraphStorage": ".kg.postgres_impl", + "GremlinStorage": ".kg.gremlin_impl", + "PGDocStatusStorage": ".kg.postgres_impl", + "FaissVectorDBStorage": ".kg.faiss_impl", + "QdrantVectorDBStorage": ".kg.qdrant_impl", +} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8b695883..174947f3 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -8,6 +8,8 @@ from datetime import datetime from functools import partial from typing import Any, AsyncIterator, Callable, Iterator, cast, final +from lightrag.kg import STORAGE_ENV_REQUIREMENTS, STORAGE_IMPLEMENTATIONS, STORAGES + from .base import ( BaseGraphStorage, BaseKVStorage, @@ -45,149 +47,6 @@ from .utils import ( config = configparser.ConfigParser() config.read("config.ini", "utf-8") -# Storage type and implementation compatibility validation table -STORAGE_IMPLEMENTATIONS = { - "KV_STORAGE": { - "implementations": [ - "JsonKVStorage", - "MongoKVStorage", - "RedisKVStorage", - "TiDBKVStorage", - "PGKVStorage", - "OracleKVStorage", - ], - "required_methods": ["get_by_id", "upsert"], - }, - "GRAPH_STORAGE": { - "implementations": [ - "NetworkXStorage", - "Neo4JStorage", - "MongoGraphStorage", - "TiDBGraphStorage", - "AGEStorage", - "GremlinStorage", - "PGGraphStorage", - "OracleGraphStorage", - ], - "required_methods": ["upsert_node", "upsert_edge"], - }, - "VECTOR_STORAGE": { - "implementations": [ - "NanoVectorDBStorage", - "MilvusVectorDBStorage", - "ChromaVectorDBStorage", - "TiDBVectorDBStorage", - "PGVectorStorage", - "FaissVectorDBStorage", - "QdrantVectorDBStorage", - "OracleVectorDBStorage", - "MongoVectorDBStorage", - ], - "required_methods": ["query", "upsert"], - }, - "DOC_STATUS_STORAGE": { - "implementations": [ - "JsonDocStatusStorage", - "PGDocStatusStorage", - "PGDocStatusStorage", - "MongoDocStatusStorage", - ], - "required_methods": ["get_docs_by_status"], - }, -} - -# Storage implementation environment variable without default value -STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { - # KV Storage Implementations - "JsonKVStorage": [], - "MongoKVStorage": [], - "RedisKVStorage": ["REDIS_URI"], - "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], - "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "OracleKVStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], - # Graph Storage Implementations - "NetworkXStorage": [], - "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], - "MongoGraphStorage": [], - "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], - "AGEStorage": [ - "AGE_POSTGRES_DB", - "AGE_POSTGRES_USER", - "AGE_POSTGRES_PASSWORD", - ], - "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], - "PGGraphStorage": [ - "POSTGRES_USER", - "POSTGRES_PASSWORD", - "POSTGRES_DATABASE", - ], - "OracleGraphStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], - # Vector Storage Implementations - "NanoVectorDBStorage": [], - "MilvusVectorDBStorage": [], - "ChromaVectorDBStorage": [], - "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], - "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "FaissVectorDBStorage": [], - "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None - "OracleVectorDBStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], - "MongoVectorDBStorage": [], - # Document Status Storage Implementations - "JsonDocStatusStorage": [], - "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "MongoDocStatusStorage": [], -} - -# Storage implementation module mapping -STORAGES = { - "NetworkXStorage": ".kg.networkx_impl", - "JsonKVStorage": ".kg.json_kv_impl", - "NanoVectorDBStorage": ".kg.nano_vector_db_impl", - "JsonDocStatusStorage": ".kg.json_doc_status_impl", - "Neo4JStorage": ".kg.neo4j_impl", - "OracleKVStorage": ".kg.oracle_impl", - "OracleGraphStorage": ".kg.oracle_impl", - "OracleVectorDBStorage": ".kg.oracle_impl", - "MilvusVectorDBStorage": ".kg.milvus_impl", - "MongoKVStorage": ".kg.mongo_impl", - "MongoDocStatusStorage": ".kg.mongo_impl", - "MongoGraphStorage": ".kg.mongo_impl", - "MongoVectorDBStorage": ".kg.mongo_impl", - "RedisKVStorage": ".kg.redis_impl", - "ChromaVectorDBStorage": ".kg.chroma_impl", - "TiDBKVStorage": ".kg.tidb_impl", - "TiDBVectorDBStorage": ".kg.tidb_impl", - "TiDBGraphStorage": ".kg.tidb_impl", - "PGKVStorage": ".kg.postgres_impl", - "PGVectorStorage": ".kg.postgres_impl", - "AGEStorage": ".kg.age_impl", - "PGGraphStorage": ".kg.postgres_impl", - "GremlinStorage": ".kg.gremlin_impl", - "PGDocStatusStorage": ".kg.postgres_impl", - "FaissVectorDBStorage": ".kg.faiss_impl", - "QdrantVectorDBStorage": ".kg.qdrant_impl", -} - - - - - - @final @dataclass @@ -1643,4 +1502,4 @@ class LightRAG: raise ValueError( f"Storage implementation '{storage_name}' requires the following " f"environment variables: {', '.join(missing_vars)}" - ) \ No newline at end of file + ) diff --git a/lightrag/utils.py b/lightrag/utils.py index 62f62d4d..d402d14c 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -714,6 +714,7 @@ def get_conversation_turns( return "\n".join(formatted_turns) + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. @@ -737,8 +738,8 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop: new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) return new_loop - - + + def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: """Lazily import a class from an external module based on the package of the caller.""" # Get the caller's module and package @@ -756,4 +757,3 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any return cls(*args, **kwargs) return import_class - \ No newline at end of file diff --git a/reproduce/Step_3.py b/reproduce/Step_3.py index be5ba99d..facb913e 100644 --- a/reproduce/Step_3.py +++ b/reproduce/Step_3.py @@ -23,7 +23,6 @@ async def process_query(query_text, rag_instance, query_param): return None, {"query": query_text, "error": str(e)} - def run_queries_and_save_to_json( queries, rag_instance, query_param, output_file, error_file ): diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py index b1d33f93..885220fa 100644 --- a/reproduce/Step_3_openai_compatible.py +++ b/reproduce/Step_3_openai_compatible.py @@ -54,9 +54,6 @@ async def process_query(query_text, rag_instance, query_param): return None, {"query": query_text, "error": str(e)} - - - def run_queries_and_save_to_json( queries, rag_instance, query_param, output_file, error_file ): From 59bb75d4a1b3552381a6f53a02753fd818608ec3 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:27:55 +0100 Subject: [PATCH 12/14] added log path --- lightrag/lightrag.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 174947f3..9f4db5ab 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -82,8 +82,8 @@ class LightRAG: log_level: int = field(default=logger.level) """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" - log_dir: str = field(default=os.getcwd()) - """Directory where logs are stored. Defaults to the current working directory.""" + log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log")) + """Log file path.""" # Entity extraction # --- @@ -251,9 +251,8 @@ class LightRAG: """ def __post_init__(self): - os.makedirs(self.log_dir, exist_ok=True) - log_file = os.path.join(self.log_dir, "lightrag.log") - set_logger(log_file) + os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) + set_logger(self.log_file_path) logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") From 60717fd6be185cc3d4a93f57473a7091adc5b327 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:30:30 +0100 Subject: [PATCH 13/14] cleanup storage state --- lightrag/lightrag.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 9f4db5ab..1a8dcf5c 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -250,12 +250,14 @@ class LightRAG: The default function is :func:`.utils.convert_response_to_json`. """ + _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) + def __post_init__(self): + logger.setLevel(self.log_level) os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) set_logger(self.log_file_path) - - logger.setLevel(self.log_level) logger.info(f"Logger initialized for working directory: {self.working_dir}") + if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) @@ -283,9 +285,6 @@ class LightRAG: **self.vector_db_storage_cls_kwargs, } - # Life cycle - self.storages_status = StoragesStatus.NOT_CREATED - # Show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) @@ -393,7 +392,7 @@ class LightRAG: ) ) - self.storages_status = StoragesStatus.CREATED + self._storages_status = StoragesStatus.CREATED # Initialize storages if self.auto_manage_storages_states: @@ -408,7 +407,7 @@ class LightRAG: async def initialize_storages(self): """Asynchronously initialize the storages""" - if self.storages_status == StoragesStatus.CREATED: + if self._storages_status == StoragesStatus.CREATED: tasks = [] for storage in ( @@ -426,12 +425,12 @@ class LightRAG: await asyncio.gather(*tasks) - self.storages_status = StoragesStatus.INITIALIZED + self._storages_status = StoragesStatus.INITIALIZED logger.debug("Initialized Storages") async def finalize_storages(self): """Asynchronously finalize the storages""" - if self.storages_status == StoragesStatus.INITIALIZED: + if self._storages_status == StoragesStatus.INITIALIZED: tasks = [] for storage in ( @@ -449,7 +448,7 @@ class LightRAG: await asyncio.gather(*tasks) - self.storages_status = StoragesStatus.FINALIZED + self._storages_status = StoragesStatus.FINALIZED logger.debug("Finalized Storages") def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: From 38dc2466dade429e1a54c07103779aff071a2012 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Thu, 20 Feb 2025 13:34:47 +0100 Subject: [PATCH 14/14] cleanup --- examples/lightrag_openai_compatible_stream_demo.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py index 750f139e..a974ca14 100644 --- a/examples/lightrag_openai_compatible_stream_demo.py +++ b/examples/lightrag_openai_compatible_stream_demo.py @@ -1,7 +1,8 @@ +import inspect import os from lightrag import LightRAG from lightrag.llm import openai_complete, openai_embed -from lightrag.utils import EmbeddingFunc +from lightrag.utils import EmbeddingFunc, always_get_an_event_loop from lightrag import QueryParam # WorkingDir @@ -44,3 +45,10 @@ async def print_stream(stream): async for chunk in stream: if chunk: print(chunk, end="", flush=True) + + +loop = always_get_an_event_loop() +if inspect.isasyncgen(resp): + loop.run_until_complete(print_stream(resp)) +else: + print(resp)