diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index f34fe4b1..c2859829 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -172,8 +172,8 @@ class OracleDB: @dataclass class OracleKVStorage(BaseKVStorage): - # should pass db object to self.db - db: OracleDB = None + # db instance must be injected before use + # db: OracleDB meta_fields = None def __post_init__(self): @@ -318,8 +318,8 @@ class OracleKVStorage(BaseKVStorage): @dataclass class OracleVectorDBStorage(BaseVectorStorage): - # should pass db object to self.db - db: OracleDB = None + # db instance must be injected before use + # db: OracleDB cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): @@ -361,8 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage): @dataclass class OracleGraphStorage(BaseGraphStorage): - # should pass db object to self.db - db: OracleDB = None + # db instance must be injected before use + # db: OracleDB def __post_init__(self): """从graphml文件加载图""" diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 526f54a7..221202ab 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -177,7 +177,8 @@ class PostgreSQLDB: @dataclass class PGKVStorage(BaseKVStorage): - db: PostgreSQLDB = None + # db instance must be injected before use + # db: PostgreSQLDB def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] @@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage): @dataclass class PGVectorStorage(BaseVectorStorage): + # db instance must be injected before use + # db: PostgreSQLDB cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) - db: PostgreSQLDB = None def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] @@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage): class PGDocStatusStorage(DocStatusStorage): """PostgreSQL implementation of document status storage""" - db: PostgreSQLDB = None - - def __post_init__(self): - pass + # db instance must be injected before use + db: PostgreSQLDB async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that don't exist in storage""" @@ -577,19 +577,15 @@ class PGGraphQueryException(Exception): @dataclass class PGGraphStorage(BaseGraphStorage): - db: PostgreSQLDB = None + # db instance must be injected before use + # db: PostgreSQLDB @staticmethod def load_nx_graph(file_name): print("no preloading of graph with AGE in production") - def __init__(self, namespace, global_config, embedding_func): - super().__init__( - namespace=namespace, - global_config=global_config, - embedding_func=embedding_func, - ) - self.graph_name = os.environ["AGE_GRAPH_NAME"] + def __post_init__(self): + self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag") self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 7c75e2d8..ba5a6240 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -101,8 +101,8 @@ class TiDB: @dataclass class TiDBKVStorage(BaseKVStorage): - # should pass db object to self.db - db: TiDB = None + # db instance must be injected before use + # db: TiDB def __post_init__(self): self._data = {} @@ -210,8 +210,8 @@ class TiDBKVStorage(BaseKVStorage): @dataclass class TiDBVectorDBStorage(BaseVectorStorage): - # should pass db object to self.db - db: TiDB = None + # db instance must be injected before use + # db: TiDB cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): @@ -333,8 +333,8 @@ class TiDBVectorDBStorage(BaseVectorStorage): @dataclass class TiDBGraphStorage(BaseGraphStorage): - # should pass db object to self.db - db: TiDB = None + # db instance must be injected before use + # db: TiDB def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 48c20428..5648c85d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -243,6 +243,9 @@ class LightRAG: graph_storage: str = field(default="NetworkXStorage") """Storage backend for knowledge graphs.""" + doc_status_storage: str = field(default="JsonDocStatusStorage") + """Storage type for tracking document processing statuses.""" + # Logging current_log_level = logger.level log_level: int = field(default=current_log_level) @@ -339,9 +342,6 @@ class LightRAG: convert_response_to_json ) - doc_status_storage: str = field(default="JsonDocStatusStorage") - """Storage type for tracking document processing statuses.""" - # Custom Chunking Function chunking_func: Callable[ [ @@ -355,6 +355,91 @@ class LightRAG: list[dict[str, Any]], ] = chunking_by_token_size + def _get_postgres_config(self): + return { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + def _get_oracle_config(self): + return { + "user": os.environ.get( + "ORACLE_USER", + config.get("oracle", "user", fallback=None), + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + def _get_tidb_config(self): + return { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + def verify_storage_implementation( self, storage_type: str, storage_name: str ) -> None: @@ -456,167 +541,6 @@ class LightRAG: # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - # Check if Oracle storage implementation is used - if ( - self.kv_storage == "OracleKVStorage" - or self.vector_storage == "OracleVectorDBStorage" - or self.graph_storage == "OracleGraphStorage" - ): - # Get parameters from environment variables or config file - dbconfig = { - "user": os.environ.get( - "ORACLE_USER", - config.get("oracle", "user", fallback=None), - ), - "password": os.environ.get( - "ORACLE_PASSWORD", - config.get("oracle", "password", fallback=None), - ), - "dsn": os.environ.get( - "ORACLE_DSN", - config.get("oracle", "dsn", fallback=None), - ), - "config_dir": os.environ.get( - "ORACLE_CONFIG_DIR", - config.get("oracle", "config_dir", fallback=None), - ), - "wallet_location": os.environ.get( - "ORACLE_WALLET_LOCATION", - config.get("oracle", "wallet_location", fallback=None), - ), - "wallet_password": os.environ.get( - "ORACLE_WALLET_PASSWORD", - config.get("oracle", "wallet_password", fallback=None), - ), - "workspace": os.environ.get( - "ORACLE_WORKSPACE", - config.get("oracle", "workspace", fallback="default"), - ), - } - - # Initialize OracleDB object - from .kg.oracle_impl import OracleDB - - oracle_db = OracleDB(dbconfig) - # Check if DB tables exist, if not, tables will be created - loop = always_get_an_event_loop() - loop.run_until_complete(oracle_db.check_tables()) - - # Only inject db object for Oracle storage implementations - if self.kv_storage == "OracleKVStorage": - self.key_string_value_json_storage_cls.db = oracle_db - if self.vector_storage == "OracleVectorDBStorage": - self.vector_db_storage_cls.db = oracle_db - if self.graph_storage == "OracleGraphStorage": - self.graph_storage_cls.db = oracle_db - - # Check if TiDB storage implementation is used - if ( - self.kv_storage == "TiDBKVStorage" - or self.vector_storage == "TiDBVectorDBStorage" - or self.graph_storage == "TiDBGraphStorage" - ): - # Get parameters from environment variables or config file - dbconfig = { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - - # Initialize TiDB object - from .kg.tidb_impl import TiDB - - tidb_db = TiDB(dbconfig) - # Check if DB tables exist, if not, tables will be created - loop = always_get_an_event_loop() - loop.run_until_complete(tidb_db.check_tables()) - - # Only inject db object for TiDB storage implementations - if self.kv_storage == "TiDBKVStorage": - self.key_string_value_json_storage_cls.db = tidb_db - if self.vector_storage == "TiDBVectorDBStorage": - self.vector_db_storage_cls.db = tidb_db - if self.graph_storage == "TiDBGraphStorage": - self.graph_storage_cls.db = tidb_db - - # Check if PostgreSQL storage implementation is used - if ( - self.kv_storage == "PGKVStorage" - or self.vector_storage == "PGVectorStorage" - or self.graph_storage == "PGGraphStorage" - or self.doc_status_storage == "PGDocStatusStorage" - ): - # Read configuration file - config_parser = configparser.ConfigParser() - if os.path.exists("config.ini"): - config_parser.read("config.ini") - - # Get parameters from environment variables or config file - dbconfig = { - "host": os.environ.get( - "POSTGRES_HOST", - config.get("postgres", "host", fallback="localhost"), - ), - "port": os.environ.get( - "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) - ), - "user": os.environ.get( - "POSTGRES_USER", config.get("postgres", "user", fallback=None) - ), - "password": os.environ.get( - "POSTGRES_PASSWORD", - config.get("postgres", "password", fallback=None), - ), - "database": os.environ.get( - "POSTGRES_DATABASE", - config.get("postgres", "database", fallback=None), - ), - "workspace": os.environ.get( - "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), - ), - } - - # Initialize PostgreSQLDB object - from .kg.postgres_impl import PostgreSQLDB - - postgres_db = PostgreSQLDB(dbconfig) - # Initialize and check tables - loop = always_get_an_event_loop() - loop.run_until_complete(postgres_db.initdb()) - # Check if DB tables exist, if not, tables will be created - loop.run_until_complete(postgres_db.check_tables()) - - # Only inject db object for PostgreSQL storage implementations - if self.kv_storage == "PGKVStorage": - self.key_string_value_json_storage_cls.db = postgres_db - if self.vector_storage == "PGVectorStorage": - self.vector_db_storage_cls.db = postgres_db - if self.graph_storage == "PGGraphStorage": - self.graph_storage_cls.db = postgres_db - if self.doc_status_storage == "PGDocStatusStorage": - self.doc_status_storage_cls.db = postgres_db - self.llm_response_cache = self.key_string_value_json_storage_cls( namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE @@ -664,6 +588,13 @@ class LightRAG: embedding_func=self.embedding_func, ) + # Initialize document status storage + self.doc_status: DocStatusStorage = self.doc_status_storage_cls( + namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), + global_config=global_config, + embedding_func=None, + ) + # What's for, Is this nessisary ? if self.llm_response_cache and hasattr( self.llm_response_cache, "global_config" @@ -677,16 +608,21 @@ class LightRAG: embedding_func=self.embedding_func, ) - # self.json_doc_status_storage = self.key_string_value_json_storage_cls( - # namespace=self.namespace_prefix + "json_doc_status_storage", - # embedding_func=None, - # ) - self.doc_status: DocStatusStorage = self.doc_status_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), - global_config=global_config, - embedding_func=None, - ) + # Collect all storage instances + storage_instances = [ + self.full_docs, + self.text_chunks, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.doc_status, + ] + + # Initialize database connections if needed + loop = always_get_an_event_loop() + loop.run_until_complete(self._initialize_database_if_needed(storage_instances)) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -710,8 +646,81 @@ class LightRAG: storage_class = lazy_external_import(import_path, storage_name) return storage_class + async def _initialize_database_if_needed(self, storage_instances: list): + """Intialize database connection and inject it to storage implementation if needed""" + from .kg.postgres_impl import PostgreSQLDB + from .kg.oracle_impl import OracleDB + from .kg.tidb_impl import TiDB + from .kg.postgres_impl import ( + PGKVStorage, + PGVectorStorage, + PGGraphStorage, + PGDocStatusStorage, + ) + from .kg.oracle_impl import ( + OracleKVStorage, + OracleVectorDBStorage, + OracleGraphStorage, + ) + from .kg.tidb_impl import ( + TiDBKVStorage, + TiDBVectorDBStorage, + TiDBGraphStorage) + + # Checking if PostgreSQL is needed + if any( + isinstance( + storage, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ) + for storage in storage_instances + ): + postgres_db = PostgreSQLDB(self._get_postgres_config()) + await postgres_db.initdb() + await postgres_db.check_tables() + for storage in storage_instances: + if isinstance( + storage, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ): + storage.db = postgres_db + logger.info(f"Injected postgres_db to {storage.__class__.__name__}") + + # Checking if Oracle is needed + if any( + isinstance( + storage, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage) + ) + for storage in storage_instances + ): + oracle_db = OracleDB(self._get_oracle_config()) + await oracle_db.check_tables() + for storage in storage_instances: + if isinstance( + storage, + (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), + ): + storage.db = oracle_db + logger.info(f"Injected oracle_db to {storage.__class__.__name__}") + + # Checking if TiDB is needed + if any( + isinstance(storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)) + for storage in storage_instances + ): + tidb_db = TiDB(self._get_tidb_config()) + await tidb_db.check_tables() + # 注入db实例 + for storage in storage_instances: + if isinstance( + storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage) + ): + storage.db = tidb_db + logger.info(f"Injected tidb_db to {storage.__class__.__name__}") + def set_storage_client(self, db_client): - # Now only tested on Oracle Database + # Inject db to storage implementation (only tested on Oracle Database + # Deprecated, seting correct value to *_storage creating LightRAG insteaded for storage in [ self.vector_db_storage_cls, self.graph_storage_cls,