improve MongoDB client management and storage init
This commit is contained in:
@@ -31,13 +31,17 @@ import configparser
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
print(root_path)
|
||||
sys.path.append(root_path)
|
||||
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.base import DocProcessingStatus, DocStatus
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
from lightrag.utils import EmbeddingFunc
|
||||
from lightrag.utils import logger
|
||||
from .ollama_api import OllamaAPI, ollama_server_infos
|
||||
from ollama_api import OllamaAPI, ollama_server_infos
|
||||
|
||||
|
||||
# Load environment variables
|
||||
|
@@ -27,8 +27,7 @@ if not pm.is_installed("motor"):
|
||||
pm.install("motor")
|
||||
|
||||
try:
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from pymongo.errors import PyMongoError
|
||||
except ImportError as e:
|
||||
@@ -40,31 +39,59 @@ config = configparser.ConfigParser()
|
||||
config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
class ClientManager:
|
||||
_instances = {"db": None, "ref_count": 0}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@classmethod
|
||||
async def get_client(cls) -> AsyncIOMotorDatabase:
|
||||
async with cls._lock:
|
||||
if cls._instances["db"] is None:
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb",
|
||||
"uri",
|
||||
fallback="mongodb://root:root@localhost:27017/",
|
||||
),
|
||||
)
|
||||
database_name = os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
db = client.get_database(database_name)
|
||||
cls._instances["db"] = db
|
||||
cls._instances["ref_count"] = 0
|
||||
cls._instances["ref_count"] += 1
|
||||
return cls._instances["db"]
|
||||
|
||||
@classmethod
|
||||
async def release_client(cls, db: AsyncIOMotorDatabase):
|
||||
async with cls._lock:
|
||||
if db is not None:
|
||||
if db is cls._instances["db"]:
|
||||
cls._instances["ref_count"] -= 1
|
||||
if cls._instances["ref_count"] == 0:
|
||||
cls._instances["db"] = None
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class MongoKVStorage(BaseKVStorage):
|
||||
def __post_init__(self):
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||
return await self._data.find_one({"_id": id})
|
||||
@@ -122,27 +149,18 @@ class MongoKVStorage(BaseKVStorage):
|
||||
@dataclass
|
||||
class MongoDocStatusStorage(DocStatusStorage):
|
||||
def __post_init__(self):
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
|
||||
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||
return await self._data.find_one({"_id": id})
|
||||
@@ -212,27 +230,20 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||
global_config=global_config,
|
||||
embedding_func=embedding_func,
|
||||
)
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self.collection = database.get_collection(self._collection_name)
|
||||
|
||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self.collection = await get_or_create_collection(
|
||||
self.db, self._collection_name
|
||||
)
|
||||
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
#
|
||||
# -------------------------------------------------------------------------
|
||||
@@ -779,40 +790,26 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
||||
)
|
||||
self.cosine_better_than_threshold = cosine_threshold
|
||||
|
||||
uri = os.environ.get(
|
||||
"MONGO_URI",
|
||||
config.get(
|
||||
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
||||
),
|
||||
)
|
||||
client = AsyncIOMotorClient(uri)
|
||||
database = client.get_database(
|
||||
os.environ.get(
|
||||
"MONGO_DATABASE",
|
||||
config.get("mongodb", "database", fallback="LightRAG"),
|
||||
)
|
||||
)
|
||||
|
||||
self._collection_name = self.namespace
|
||||
self._data = database.get_collection(self._collection_name)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||
async def initialize(self):
|
||||
if not hasattr(self, "db") or self.db is None:
|
||||
self.db = await ClientManager.get_client()
|
||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||
|
||||
# Ensure collection exists
|
||||
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
||||
# Ensure vector index exists
|
||||
await self.create_vector_index()
|
||||
|
||||
# Ensure vector index exists
|
||||
self.create_vector_index(uri, database.name, self._collection_name)
|
||||
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
||||
|
||||
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
||||
async def finalize(self):
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
await ClientManager.release_client(self.db)
|
||||
self.db = None
|
||||
|
||||
async def create_vector_index(self):
|
||||
"""Creates an Atlas Vector Search index."""
|
||||
client = MongoClient(uri)
|
||||
collection = client.get_database(database_name).get_collection(
|
||||
self._collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
search_index_model = SearchIndexModel(
|
||||
definition={
|
||||
@@ -829,7 +826,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
type="vectorSearch",
|
||||
)
|
||||
|
||||
collection.create_search_index(search_index_model)
|
||||
await self._data.create_search_index(search_index_model)
|
||||
logger.info("Vector index created successfully.")
|
||||
|
||||
except PyMongoError as _:
|
||||
@@ -923,15 +920,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
||||
"""Check if the collection exists. if not, create it."""
|
||||
client = MongoClient(uri)
|
||||
database = client.get_database(database_name)
|
||||
|
||||
collection_names = database.list_collection_names()
|
||||
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
||||
collection_names = await db.list_collection_names()
|
||||
|
||||
if collection_name not in collection_names:
|
||||
database.create_collection(collection_name)
|
||||
collection = await db.create_collection(collection_name)
|
||||
logger.info(f"Created collection: {collection_name}")
|
||||
return collection
|
||||
else:
|
||||
logger.debug(f"Collection '{collection_name}' already exists.")
|
||||
return db.get_collection(collection_name)
|
||||
|
@@ -609,9 +609,9 @@ class LightRAG:
|
||||
tasks.append(storage.finalize())
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
logger.debug("Finalized Storages")
|
||||
|
||||
self.storages_status = StoragesStatus.FINALIZED
|
||||
logger.debug("Finalized Storages")
|
||||
|
||||
async def get_graph_labels(self):
|
||||
text = await self.chunk_entity_relation_graph.get_all_labels()
|
||||
|
Reference in New Issue
Block a user