diff --git a/docker-compose.yml b/docker-compose.yml index b5659692..4ced24ca 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3.8' - services: lightrag: build: . diff --git a/examples/lightrag_api_oracle_demo.py b/examples/lightrag_api_oracle_demo.py index e66e3f94..3675795e 100644 --- a/examples/lightrag_api_oracle_demo.py +++ b/examples/lightrag_api_oracle_demo.py @@ -98,7 +98,6 @@ async def init(): # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage - # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, diff --git a/examples/lightrag_openai_compatible_stream_demo.py b/examples/lightrag_openai_compatible_stream_demo.py index 93c4297c..a974ca14 100644 --- a/examples/lightrag_openai_compatible_stream_demo.py +++ b/examples/lightrag_openai_compatible_stream_demo.py @@ -1,9 +1,8 @@ -import os import inspect +import os from lightrag import LightRAG from lightrag.llm import openai_complete, openai_embed -from lightrag.utils import EmbeddingFunc -from lightrag.lightrag import always_get_an_event_loop +from lightrag.utils import EmbeddingFunc, always_get_an_event_loop from lightrag import QueryParam # WorkingDir diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index f4004f84..f2ee9ad8 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -63,7 +63,6 @@ async def main(): # Initialize LightRAG # We use TiDB DB as the KV/vector - # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt rag = LightRAG( enable_llm_cache=False, working_dir=WORKING_DIR, diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 087eaac9..2f3eae87 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -1 +1,136 @@ -# print ("init package vars here. ......") +STORAGE_IMPLEMENTATIONS = { + "KV_STORAGE": { + "implementations": [ + "JsonKVStorage", + "MongoKVStorage", + "RedisKVStorage", + "TiDBKVStorage", + "PGKVStorage", + "OracleKVStorage", + ], + "required_methods": ["get_by_id", "upsert"], + }, + "GRAPH_STORAGE": { + "implementations": [ + "NetworkXStorage", + "Neo4JStorage", + "MongoGraphStorage", + "TiDBGraphStorage", + "AGEStorage", + "GremlinStorage", + "PGGraphStorage", + "OracleGraphStorage", + ], + "required_methods": ["upsert_node", "upsert_edge"], + }, + "VECTOR_STORAGE": { + "implementations": [ + "NanoVectorDBStorage", + "MilvusVectorDBStorage", + "ChromaVectorDBStorage", + "TiDBVectorDBStorage", + "PGVectorStorage", + "FaissVectorDBStorage", + "QdrantVectorDBStorage", + "OracleVectorDBStorage", + "MongoVectorDBStorage", + ], + "required_methods": ["query", "upsert"], + }, + "DOC_STATUS_STORAGE": { + "implementations": [ + "JsonDocStatusStorage", + "PGDocStatusStorage", + "PGDocStatusStorage", + "MongoDocStatusStorage", + ], + "required_methods": ["get_docs_by_status"], + }, +} + +# Storage implementation environment variable without default value +STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { + # KV Storage Implementations + "JsonKVStorage": [], + "MongoKVStorage": [], + "RedisKVStorage": ["REDIS_URI"], + "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "OracleKVStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], + # Graph Storage Implementations + "NetworkXStorage": [], + "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], + "MongoGraphStorage": [], + "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + "AGEStorage": [ + "AGE_POSTGRES_DB", + "AGE_POSTGRES_USER", + "AGE_POSTGRES_PASSWORD", + ], + "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], + "PGGraphStorage": [ + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "POSTGRES_DATABASE", + ], + "OracleGraphStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], + # Vector Storage Implementations + "NanoVectorDBStorage": [], + "MilvusVectorDBStorage": [], + "ChromaVectorDBStorage": [], + "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], + "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "FaissVectorDBStorage": [], + "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None + "OracleVectorDBStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], + "MongoVectorDBStorage": [], + # Document Status Storage Implementations + "JsonDocStatusStorage": [], + "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "MongoDocStatusStorage": [], +} + +# Storage implementation module mapping +STORAGES = { + "NetworkXStorage": ".kg.networkx_impl", + "JsonKVStorage": ".kg.json_kv_impl", + "NanoVectorDBStorage": ".kg.nano_vector_db_impl", + "JsonDocStatusStorage": ".kg.json_doc_status_impl", + "Neo4JStorage": ".kg.neo4j_impl", + "OracleKVStorage": ".kg.oracle_impl", + "OracleGraphStorage": ".kg.oracle_impl", + "OracleVectorDBStorage": ".kg.oracle_impl", + "MilvusVectorDBStorage": ".kg.milvus_impl", + "MongoKVStorage": ".kg.mongo_impl", + "MongoDocStatusStorage": ".kg.mongo_impl", + "MongoGraphStorage": ".kg.mongo_impl", + "MongoVectorDBStorage": ".kg.mongo_impl", + "RedisKVStorage": ".kg.redis_impl", + "ChromaVectorDBStorage": ".kg.chroma_impl", + "TiDBKVStorage": ".kg.tidb_impl", + "TiDBVectorDBStorage": ".kg.tidb_impl", + "TiDBGraphStorage": ".kg.tidb_impl", + "PGKVStorage": ".kg.postgres_impl", + "PGVectorStorage": ".kg.postgres_impl", + "AGEStorage": ".kg.age_impl", + "PGGraphStorage": ".kg.postgres_impl", + "GremlinStorage": ".kg.gremlin_impl", + "PGDocStatusStorage": ".kg.postgres_impl", + "FaissVectorDBStorage": ".kg.faiss_impl", + "QdrantVectorDBStorage": ".kg.qdrant_impl", +} diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index de61a2ca..35983ad3 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -44,7 +44,7 @@ class OracleDB: self.increment = 1 logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") if self.user is None or self.password is None: - raise ValueError("Missing database user or password in addon_params") + raise ValueError("Missing database user or password") try: oracledb.defaults.fetch_lobs = False diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index ababc05f..d7ace41a 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -54,9 +54,7 @@ class PostgreSQLDB: self.pool: Pool | None = None if self.user is None or self.password is None or self.database is None: - raise ValueError( - "Missing database user, password, or database in addon_params" - ) + raise ValueError("Missing database user, password, or database") async def initdb(self): try: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a4daeced..1a8dcf5c 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -6,8 +6,10 @@ import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial -from typing import Any, AsyncIterator, Callable, Iterator, cast -from asyncio import Lock +from typing import Any, AsyncIterator, Callable, Iterator, cast, final + +from lightrag.kg import STORAGE_ENV_REQUIREMENTS, STORAGE_IMPLEMENTATIONS, STORAGES + from .base import ( BaseGraphStorage, BaseKVStorage, @@ -32,8 +34,10 @@ from .operate import ( from .prompt import GRAPH_FIELD_SEP from .utils import ( EmbeddingFunc, + always_get_an_event_loop, compute_mdhash_id, convert_response_to_json, + lazy_external_import, limit_async_func_call, logger, set_logger, @@ -43,210 +47,22 @@ from .utils import ( config = configparser.ConfigParser() config.read("config.ini", "utf-8") -# Storage type and implementation compatibility validation table -STORAGE_IMPLEMENTATIONS = { - "KV_STORAGE": { - "implementations": [ - "JsonKVStorage", - "MongoKVStorage", - "RedisKVStorage", - "TiDBKVStorage", - "PGKVStorage", - "OracleKVStorage", - ], - "required_methods": ["get_by_id", "upsert"], - }, - "GRAPH_STORAGE": { - "implementations": [ - "NetworkXStorage", - "Neo4JStorage", - "MongoGraphStorage", - "TiDBGraphStorage", - "AGEStorage", - "GremlinStorage", - "PGGraphStorage", - "OracleGraphStorage", - ], - "required_methods": ["upsert_node", "upsert_edge"], - }, - "VECTOR_STORAGE": { - "implementations": [ - "NanoVectorDBStorage", - "MilvusVectorDBStorage", - "ChromaVectorDBStorage", - "TiDBVectorDBStorage", - "PGVectorStorage", - "FaissVectorDBStorage", - "QdrantVectorDBStorage", - "OracleVectorDBStorage", - "MongoVectorDBStorage", - ], - "required_methods": ["query", "upsert"], - }, - "DOC_STATUS_STORAGE": { - "implementations": [ - "JsonDocStatusStorage", - "PGDocStatusStorage", - "PGDocStatusStorage", - "MongoDocStatusStorage", - ], - "required_methods": ["get_docs_by_status"], - }, -} - -# Storage implementation environment variable without default value -STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { - # KV Storage Implementations - "JsonKVStorage": [], - "MongoKVStorage": [], - "RedisKVStorage": ["REDIS_URI"], - "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], - "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "OracleKVStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], - # Graph Storage Implementations - "NetworkXStorage": [], - "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], - "MongoGraphStorage": [], - "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], - "AGEStorage": [ - "AGE_POSTGRES_DB", - "AGE_POSTGRES_USER", - "AGE_POSTGRES_PASSWORD", - ], - "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], - "PGGraphStorage": [ - "POSTGRES_USER", - "POSTGRES_PASSWORD", - "POSTGRES_DATABASE", - ], - "OracleGraphStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], - # Vector Storage Implementations - "NanoVectorDBStorage": [], - "MilvusVectorDBStorage": [], - "ChromaVectorDBStorage": [], - "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], - "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "FaissVectorDBStorage": [], - "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None - "OracleVectorDBStorage": [ - "ORACLE_DSN", - "ORACLE_USER", - "ORACLE_PASSWORD", - "ORACLE_CONFIG_DIR", - ], - "MongoVectorDBStorage": [], - # Document Status Storage Implementations - "JsonDocStatusStorage": [], - "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "MongoDocStatusStorage": [], -} - -# Storage implementation module mapping -STORAGES = { - "NetworkXStorage": ".kg.networkx_impl", - "JsonKVStorage": ".kg.json_kv_impl", - "NanoVectorDBStorage": ".kg.nano_vector_db_impl", - "JsonDocStatusStorage": ".kg.json_doc_status_impl", - "Neo4JStorage": ".kg.neo4j_impl", - "OracleKVStorage": ".kg.oracle_impl", - "OracleGraphStorage": ".kg.oracle_impl", - "OracleVectorDBStorage": ".kg.oracle_impl", - "MilvusVectorDBStorage": ".kg.milvus_impl", - "MongoKVStorage": ".kg.mongo_impl", - "MongoDocStatusStorage": ".kg.mongo_impl", - "MongoGraphStorage": ".kg.mongo_impl", - "MongoVectorDBStorage": ".kg.mongo_impl", - "RedisKVStorage": ".kg.redis_impl", - "ChromaVectorDBStorage": ".kg.chroma_impl", - "TiDBKVStorage": ".kg.tidb_impl", - "TiDBVectorDBStorage": ".kg.tidb_impl", - "TiDBGraphStorage": ".kg.tidb_impl", - "PGKVStorage": ".kg.postgres_impl", - "PGVectorStorage": ".kg.postgres_impl", - "AGEStorage": ".kg.age_impl", - "PGGraphStorage": ".kg.postgres_impl", - "GremlinStorage": ".kg.gremlin_impl", - "PGDocStatusStorage": ".kg.postgres_impl", - "FaissVectorDBStorage": ".kg.faiss_impl", - "QdrantVectorDBStorage": ".kg.qdrant_impl", -} - - -def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: - """Lazily import a class from an external module based on the package of the caller.""" - # Get the caller's module and package - import inspect - - caller_frame = inspect.currentframe().f_back - module = inspect.getmodule(caller_frame) - package = module.__package__ if module else None - - def import_class(*args: Any, **kwargs: Any): - import importlib - - module = importlib.import_module(module_name, package=package) - cls = getattr(module, class_name) - return cls(*args, **kwargs) - - return import_class - - -def always_get_an_event_loop() -> asyncio.AbstractEventLoop: - """ - Ensure that there is always an event loop available. - - This function tries to get the current event loop. If the current event loop is closed or does not exist, - it creates a new event loop and sets it as the current event loop. - - Returns: - asyncio.AbstractEventLoop: The current or newly created event loop. - """ - try: - # Try to get the current event loop - current_loop = asyncio.get_event_loop() - if current_loop.is_closed(): - raise RuntimeError("Event loop is closed.") - return current_loop - - except RuntimeError: - # If no event loop exists or it is closed, create a new one - logger.info("Creating a new event loop in main thread.") - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - return new_loop - +@final @dataclass class LightRAG: """LightRAG: Simple and Fast Retrieval-Augmented Generation.""" + # Directory + # --- + working_dir: str = field( - default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" + default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" ) """Directory where cache and temporary files are stored.""" - embedding_cache_config: dict[str, Any] = field( - default_factory=lambda: { - "enabled": False, - "similarity_threshold": 0.95, - "use_llm_check": False, - } - ) - """Configuration for embedding cache. - - enabled: If True, enables caching to avoid redundant computations. - - similarity_threshold: Minimum similarity score to use cached embeddings. - - use_llm_check: If True, validates cached embeddings using an LLM. - """ + # Storage + # --- kv_storage: str = field(default="JsonKVStorage") """Storage backend for key-value data.""" @@ -261,32 +77,74 @@ class LightRAG: """Storage type for tracking document processing statuses.""" # Logging - current_log_level = logger.level - log_level: int = field(default=current_log_level) + # --- + + log_level: int = field(default=logger.level) """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" - log_dir: str = field(default=os.getcwd()) - """Directory where logs are stored. Defaults to the current working directory.""" - - # Text chunking - chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200")) - """Maximum number of tokens per text chunk when splitting documents.""" - - chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100")) - """Number of overlapping tokens between consecutive text chunks to preserve context.""" - - tiktoken_model_name: str = "gpt-4o-mini" - """Model name used for tokenization when chunking text.""" + log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log")) + """Log file path.""" # Entity extraction - entity_extract_max_gleaning: int = 1 + # --- + + entity_extract_max_gleaning: int = field(default=1) """Maximum number of entity extraction attempts for ambiguous content.""" - entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500")) + entity_summary_to_max_tokens: int = field( + default=int(os.getenv("MAX_TOKEN_SUMMARY", 500)) + ) + + # Text chunking + # --- + + chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200))) + """Maximum number of tokens per text chunk when splitting documents.""" + + chunk_overlap_token_size: int = field( + default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100)) + ) + """Number of overlapping tokens between consecutive text chunks to preserve context.""" + + tiktoken_model_name: str = field(default="gpt-4o-mini") + """Model name used for tokenization when chunking text.""" + """Maximum number of tokens used for summarizing extracted entities.""" + chunking_func: Callable[ + [ + str, + str | None, + bool, + int, + int, + str, + ], + list[dict[str, Any]], + ] = field(default_factory=lambda: chunking_by_token_size) + """ + Custom chunking function for splitting text into chunks before processing. + + The function should take the following parameters: + + - `content`: The text to be split into chunks. + - `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens. + - `split_by_character_only`: If True, the text is split only on the specified character. + - `chunk_token_size`: The maximum number of tokens per chunk. + - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks. + - `tiktoken_model_name`: The name of the tiktoken model to use for tokenization. + + The function should return a list of dictionaries, where each dictionary contains the following keys: + - `tokens`: The number of tokens in the chunk. + - `content`: The text content of the chunk. + + Defaults to `chunking_by_token_size` if not specified. + """ + # Node embedding - node_embedding_algorithm: str = "node2vec" + # --- + + node_embedding_algorithm: str = field(default="node2vec") """Algorithm used for node embedding in knowledge graphs.""" node2vec_params: dict[str, int] = field( @@ -308,119 +166,98 @@ class LightRAG: - random_seed: Seed value for reproducibility. """ - embedding_func: EmbeddingFunc | None = None + # Embedding + # --- + + embedding_func: EmbeddingFunc | None = field(default=None) """Function for computing text embeddings. Must be set before use.""" - embedding_batch_num: int = 32 + embedding_batch_num: int = field(default=32) """Batch size for embedding computations.""" - embedding_func_max_async: int = 16 + embedding_func_max_async: int = field(default=16) """Maximum number of concurrent embedding function calls.""" + embedding_cache_config: dict[str, Any] = field( + default={ + "enabled": False, + "similarity_threshold": 0.95, + "use_llm_check": False, + } + ) + """Configuration for embedding cache. + - enabled: If True, enables caching to avoid redundant computations. + - similarity_threshold: Minimum similarity score to use cached embeddings. + - use_llm_check: If True, validates cached embeddings using an LLM. + """ + # LLM Configuration - llm_model_func: Callable[..., object] | None = None + # --- + + llm_model_func: Callable[..., object] | None = field(default=None) """Function for interacting with the large language model (LLM). Must be set before use.""" - llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" + llm_model_name: str = field(default="gpt-4o-mini") """Name of the LLM model used for generating responses.""" - llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768")) + llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768))) """Maximum number of tokens allowed per LLM response.""" - llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) + llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16))) """Maximum number of concurrent LLM calls.""" llm_model_kwargs: dict[str, Any] = field(default_factory=dict) """Additional keyword arguments passed to the LLM model function.""" # Storage + # --- + vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) """Additional parameters for vector database storage.""" namespace_prefix: str = field(default="") """Prefix for namespacing stored data across different environments.""" - enable_llm_cache: bool = True + enable_llm_cache: bool = field(default=True) """Enables caching for LLM responses to avoid redundant computations.""" - enable_llm_cache_for_entity_extract: bool = True + enable_llm_cache_for_entity_extract: bool = field(default=True) """If True, enables caching for entity extraction steps to reduce LLM costs.""" # Extensions + # --- + + max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20))) + """Maximum number of parallel insert operations.""" + addon_params: dict[str, Any] = field(default_factory=dict) # Storages Management - auto_manage_storages_states: bool = True + # --- + + auto_manage_storages_states: bool = field(default=True) """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" - """Dictionary for additional parameters and extensions.""" - convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( - convert_response_to_json + # Storages Management + # --- + + convert_response_to_json_func: Callable[[str], dict[str, Any]] = field( + default_factory=lambda: convert_response_to_json ) + """ + Custom function for converting LLM responses to JSON format. - # Lock for entity extraction - _entity_lock = Lock() + The default function is :func:`.utils.convert_response_to_json`. + """ - # Custom Chunking Function - chunking_func: Callable[ - [ - str, - str | None, - bool, - int, - int, - str, - ], - list[dict[str, Any]], - ] = chunking_by_token_size - - def verify_storage_implementation( - self, storage_type: str, storage_name: str - ) -> None: - """Verify if storage implementation is compatible with specified storage type - - Args: - storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.) - storage_name: Storage implementation name - - Raises: - ValueError: If storage implementation is incompatible or missing required methods - """ - if storage_type not in STORAGE_IMPLEMENTATIONS: - raise ValueError(f"Unknown storage type: {storage_type}") - - storage_info = STORAGE_IMPLEMENTATIONS[storage_type] - if storage_name not in storage_info["implementations"]: - raise ValueError( - f"Storage implementation '{storage_name}' is not compatible with {storage_type}. " - f"Compatible implementations are: {', '.join(storage_info['implementations'])}" - ) - - def check_storage_env_vars(self, storage_name: str) -> None: - """Check if all required environment variables for storage implementation exist - - Args: - storage_name: Storage implementation name - - Raises: - ValueError: If required environment variables are missing - """ - required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) - missing_vars = [var for var in required_vars if var not in os.environ] - - if missing_vars: - raise ValueError( - f"Storage implementation '{storage_name}' requires the following " - f"environment variables: {', '.join(missing_vars)}" - ) + _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED) def __post_init__(self): - os.makedirs(self.log_dir, exist_ok=True) - log_file = os.path.join(self.log_dir, "lightrag.log") - set_logger(log_file) - logger.setLevel(self.log_level) + os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True) + set_logger(self.log_file_path) logger.info(f"Logger initialized for working directory: {self.working_dir}") + if not os.path.exists(self.working_dir): logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) @@ -448,9 +285,6 @@ class LightRAG: **self.vector_db_storage_cls_kwargs, } - # Life cycle - self.storages_status = StoragesStatus.NOT_CREATED - # Show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) @@ -558,7 +392,7 @@ class LightRAG: ) ) - self.storages_status = StoragesStatus.CREATED + self._storages_status = StoragesStatus.CREATED # Initialize storages if self.auto_manage_storages_states: @@ -573,7 +407,7 @@ class LightRAG: async def initialize_storages(self): """Asynchronously initialize the storages""" - if self.storages_status == StoragesStatus.CREATED: + if self._storages_status == StoragesStatus.CREATED: tasks = [] for storage in ( @@ -591,12 +425,12 @@ class LightRAG: await asyncio.gather(*tasks) - self.storages_status = StoragesStatus.INITIALIZED + self._storages_status = StoragesStatus.INITIALIZED logger.debug("Initialized Storages") async def finalize_storages(self): """Asynchronously finalize the storages""" - if self.storages_status == StoragesStatus.INITIALIZED: + if self._storages_status == StoragesStatus.INITIALIZED: tasks = [] for storage in ( @@ -614,7 +448,7 @@ class LightRAG: await asyncio.gather(*tasks) - self.storages_status = StoragesStatus.FINALIZED + self._storages_status = StoragesStatus.FINALIZED logger.debug("Finalized Storages") def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: @@ -789,10 +623,9 @@ class LightRAG: return # 2. split docs into chunks, insert chunks, update doc status - batch_size = self.addon_params.get("insert_batch_size", 10) docs_batches = [ - list(to_process_docs.items())[i : i + batch_size] - for i in range(0, len(to_process_docs), batch_size) + list(to_process_docs.items())[i : i + self.max_parallel_insert] + for i in range(0, len(to_process_docs), self.max_parallel_insert) ] logger.info(f"Number of batches to process: {len(docs_batches)}.") @@ -1203,7 +1036,6 @@ class LightRAG: # --------------------- # STEP 1: Keyword Extraction # --------------------- - # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords). hl_keywords, ll_keywords = await extract_keywords_only( text=query, param=param, @@ -1629,3 +1461,43 @@ class LightRAG: result["vector_data"] = vector_data[0] if vector_data else None return result + + def verify_storage_implementation( + self, storage_type: str, storage_name: str + ) -> None: + """Verify if storage implementation is compatible with specified storage type + + Args: + storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.) + storage_name: Storage implementation name + + Raises: + ValueError: If storage implementation is incompatible or missing required methods + """ + if storage_type not in STORAGE_IMPLEMENTATIONS: + raise ValueError(f"Unknown storage type: {storage_type}") + + storage_info = STORAGE_IMPLEMENTATIONS[storage_type] + if storage_name not in storage_info["implementations"]: + raise ValueError( + f"Storage implementation '{storage_name}' is not compatible with {storage_type}. " + f"Compatible implementations are: {', '.join(storage_info['implementations'])}" + ) + + def check_storage_env_vars(self, storage_name: str) -> None: + """Check if all required environment variables for storage implementation exist + + Args: + storage_name: Storage implementation name + + Raises: + ValueError: If required environment variables are missing + """ + required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + raise ValueError( + f"Storage implementation '{storage_name}' requires the following " + f"environment variables: {', '.join(missing_vars)}" + ) diff --git a/lightrag/utils.py b/lightrag/utils.py index d932f149..d402d14c 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -713,3 +713,47 @@ def get_conversation_turns( ) return "\n".join(formatted_turns) + + +def always_get_an_event_loop() -> asyncio.AbstractEventLoop: + """ + Ensure that there is always an event loop available. + + This function tries to get the current event loop. If the current event loop is closed or does not exist, + it creates a new event loop and sets it as the current event loop. + + Returns: + asyncio.AbstractEventLoop: The current or newly created event loop. + """ + try: + # Try to get the current event loop + current_loop = asyncio.get_event_loop() + if current_loop.is_closed(): + raise RuntimeError("Event loop is closed.") + return current_loop + + except RuntimeError: + # If no event loop exists or it is closed, create a new one + logger.info("Creating a new event loop in main thread.") + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + return new_loop + + +def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: + """Lazily import a class from an external module based on the package of the caller.""" + # Get the caller's module and package + import inspect + + caller_frame = inspect.currentframe().f_back + module = inspect.getmodule(caller_frame) + package = module.__package__ if module else None + + def import_class(*args: Any, **kwargs: Any): + import importlib + + module = importlib.import_module(module_name, package=package) + cls = getattr(module, class_name) + return cls(*args, **kwargs) + + return import_class diff --git a/reproduce/Step_3.py b/reproduce/Step_3.py index f9ee3257..facb913e 100644 --- a/reproduce/Step_3.py +++ b/reproduce/Step_3.py @@ -1,7 +1,7 @@ import re import json -import asyncio from lightrag import LightRAG, QueryParam +from lightrag.utils import always_get_an_event_loop def extract_queries(file_path): @@ -23,15 +23,6 @@ async def process_query(query_text, rag_instance, query_param): return None, {"query": query_text, "error": str(e)} -def always_get_an_event_loop() -> asyncio.AbstractEventLoop: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - - def run_queries_and_save_to_json( queries, rag_instance, query_param, output_file, error_file ): diff --git a/reproduce/Step_3_openai_compatible.py b/reproduce/Step_3_openai_compatible.py index e4833adf..885220fa 100644 --- a/reproduce/Step_3_openai_compatible.py +++ b/reproduce/Step_3_openai_compatible.py @@ -1,10 +1,9 @@ import os import re import json -import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.utils import EmbeddingFunc +from lightrag.utils import EmbeddingFunc, always_get_an_event_loop import numpy as np @@ -55,15 +54,6 @@ async def process_query(query_text, rag_instance, query_param): return None, {"query": query_text, "error": str(e)} -def always_get_an_event_loop() -> asyncio.AbstractEventLoop: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - - def run_queries_and_save_to_json( queries, rag_instance, query_param, output_file, error_file ):