diff --git a/examples/lightrag_api_oracle_demo.py b/examples/lightrag_api_oracle_demo.py index 6162a300..e66e3f94 100644 --- a/examples/lightrag_api_oracle_demo.py +++ b/examples/lightrag_api_oracle_demo.py @@ -17,7 +17,6 @@ from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import numpy as np -from lightrag.kg.oracle_impl import OracleDB print(os.getcwd()) script_directory = Path(__file__).resolve().parent.parent @@ -48,6 +47,14 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}") if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) +os.environ["ORACLE_USER"] = "" +os.environ["ORACLE_PASSWORD"] = "" +os.environ["ORACLE_DSN"] = "" +os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir" +os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location" +os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password" +os.environ["ORACLE_WORKSPACE"] = "company" + async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs @@ -89,20 +96,6 @@ async def init(): # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud - oracle_db = OracleDB( - config={ - "user": "", - "password": "", - "dsn": "", - "config_dir": "path_to_config_dir", - "wallet_location": "path_to_wallet_location", - "wallet_password": "wallet_password", - "workspace": "company", - } # specify which docs you want to store and query - ) - - # Check if Oracle DB tables exist, if not, tables will be created - await oracle_db.check_tables() # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt @@ -121,11 +114,6 @@ async def init(): vector_storage="OracleVectorDBStorage", ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.graph_storage_cls.db = oracle_db - rag.key_string_value_json_storage_cls.db = oracle_db - rag.vector_db_storage_cls.db = oracle_db - return rag diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index 9c90424e..c6121840 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -6,7 +6,6 @@ from lightrag import LightRAG, QueryParam from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import numpy as np -from lightrag.kg.oracle_impl import OracleDB print(os.getcwd()) script_directory = Path(__file__).resolve().parent.parent @@ -26,6 +25,14 @@ MAX_TOKENS = 4000 if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) +os.environ["ORACLE_USER"] = "username" +os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx" +os.environ["ORACLE_DSN"] = "xxxxxxx_medium" +os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir" +os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location" +os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password" +os.environ["ORACLE_WORKSPACE"] = "company" + async def llm_model_func( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs @@ -63,26 +70,6 @@ async def main(): embedding_dimension = await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") - # Create Oracle DB connection - # The `config` parameter is the connection configuration of Oracle DB - # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html - # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query - # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud - oracle_db = OracleDB( - config={ - "user": "username", - "password": "xxxxxxxxx", - "dsn": "xxxxxxx_medium", - "config_dir": "dir/path/to/oracle/config", - "wallet_location": "dir/path/to/oracle/wallet", - "wallet_password": "xxxxxxxxx", - "workspace": "company", # specify which docs you want to store and query - } - ) - - # Check if Oracle DB tables exist, if not, tables will be created - await oracle_db.check_tables() - # Initialize LightRAG # We use Oracle DB as the KV/vector/graph storage # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt @@ -112,26 +99,6 @@ async def main(): }, ) - # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool - - for storage in [ - rag.vector_db_storage_cls, - rag.graph_storage_cls, - rag.doc_status, - rag.full_docs, - rag.text_chunks, - rag.llm_response_cache, - rag.key_string_value_json_storage_cls, - rag.chunks_vdb, - rag.relationships_vdb, - rag.entities_vdb, - rag.graph_storage_cls, - rag.chunk_entity_relation_graph, - rag.llm_response_cache, - ]: - # set client - storage.db = oracle_db - # Extract and Insert into LightRAG storage with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: all_text = f.read() diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index b8e4d35c..f4004f84 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -4,7 +4,6 @@ import os import numpy as np from lightrag import LightRAG, QueryParam -from lightrag.kg.tidb_impl import TiDB from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache from lightrag.utils import EmbeddingFunc @@ -17,11 +16,11 @@ APIKEY = "" CHATMODEL = "" EMBEDMODEL = "" -TIDB_HOST = "" -TIDB_PORT = "" -TIDB_USER = "" -TIDB_PASSWORD = "" -TIDB_DATABASE = "lightrag" +os.environ["TIDB_HOST"] = "" +os.environ["TIDB_PORT"] = "" +os.environ["TIDB_USER"] = "" +os.environ["TIDB_PASSWORD"] = "" +os.environ["TIDB_DATABASE"] = "lightrag" if not os.path.exists(WORKING_DIR): os.mkdir(WORKING_DIR) @@ -62,21 +61,6 @@ async def main(): embedding_dimension = await get_embedding_dim() print(f"Detected embedding dimension: {embedding_dimension}") - # Create TiDB DB connection - tidb = TiDB( - config={ - "host": TIDB_HOST, - "port": TIDB_PORT, - "user": TIDB_USER, - "password": TIDB_PASSWORD, - "database": TIDB_DATABASE, - "workspace": "company", # specify which docs you want to store and query - } - ) - - # Check if TiDB DB tables exist, if not, tables will be created - await tidb.check_tables() - # Initialize LightRAG # We use TiDB DB as the KV/vector # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt @@ -95,15 +79,6 @@ async def main(): graph_storage="TiDBGraphStorage", ) - if rag.llm_response_cache: - rag.llm_response_cache.db = tidb - rag.full_docs.db = tidb - rag.text_chunks.db = tidb - rag.entities_vdb.db = tidb - rag.relationships_vdb.db = tidb - rag.chunks_vdb.db = tidb - rag.chunk_entity_relation_graph.db = tidb - # Extract and Insert into LightRAG storage with open("./dickens/demo.txt", "r", encoding="utf-8") as f: await rag.ainsert(f.read()) diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index f2066d09..913361b3 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -5,7 +5,6 @@ import time from dotenv import load_dotenv from lightrag import LightRAG, QueryParam -from lightrag.kg.postgres_impl import PostgreSQLDB from lightrag.llm.zhipu import zhipu_complete from lightrag.llm.ollama import ollama_embedding from lightrag.utils import EmbeddingFunc @@ -22,22 +21,14 @@ if not os.path.exists(WORKING_DIR): # AGE os.environ["AGE_GRAPH_NAME"] = "dickens" -postgres_db = PostgreSQLDB( - config={ - "host": "localhost", - "port": 15432, - "user": "rag", - "password": "rag", - "database": "rag", - } -) +os.environ["POSTGRES_HOST"] = "localhost" +os.environ["POSTGRES_PORT"] = "15432" +os.environ["POSTGRES_USER"] = "rag" +os.environ["POSTGRES_PASSWORD"] = "rag" +os.environ["POSTGRES_DATABASE"] = "rag" async def main(): - await postgres_db.initdb() - # Check if PostgreSQL DB tables exist, if not, tables will be created - await postgres_db.check_tables() - rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=zhipu_complete, @@ -57,17 +48,7 @@ async def main(): graph_storage="PGGraphStorage", vector_storage="PGVectorStorage", ) - # Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool - rag.doc_status.db = postgres_db - rag.full_docs.db = postgres_db - rag.text_chunks.db = postgres_db - rag.llm_response_cache.db = postgres_db - rag.key_string_value_json_storage_cls.db = postgres_db - rag.chunks_vdb.db = postgres_db - rag.relationships_vdb.db = postgres_db - rag.entities_vdb.db = postgres_db - rag.graph_storage_cls.db = postgres_db - rag.chunk_entity_relation_graph.db = postgres_db + # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index fba81086..a0b03129 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -15,11 +15,6 @@ import logging import argparse from typing import List, Any, Literal, Optional, Dict from pydantic import BaseModel, Field, field_validator -from lightrag import LightRAG, QueryParam -from lightrag.base import DocProcessingStatus, DocStatus -from lightrag.types import GPTKeywordExtractionFormat -from lightrag.api import __api_version__ -from lightrag.utils import EmbeddingFunc from pathlib import Path import shutil import aiofiles @@ -36,39 +31,13 @@ import configparser import traceback from datetime import datetime +from lightrag import LightRAG, QueryParam +from lightrag.base import DocProcessingStatus, DocStatus +from lightrag.types import GPTKeywordExtractionFormat +from lightrag.api import __api_version__ +from lightrag.utils import EmbeddingFunc from lightrag.utils import logger -from .ollama_api import ( - OllamaAPI, -) -from .ollama_api import ollama_server_infos - - -def get_db_type_from_storage_class(class_name: str) -> str | None: - """Determine database type based on storage class name""" - if class_name.startswith("PG"): - return "postgres" - elif class_name.startswith("Oracle"): - return "oracle" - elif class_name.startswith("TiDB"): - return "tidb" - return None - - -def import_db_module(db_type: str): - """Dynamically import database module""" - if db_type == "postgres": - from ..kg.postgres_impl import PostgreSQLDB - - return PostgreSQLDB - elif db_type == "oracle": - from ..kg.oracle_impl import OracleDB - - return OracleDB - elif db_type == "tidb": - from ..kg.tidb_impl import TiDB - - return TiDB - return None +from .ollama_api import OllamaAPI, ollama_server_infos # Load environment variables @@ -929,52 +898,12 @@ def create_app(args): @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" - # Initialize database connections - db_instances = {} # Store background tasks app.state.background_tasks = set() try: - # Check which database types are used - db_types = set() - for storage_name, storage_instance in storage_instances: - db_type = get_db_type_from_storage_class( - storage_instance.__class__.__name__ - ) - if db_type: - db_types.add(db_type) - - # Import and initialize databases as needed - for db_type in db_types: - if db_type == "postgres": - DB = import_db_module("postgres") - db = DB(_get_postgres_config()) - await db.initdb() - await db.check_tables() - db_instances["postgres"] = db - elif db_type == "oracle": - DB = import_db_module("oracle") - db = DB(_get_oracle_config()) - await db.check_tables() - db_instances["oracle"] = db - elif db_type == "tidb": - DB = import_db_module("tidb") - db = DB(_get_tidb_config()) - await db.check_tables() - db_instances["tidb"] = db - - # Inject database instances into storage classes - for storage_name, storage_instance in storage_instances: - db_type = get_db_type_from_storage_class( - storage_instance.__class__.__name__ - ) - if db_type: - if db_type not in db_instances: - error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized" - logger.error(error_msg) - raise RuntimeError(error_msg) - storage_instance.db = db_instances[db_type] - logger.info(f"Injected {db_type} db to {storage_name}") + # Initialize database connections + await rag.initialize_storages() # Auto scan documents if enabled if args.auto_scan_at_startup: @@ -1000,17 +929,7 @@ def create_app(args): finally: # Clean up database connections - for db_type, db in db_instances.items(): - if hasattr(db, "pool"): - await db.pool.close() - # Use more accurate database name display - db_names = { - "postgres": "PostgreSQL", - "oracle": "Oracle", - "tidb": "TiDB", - } - db_name = db_names.get(db_type, db_type) - logger.info(f"Closed {db_name} database connection pool") + await rag.finalize_storages() # Initialize FastAPI app = FastAPI( @@ -1042,92 +961,6 @@ def create_app(args): allow_headers=["*"], ) - # Database configuration functions - def _get_postgres_config(): - return { - "host": os.environ.get( - "POSTGRES_HOST", - config.get("postgres", "host", fallback="localhost"), - ), - "port": os.environ.get( - "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) - ), - "user": os.environ.get( - "POSTGRES_USER", config.get("postgres", "user", fallback=None) - ), - "password": os.environ.get( - "POSTGRES_PASSWORD", - config.get("postgres", "password", fallback=None), - ), - "database": os.environ.get( - "POSTGRES_DATABASE", - config.get("postgres", "database", fallback=None), - ), - "workspace": os.environ.get( - "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), - ), - } - - def _get_oracle_config(): - return { - "user": os.environ.get( - "ORACLE_USER", - config.get("oracle", "user", fallback=None), - ), - "password": os.environ.get( - "ORACLE_PASSWORD", - config.get("oracle", "password", fallback=None), - ), - "dsn": os.environ.get( - "ORACLE_DSN", - config.get("oracle", "dsn", fallback=None), - ), - "config_dir": os.environ.get( - "ORACLE_CONFIG_DIR", - config.get("oracle", "config_dir", fallback=None), - ), - "wallet_location": os.environ.get( - "ORACLE_WALLET_LOCATION", - config.get("oracle", "wallet_location", fallback=None), - ), - "wallet_password": os.environ.get( - "ORACLE_WALLET_PASSWORD", - config.get("oracle", "wallet_password", fallback=None), - ), - "workspace": os.environ.get( - "ORACLE_WORKSPACE", - config.get("oracle", "workspace", fallback="default"), - ), - } - - def _get_tidb_config(): - return { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - # Create the optional API key dependency optional_api_key = get_api_key_dependency(api_key) @@ -1262,6 +1095,7 @@ def create_app(args): }, log_level=args.log_level, namespace_prefix=args.namespace_prefix, + auto_manage_storages_states=False, ) else: rag = LightRAG( @@ -1293,20 +1127,9 @@ def create_app(args): }, log_level=args.log_level, namespace_prefix=args.namespace_prefix, + auto_manage_storages_states=False, ) - # Collect all storage instances - storage_instances = [ - ("full_docs", rag.full_docs), - ("text_chunks", rag.text_chunks), - ("chunk_entity_relation_graph", rag.chunk_entity_relation_graph), - ("entities_vdb", rag.entities_vdb), - ("relationships_vdb", rag.relationships_vdb), - ("chunks_vdb", rag.chunks_vdb), - ("doc_status", rag.doc_status), - ("llm_response_cache", rag.llm_response_cache), - ] - async def pipeline_enqueue_file(file_path: Path) -> bool: """Add a file to the queue for processing diff --git a/lightrag/base.py b/lightrag/base.py index 79cc5639..5f6a1bf1 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -87,6 +87,14 @@ class StorageNameSpace(ABC): namespace: str global_config: dict[str, Any] + async def initialize(self): + """Initialize the storage""" + pass + + async def finalize(self): + """Finalize the storage""" + pass + @abstractmethod async def index_done_callback(self) -> None: """Commit the storage operations after indexing""" @@ -247,3 +255,12 @@ class DocStatusStorage(BaseKVStorage, ABC): self, status: DocStatus ) -> dict[str, DocProcessingStatus]: """Get all documents with a specific status""" + + +class StoragesStatus(str, Enum): + """Storages status""" + + NOT_CREATED = "not_created" + CREATED = "created" + INITIALIZED = "initialized" + FINALIZED = "finalized" diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 2210e9f4..c7b16a70 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,5 +1,5 @@ import os -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np import configparser import asyncio @@ -26,8 +26,11 @@ if not pm.is_installed("motor"): pm.install("motor") try: - from motor.motor_asyncio import AsyncIOMotorClient - from pymongo import MongoClient + from motor.motor_asyncio import ( + AsyncIOMotorClient, + AsyncIOMotorDatabase, + AsyncIOMotorCollection, + ) from pymongo.operations import SearchIndexModel from pymongo.errors import PyMongoError except ImportError as e: @@ -39,31 +42,63 @@ config = configparser.ConfigParser() config.read("config.ini", "utf-8") +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @classmethod + async def get_client(cls) -> AsyncIOMotorDatabase: + async with cls._lock: + if cls._instances["db"] is None: + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", + "uri", + fallback="mongodb://root:root@localhost:27017/", + ), + ) + database_name = os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + client = AsyncIOMotorClient(uri) + db = client.get_database(database_name) + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: AsyncIOMotorDatabase): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + cls._instances["db"] = None + + @final @dataclass class MongoKVStorage(BaseKVStorage): - def __post_init__(self): - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) + db: AsyncIOMotorDatabase = field(default=None) + _data: AsyncIOMotorCollection = field(default=None) + def __post_init__(self): self._collection_name = self.namespace - self._data = database.get_collection(self._collection_name) - logger.debug(f"Use MongoDB as KV {self._collection_name}") + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) + logger.debug(f"Use MongoDB as KV {self._collection_name}") - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def get_by_id(self, id: str) -> dict[str, Any] | None: return await self._data.find_one({"_id": id}) @@ -120,28 +155,23 @@ class MongoKVStorage(BaseKVStorage): @final @dataclass class MongoDocStatusStorage(DocStatusStorage): + db: AsyncIOMotorDatabase = field(default=None) + _data: AsyncIOMotorCollection = field(default=None) + def __post_init__(self): - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) - self._collection_name = self.namespace - self._data = database.get_collection(self._collection_name) - logger.debug(f"Use MongoDB as doc status {self._collection_name}") + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) + logger.debug(f"Use MongoDB as DocStatus {self._collection_name}") - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return await self._data.find_one({"_id": id}) @@ -202,36 +232,33 @@ class MongoDocStatusStorage(DocStatusStorage): @dataclass class MongoGraphStorage(BaseGraphStorage): """ - A concrete implementation using MongoDB’s $graphLookup to demonstrate multi-hop queries. + A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries. """ + db: AsyncIOMotorDatabase = field(default=None) + collection: AsyncIOMotorCollection = field(default=None) + def __init__(self, namespace, global_config, embedding_func): super().__init__( namespace=namespace, global_config=global_config, embedding_func=embedding_func, ) - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) - self._collection_name = self.namespace - self.collection = database.get_collection(self._collection_name) - logger.debug(f"Use MongoDB as KG {self._collection_name}") + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + self.collection = await get_or_create_collection( + self.db, self._collection_name + ) + logger.debug(f"Use MongoDB as KG {self._collection_name}") - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self.collection = None # # ------------------------------------------------------------------------- @@ -770,6 +797,9 @@ class MongoGraphStorage(BaseGraphStorage): @final @dataclass class MongoVectorDBStorage(BaseVectorStorage): + db: AsyncIOMotorDatabase = field(default=None) + _data: AsyncIOMotorCollection = field(default=None) + def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -778,41 +808,36 @@ class MongoVectorDBStorage(BaseVectorStorage): "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" ) self.cosine_better_than_threshold = cosine_threshold - - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) - self._collection_name = self.namespace - self._data = database.get_collection(self._collection_name) self._max_batch_size = self.global_config["embedding_batch_num"] - logger.debug(f"Use MongoDB as VDB {self._collection_name}") + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + # Ensure vector index exists + await self.create_vector_index_if_not_exists() - # Ensure vector index exists - self.create_vector_index(uri, database.name, self._collection_name) + logger.debug(f"Use MongoDB as VDB {self._collection_name}") - def create_vector_index(self, uri: str, database_name: str, collection_name: str): + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None + + async def create_vector_index_if_not_exists(self): """Creates an Atlas Vector Search index.""" - client = MongoClient(uri) - collection = client.get_database(database_name).get_collection( - self._collection_name - ) - try: + index_name = "vector_knn_index" + + indexes = await self._data.list_search_indexes().to_list(length=None) + for index in indexes: + if index["name"] == index_name: + logger.debug("vector index already exist") + return + search_index_model = SearchIndexModel( definition={ "fields": [ @@ -824,11 +849,11 @@ class MongoVectorDBStorage(BaseVectorStorage): } ] }, - name="vector_knn_index", + name=index_name, type="vectorSearch", ) - collection.create_search_index(search_index_model) + await self._data.create_search_index(search_index_model) logger.info("Vector index created successfully.") except PyMongoError as _: @@ -913,15 +938,13 @@ class MongoVectorDBStorage(BaseVectorStorage): raise NotImplementedError -def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): - """Check if the collection exists. if not, create it.""" - client = MongoClient(uri) - database = client.get_database(database_name) - - collection_names = database.list_collection_names() +async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): + collection_names = await db.list_collection_names() if collection_name not in collection_names: - database.create_collection(collection_name) + collection = await db.create_collection(collection_name) logger.info(f"Created collection: {collection_name}") + return collection else: logger.debug(f"Collection '{collection_name}' already exists.") + return db.get_collection(collection_name) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 1bb12ccf..3e0c6799 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -2,11 +2,11 @@ import array import asyncio # import html -# import os -from dataclasses import dataclass +import os +from dataclasses import dataclass, field from typing import Any, Union, final - import numpy as np +import configparser from lightrag.types import KnowledgeGraph @@ -177,17 +177,91 @@ class OracleDB: raise +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @staticmethod + def get_config(): + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + return { + "user": os.environ.get( + "ORACLE_USER", + config.get("oracle", "user", fallback=None), + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + @classmethod + async def get_client(cls) -> OracleDB: + async with cls._lock: + if cls._instances["db"] is None: + config = ClientManager.get_config() + db = OracleDB(config) + await db.check_tables() + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: OracleDB): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + await db.pool.close() + logger.info("Closed OracleDB database connection pool") + cls._instances["db"] = None + else: + await db.pool.close() + + @final @dataclass class OracleKVStorage(BaseKVStorage): - # db instance must be injected before use - # db: OracleDB + db: OracleDB = field(default=None) meta_fields = None def __post_init__(self): self._data = {} self._max_batch_size = self.global_config.get("embedding_batch_num", 10) + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -324,6 +398,8 @@ class OracleKVStorage(BaseKVStorage): @final @dataclass class OracleVectorDBStorage(BaseVectorStorage): + db: OracleDB = field(default=None) + def __post_init__(self): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") @@ -333,6 +409,15 @@ class OracleVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + #################### query method ############### async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: embeddings = await self.embedding_func([query]) @@ -369,9 +454,20 @@ class OracleVectorDBStorage(BaseVectorStorage): @final @dataclass class OracleGraphStorage(BaseGraphStorage): + db: OracleDB = field(default=None) + def __post_init__(self): self._max_batch_size = self.global_config.get("embedding_batch_num", 10) + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + #################### insert method ################ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 70900c93..fd560668 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -3,10 +3,10 @@ import inspect import json import os import time -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List, Union, final - import numpy as np +import configparser from lightrag.types import KnowledgeGraph @@ -181,15 +181,84 @@ class PostgreSQLDB: pass +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @staticmethod + def get_config(): + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + return { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + @classmethod + async def get_client(cls) -> PostgreSQLDB: + async with cls._lock: + if cls._instances["db"] is None: + config = ClientManager.get_config() + db = PostgreSQLDB(config) + await db.initdb() + await db.check_tables() + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: PostgreSQLDB): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + await db.pool.close() + logger.info("Closed PostgreSQL database connection pool") + cls._instances["db"] = None + else: + await db.pool.close() + + @final @dataclass class PGKVStorage(BaseKVStorage): - # db instance must be injected before use - # db: PostgreSQLDB + db: PostgreSQLDB = field(default=None) def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -308,6 +377,8 @@ class PGKVStorage(BaseKVStorage): @final @dataclass class PGVectorStorage(BaseVectorStorage): + db: PostgreSQLDB = field(default=None) + def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) @@ -318,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + def _upsert_chunks(self, item: dict): try: upsert_sql = SQL_TEMPLATES["upsert_chunk"] @@ -426,6 +506,17 @@ class PGVectorStorage(BaseVectorStorage): @final @dataclass class PGDocStatusStorage(DocStatusStorage): + db: PostgreSQLDB = field(default=None) + + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( @@ -565,6 +656,8 @@ class PGGraphQueryException(Exception): @final @dataclass class PGGraphStorage(BaseGraphStorage): + db: PostgreSQLDB = field(default=None) + @staticmethod def load_nx_graph(file_name): print("no preloading of graph with AGE in production") @@ -575,6 +668,15 @@ class PGGraphStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + async def index_done_callback(self) -> None: # PG handles persistence automatically pass diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index e8e7800d..b94148d6 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -1,6 +1,6 @@ import asyncio import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Union, final import numpy as np @@ -13,6 +13,7 @@ from ..namespace import NameSpace, is_namespace from ..utils import logger import pipmaster as pm +import configparser if not pm.is_installed("pymysql"): pm.install("pymysql") @@ -104,16 +105,81 @@ class TiDB: raise +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @staticmethod + def get_config(): + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + return { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + + @classmethod + async def get_client(cls) -> TiDB: + async with cls._lock: + if cls._instances["db"] is None: + config = ClientManager.get_config() + db = TiDB(config) + await db.check_tables() + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: TiDB): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + cls._instances["db"] = None + + @final @dataclass class TiDBKVStorage(BaseKVStorage): - # db instance must be injected before use - # db: TiDB + db: TiDB = field(default=None) def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -184,7 +250,7 @@ class TiDBKVStorage(BaseKVStorage): "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], - "content_vector": f'{item["__vector__"].tolist()}', + "content_vector": f"{item['__vector__'].tolist()}", "workspace": self.db.workspace, } ) @@ -212,6 +278,8 @@ class TiDBKVStorage(BaseKVStorage): @final @dataclass class TiDBVectorDBStorage(BaseVectorStorage): + db: TiDB = field(default=None) + def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" @@ -225,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: """Search from tidb vector""" embeddings = await self.embedding_func([query]) @@ -282,7 +359,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): "id": item["id"], "name": item["entity_name"], "content": item["content"], - "content_vector": f'{item["content_vector"].tolist()}', + "content_vector": f"{item['content_vector'].tolist()}", "workspace": self.db.workspace, } # update entity_id if node inserted by graph_storage_instance before @@ -304,7 +381,7 @@ class TiDBVectorDBStorage(BaseVectorStorage): "source_name": item["src_id"], "target_name": item["tgt_id"], "content": item["content"], - "content_vector": f'{item["content_vector"].tolist()}', + "content_vector": f"{item['content_vector'].tolist()}", "workspace": self.db.workspace, } # update relation_id if node inserted by graph_storage_instance before @@ -337,12 +414,20 @@ class TiDBVectorDBStorage(BaseVectorStorage): @final @dataclass class TiDBGraphStorage(BaseGraphStorage): - # db instance must be injected before use - # db: TiDB + db: TiDB = field(default=None) def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + #################### upsert method ################ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 7b3c9295..f9ab2333 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,7 @@ from .base import ( DocStatusStorage, QueryParam, StorageNameSpace, + StoragesStatus, ) from .namespace import NameSpace, make_namespace from .operate import ( @@ -348,6 +349,10 @@ class LightRAG: # Extensions addon_params: dict[str, Any] = field(default_factory=dict) + # Storages Management + auto_manage_storages_states: bool = True + """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" + """Dictionary for additional parameters and extensions.""" convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( convert_response_to_json @@ -440,7 +445,10 @@ class LightRAG: **self.vector_db_storage_cls_kwargs, } - # show config + # Life cycle + self.storages_status = StoragesStatus.NOT_CREATED + + # Show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) logger.debug(f"LightRAG init with param:\n {_print_config}\n") @@ -547,6 +555,65 @@ class LightRAG: ) ) + self.storages_status = StoragesStatus.CREATED + + # Initialize storages + if self.auto_manage_storages_states: + loop = always_get_an_event_loop() + loop.run_until_complete(self.initialize_storages()) + + def __del__(self): + # Finalize storages + if self.auto_manage_storages_states: + loop = always_get_an_event_loop() + loop.run_until_complete(self.finalize_storages()) + + async def initialize_storages(self): + """Asynchronously initialize the storages""" + if self.storages_status == StoragesStatus.CREATED: + tasks = [] + + for storage in ( + self.full_docs, + self.text_chunks, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.chunk_entity_relation_graph, + self.llm_response_cache, + self.doc_status, + ): + if storage: + tasks.append(storage.initialize()) + + await asyncio.gather(*tasks) + + self.storages_status = StoragesStatus.INITIALIZED + logger.debug("Initialized Storages") + + async def finalize_storages(self): + """Asynchronously finalize the storages""" + if self.storages_status == StoragesStatus.INITIALIZED: + tasks = [] + + for storage in ( + self.full_docs, + self.text_chunks, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.chunk_entity_relation_graph, + self.llm_response_cache, + self.doc_status, + ): + if storage: + tasks.append(storage.finalize()) + + await asyncio.gather(*tasks) + + self.storages_status = StoragesStatus.FINALIZED + logger.debug("Finalized Storages") + async def get_graph_labels(self): text = await self.chunk_entity_relation_graph.get_all_labels() return text