From e194e0422668558c6e089d7417003cfeae18f896 Mon Sep 17 00:00:00 2001 From: ArnoChen Date: Wed, 19 Feb 2025 03:46:18 +0800 Subject: [PATCH] refactor database connection management and improve storage lifecycle handling update --- lightrag/api/lightrag_server.py | 199 ++------------------------------ lightrag/base.py | 17 +++ lightrag/kg/oracle_impl.py | 97 +++++++++++++++- lightrag/kg/postgres_impl.py | 99 +++++++++++++++- lightrag/kg/tidb_impl.py | 91 ++++++++++++++- lightrag/lightrag.py | 68 ++++++++++- 6 files changed, 376 insertions(+), 195 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index fba81086..b3a72d4d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -15,11 +15,6 @@ import logging import argparse from typing import List, Any, Literal, Optional, Dict from pydantic import BaseModel, Field, field_validator -from lightrag import LightRAG, QueryParam -from lightrag.base import DocProcessingStatus, DocStatus -from lightrag.types import GPTKeywordExtractionFormat -from lightrag.api import __api_version__ -from lightrag.utils import EmbeddingFunc from pathlib import Path import shutil import aiofiles @@ -36,39 +31,13 @@ import configparser import traceback from datetime import datetime +from lightrag import LightRAG, QueryParam +from lightrag.base import DocProcessingStatus, DocStatus +from lightrag.types import GPTKeywordExtractionFormat +from lightrag.api import __api_version__ +from lightrag.utils import EmbeddingFunc from lightrag.utils import logger -from .ollama_api import ( - OllamaAPI, -) -from .ollama_api import ollama_server_infos - - -def get_db_type_from_storage_class(class_name: str) -> str | None: - """Determine database type based on storage class name""" - if class_name.startswith("PG"): - return "postgres" - elif class_name.startswith("Oracle"): - return "oracle" - elif class_name.startswith("TiDB"): - return "tidb" - return None - - -def import_db_module(db_type: str): - """Dynamically import database module""" - if db_type == "postgres": - from ..kg.postgres_impl import PostgreSQLDB - - return PostgreSQLDB - elif db_type == "oracle": - from ..kg.oracle_impl import OracleDB - - return OracleDB - elif db_type == "tidb": - from ..kg.tidb_impl import TiDB - - return TiDB - return None +from .ollama_api import OllamaAPI, ollama_server_infos # Load environment variables @@ -929,52 +898,12 @@ def create_app(args): @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" - # Initialize database connections - db_instances = {} # Store background tasks app.state.background_tasks = set() try: - # Check which database types are used - db_types = set() - for storage_name, storage_instance in storage_instances: - db_type = get_db_type_from_storage_class( - storage_instance.__class__.__name__ - ) - if db_type: - db_types.add(db_type) - - # Import and initialize databases as needed - for db_type in db_types: - if db_type == "postgres": - DB = import_db_module("postgres") - db = DB(_get_postgres_config()) - await db.initdb() - await db.check_tables() - db_instances["postgres"] = db - elif db_type == "oracle": - DB = import_db_module("oracle") - db = DB(_get_oracle_config()) - await db.check_tables() - db_instances["oracle"] = db - elif db_type == "tidb": - DB = import_db_module("tidb") - db = DB(_get_tidb_config()) - await db.check_tables() - db_instances["tidb"] = db - - # Inject database instances into storage classes - for storage_name, storage_instance in storage_instances: - db_type = get_db_type_from_storage_class( - storage_instance.__class__.__name__ - ) - if db_type: - if db_type not in db_instances: - error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized" - logger.error(error_msg) - raise RuntimeError(error_msg) - storage_instance.db = db_instances[db_type] - logger.info(f"Injected {db_type} db to {storage_name}") + # Initialize database connections + await rag.initialize_storages() # Auto scan documents if enabled if args.auto_scan_at_startup: @@ -1000,17 +929,7 @@ def create_app(args): finally: # Clean up database connections - for db_type, db in db_instances.items(): - if hasattr(db, "pool"): - await db.pool.close() - # Use more accurate database name display - db_names = { - "postgres": "PostgreSQL", - "oracle": "Oracle", - "tidb": "TiDB", - } - db_name = db_names.get(db_type, db_type) - logger.info(f"Closed {db_name} database connection pool") + await rag.finalize_storages() # Initialize FastAPI app = FastAPI( @@ -1042,92 +961,6 @@ def create_app(args): allow_headers=["*"], ) - # Database configuration functions - def _get_postgres_config(): - return { - "host": os.environ.get( - "POSTGRES_HOST", - config.get("postgres", "host", fallback="localhost"), - ), - "port": os.environ.get( - "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) - ), - "user": os.environ.get( - "POSTGRES_USER", config.get("postgres", "user", fallback=None) - ), - "password": os.environ.get( - "POSTGRES_PASSWORD", - config.get("postgres", "password", fallback=None), - ), - "database": os.environ.get( - "POSTGRES_DATABASE", - config.get("postgres", "database", fallback=None), - ), - "workspace": os.environ.get( - "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), - ), - } - - def _get_oracle_config(): - return { - "user": os.environ.get( - "ORACLE_USER", - config.get("oracle", "user", fallback=None), - ), - "password": os.environ.get( - "ORACLE_PASSWORD", - config.get("oracle", "password", fallback=None), - ), - "dsn": os.environ.get( - "ORACLE_DSN", - config.get("oracle", "dsn", fallback=None), - ), - "config_dir": os.environ.get( - "ORACLE_CONFIG_DIR", - config.get("oracle", "config_dir", fallback=None), - ), - "wallet_location": os.environ.get( - "ORACLE_WALLET_LOCATION", - config.get("oracle", "wallet_location", fallback=None), - ), - "wallet_password": os.environ.get( - "ORACLE_WALLET_PASSWORD", - config.get("oracle", "wallet_password", fallback=None), - ), - "workspace": os.environ.get( - "ORACLE_WORKSPACE", - config.get("oracle", "workspace", fallback="default"), - ), - } - - def _get_tidb_config(): - return { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - # Create the optional API key dependency optional_api_key = get_api_key_dependency(api_key) @@ -1262,6 +1095,7 @@ def create_app(args): }, log_level=args.log_level, namespace_prefix=args.namespace_prefix, + is_managed_by_server=True, ) else: rag = LightRAG( @@ -1293,20 +1127,9 @@ def create_app(args): }, log_level=args.log_level, namespace_prefix=args.namespace_prefix, + is_managed_by_server=True, ) - # Collect all storage instances - storage_instances = [ - ("full_docs", rag.full_docs), - ("text_chunks", rag.text_chunks), - ("chunk_entity_relation_graph", rag.chunk_entity_relation_graph), - ("entities_vdb", rag.entities_vdb), - ("relationships_vdb", rag.relationships_vdb), - ("chunks_vdb", rag.chunks_vdb), - ("doc_status", rag.doc_status), - ("llm_response_cache", rag.llm_response_cache), - ] - async def pipeline_enqueue_file(file_path: Path) -> bool: """Add a file to the queue for processing diff --git a/lightrag/base.py b/lightrag/base.py index 79cc5639..5f6a1bf1 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -87,6 +87,14 @@ class StorageNameSpace(ABC): namespace: str global_config: dict[str, Any] + async def initialize(self): + """Initialize the storage""" + pass + + async def finalize(self): + """Finalize the storage""" + pass + @abstractmethod async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" @@ -247,3 +255,12 @@ class DocStatusStorage(BaseKVStorage, ABC): self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" + + +class StoragesStatus(str, Enum): + """Storages status""" + + NOT_CREATED = "not_created" + CREATED = "created" + INITIALIZED = "initialized" + FINALIZED = "finalized" diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 0b21f620..8391acaa 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -2,11 +2,11 @@ import array import asyncio # import html -# import os +import os from dataclasses import dataclass from typing import Any, Union, final - import numpy as np +import configparser from lightrag.types import KnowledgeGraph @@ -173,6 +173,72 @@ class OracleDB: raise +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @staticmethod + def get_config(): + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + return { + "user": os.environ.get( + "ORACLE_USER", + config.get("oracle", "user", fallback=None), + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + @classmethod + async def get_client(cls) -> OracleDB: + async with cls._lock: + if cls._instances["db"] is None: + config = ClientManager.get_config() + db = OracleDB(config) + await db.check_tables() + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: OracleDB): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + await db.pool.close() + logger.info("Closed OracleDB database connection pool") + cls._instances["db"] = None + else: + await db.pool.close() + + @final @dataclass class OracleKVStorage(BaseKVStorage): @@ -184,6 +250,15 @@ class OracleKVStorage(BaseKVStorage): self._data = {} self._max_batch_size = self.global_config.get("embedding_batch_num", 10) + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -329,6 +404,15 @@ class OracleVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + #################### query method ############### async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) @@ -368,6 +452,15 @@ class OracleGraphStorage(BaseGraphStorage): def __post_init__(self): self._max_batch_size = self.global_config.get("embedding_batch_num", 10) + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + #################### insert method ################ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index f7866e42..2ec16716 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -5,8 +5,8 @@ import os import time from dataclasses import dataclass from typing import Any, Dict, List, Union, final - import numpy as np +import configparser from lightrag.types import KnowledgeGraph @@ -182,6 +182,67 @@ class PostgreSQLDB: pass +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @staticmethod + def get_config(): + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + return { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + @classmethod + async def get_client(cls) -> PostgreSQLDB: + async with cls._lock: + if cls._instances["db"] is None: + config = ClientManager.get_config() + db = PostgreSQLDB(config) + await db.initdb() + await db.check_tables() + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: PostgreSQLDB): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + await db.pool.close() + logger.info("Closed PostgreSQL database connection pool") + cls._instances["db"] = None + else: + await db.pool.close() + + @final @dataclass class PGKVStorage(BaseKVStorage): @@ -191,6 +252,15 @@ class PGKVStorage(BaseKVStorage): def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -319,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + def _upsert_chunks(self, item: dict): try: upsert_sql = SQL_TEMPLATES["upsert_chunk"] @@ -435,6 +514,15 @@ class PGVectorStorage(BaseVectorStorage): @final @dataclass class PGDocStatusStorage(DocStatusStorage): + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( @@ -584,6 +672,15 @@ class PGGraphStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + async def index_done_callback(self) -> None: # PG handles persistence automatically pass diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 110a404a..dc0dc422 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -14,6 +14,7 @@ from ..namespace import NameSpace, is_namespace from ..utils import logger import pipmaster as pm +import configparser if not pm.is_installed("pymysql"): pm.install("pymysql") @@ -105,6 +106,63 @@ class TiDB: raise +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @staticmethod + def get_config(): + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + return { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + + @classmethod + async def get_client(cls) -> TiDB: + async with cls._lock: + if cls._instances["db"] is None: + config = ClientManager.get_config() + db = TiDB(config) + await db.check_tables() + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: TiDB): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + cls._instances["db"] = None + + @final @dataclass class TiDBKVStorage(BaseKVStorage): @@ -115,6 +173,15 @@ class TiDBKVStorage(BaseKVStorage): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -185,7 +252,7 @@ class TiDBKVStorage(BaseKVStorage): "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], - "content_vector": f'{item["__vector__"].tolist()}', + "content_vector": f"{item['__vector__'].tolist()}", "workspace": self.db.workspace, } ) @@ -226,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Search from tidb vector""" embeddings = await self.embedding_func([query]) @@ -290,7 +366,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): "id": item["id"], "name": item["entity_name"], "content": item["content"], - "content_vector": f'{item["content_vector"].tolist()}', + "content_vector": f"{item['content_vector'].tolist()}", "workspace": self.db.workspace, } # update entity_id if node inserted by graph_storage_instance before @@ -312,7 +388,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): "source_name": item["src_id"], "target_name": item["tgt_id"], "content": item["content"], - "content_vector": f'{item["content_vector"].tolist()}', + "content_vector": f"{item['content_vector'].tolist()}", "workspace": self.db.workspace, } # update relation_id if node inserted by graph_storage_instance before @@ -351,6 +427,15 @@ class TiDBGraphStorage(BaseGraphStorage): def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + #################### upsert method ################ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 09a8df3f..d18e4d7a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,7 @@ from .base import ( DocStatusStorage, QueryParam, StorageNameSpace, + StoragesStatus, ) from .namespace import NameSpace, make_namespace from .operate import ( @@ -348,6 +349,9 @@ class LightRAG: # Extensions addon_params: dict[str, Any] = field(default_factory=dict) + # Ownership + is_managed_by_server: bool = False + """Dictionary for additional parameters and extensions.""" convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( convert_response_to_json @@ -440,7 +444,10 @@ class LightRAG: **self.vector_db_storage_cls_kwargs, } - # show config + # Life cycle + self.storages_status = StoragesStatus.NOT_CREATED + + # Show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -547,6 +554,65 @@ class LightRAG: ) ) + self.storages_status = StoragesStatus.CREATED + + # Initialize storages + if not self.is_managed_by_server: + loop = always_get_an_event_loop() + loop.run_until_complete(self.initialize_storages()) + + def __del__(self): + # Finalize storages + if not self.is_managed_by_server: + loop = always_get_an_event_loop() + loop.run_until_complete(self.finalize_storages()) + + async def initialize_storages(self): + """Asynchronously initialize the storages""" + if self.storages_status == StoragesStatus.CREATED: + tasks = [] + + for storage in ( + self.full_docs, + self.text_chunks, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.chunk_entity_relation_graph, + self.llm_response_cache, + self.doc_status, + ): + if storage: + tasks.append(storage.initialize()) + + await asyncio.gather(*tasks) + + self.storages_status = StoragesStatus.INITIALIZED + logger.debug("Initialized Storages") + + async def finalize_storages(self): + """Asynchronously finalize the storages""" + if self.storages_status == StoragesStatus.INITIALIZED: + tasks = [] + + for storage in ( + self.full_docs, + self.text_chunks, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.chunk_entity_relation_graph, + self.llm_response_cache, + self.doc_status, + ): + if storage: + tasks.append(storage.finalize()) + + await asyncio.gather(*tasks) + logger.debug("Finalized Storages") + + self.storages_status = StoragesStatus.FINALIZED + async def get_graph_labels(self): text = await self.chunk_entity_relation_graph.get_all_labels() return text