diff --git a/examples/lightrag_api_oracle_demo.py b/examples/lightrag_api_oracle_demo.py index 97e483f6..e66e3f94 100644 --- a/examples/lightrag_api_oracle_demo.py +++ b/examples/lightrag_api_oracle_demo.py @@ -17,7 +17,6 @@ from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import numpy as np -from lightrag.kg.oracle_impl import OracleDB print(os.getcwd()) script_directory = Path(__file__).resolve().parent.parent diff --git a/examples/lightrag_oracle_demo.py b/examples/lightrag_oracle_demo.py index ed843299..c6121840 100644 --- a/examples/lightrag_oracle_demo.py +++ b/examples/lightrag_oracle_demo.py @@ -6,7 +6,6 @@ from lightrag import LightRAG, QueryParam from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc import numpy as np -from lightrag.kg.oracle_impl import OracleDB print(os.getcwd()) script_directory = Path(__file__).resolve().parent.parent diff --git a/examples/lightrag_tidb_demo.py b/examples/lightrag_tidb_demo.py index 597aeeb2..f4004f84 100644 --- a/examples/lightrag_tidb_demo.py +++ b/examples/lightrag_tidb_demo.py @@ -4,7 +4,6 @@ import os import numpy as np from lightrag import LightRAG, QueryParam -from lightrag.kg.tidb_impl import TiDB from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache from lightrag.utils import EmbeddingFunc diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index a1fd2cab..913361b3 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -5,7 +5,6 @@ import time from dotenv import load_dotenv from lightrag import LightRAG, QueryParam -from lightrag.kg.postgres_impl import PostgreSQLDB from lightrag.llm.zhipu import zhipu_complete from lightrag.llm.ollama import ollama_embedding from lightrag.utils import EmbeddingFunc diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b3a72d4d..897c46f2 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -31,13 +31,17 @@ import configparser import traceback from datetime import datetime +root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +print(root_path) +sys.path.append(root_path) + from lightrag import LightRAG, QueryParam from lightrag.base import DocProcessingStatus, DocStatus from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ from lightrag.utils import EmbeddingFunc from lightrag.utils import logger -from .ollama_api import OllamaAPI, ollama_server_infos +from ollama_api import OllamaAPI, ollama_server_infos # Load environment variables diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 8cfc84b9..b02c8a8c 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -27,8 +27,7 @@ if not pm.is_installed("motor"): pm.install("motor") try: - from motor.motor_asyncio import AsyncIOMotorClient - from pymongo import MongoClient + from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase from pymongo.operations import SearchIndexModel from pymongo.errors import PyMongoError except ImportError as e: @@ -40,31 +39,59 @@ config = configparser.ConfigParser() config.read("config.ini", "utf-8") +class ClientManager: + _instances = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @classmethod + async def get_client(cls) -> AsyncIOMotorDatabase: + async with cls._lock: + if cls._instances["db"] is None: + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", + "uri", + fallback="mongodb://root:root@localhost:27017/", + ), + ) + database_name = os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + client = AsyncIOMotorClient(uri) + db = client.get_database(database_name) + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: AsyncIOMotorDatabase): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + cls._instances["db"] = None + + @final @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) - self._collection_name = self.namespace - self._data = database.get_collection(self._collection_name) - logger.debug(f"Use MongoDB as KV {self._collection_name}") + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) + logger.debug(f"Use MongoDB as KV {self._collection_name}") - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def get_by_id(self, id: str) -> dict[str, Any] | None: return await self._data.find_one({"_id": id}) @@ -122,27 +149,18 @@ class MongoKVStorage(BaseKVStorage): @dataclass class MongoDocStatusStorage(DocStatusStorage): def __post_init__(self): - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) - self._collection_name = self.namespace - self._data = database.get_collection(self._collection_name) - logger.debug(f"Use MongoDB as doc status {self._collection_name}") + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) + logger.debug(f"Use MongoDB as DocStatus {self._collection_name}") - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return await self._data.find_one({"_id": id}) @@ -212,27 +230,20 @@ class MongoGraphStorage(BaseGraphStorage): global_config=global_config, embedding_func=embedding_func, ) - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) - self._collection_name = self.namespace - self.collection = database.get_collection(self._collection_name) - logger.debug(f"Use MongoDB as KG {self._collection_name}") + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + self.collection = await get_or_create_collection( + self.db, self._collection_name + ) + logger.debug(f"Use MongoDB as KG {self._collection_name}") - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None # # ------------------------------------------------------------------------- @@ -779,40 +790,26 @@ class MongoVectorDBStorage(BaseVectorStorage): "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" ) self.cosine_better_than_threshold = cosine_threshold - - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" - ), - ) - client = AsyncIOMotorClient(uri) - database = client.get_database( - os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - ) - self._collection_name = self.namespace - self._data = database.get_collection(self._collection_name) self._max_batch_size = self.global_config["embedding_batch_num"] - logger.debug(f"Use MongoDB as VDB {self._collection_name}") + async def initialize(self): + if not hasattr(self, "db") or self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) - # Ensure collection exists - create_collection_if_not_exists(uri, database.name, self._collection_name) + # Ensure vector index exists + await self.create_vector_index() - # Ensure vector index exists - self.create_vector_index(uri, database.name, self._collection_name) + logger.debug(f"Use MongoDB as VDB {self._collection_name}") - def create_vector_index(self, uri: str, database_name: str, collection_name: str): + async def finalize(self): + if hasattr(self, "db") and self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + + async def create_vector_index(self): """Creates an Atlas Vector Search index.""" - client = MongoClient(uri) - collection = client.get_database(database_name).get_collection( - self._collection_name - ) - try: search_index_model = SearchIndexModel( definition={ @@ -829,7 +826,7 @@ class MongoVectorDBStorage(BaseVectorStorage): type="vectorSearch", ) - collection.create_search_index(search_index_model) + await self._data.create_search_index(search_index_model) logger.info("Vector index created successfully.") except PyMongoError as _: @@ -923,15 +920,13 @@ class MongoVectorDBStorage(BaseVectorStorage): raise NotImplementedError -def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): - """Check if the collection exists. if not, create it.""" - client = MongoClient(uri) - database = client.get_database(database_name) - - collection_names = database.list_collection_names() +async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str): + collection_names = await db.list_collection_names() if collection_name not in collection_names: - database.create_collection(collection_name) + collection = await db.create_collection(collection_name) logger.info(f"Created collection: {collection_name}") + return collection else: logger.debug(f"Collection '{collection_name}' already exists.") + return db.get_collection(collection_name) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index d18e4d7a..ca1ac7ce 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -609,9 +609,9 @@ class LightRAG: tasks.append(storage.finalize()) await asyncio.gather(*tasks) - logger.debug("Finalized Storages") self.storages_status = StoragesStatus.FINALIZED + logger.debug("Finalized Storages") async def get_graph_labels(self): text = await self.chunk_entity_relation_graph.get_all_labels()