improve MongoDB client management and storage init

This commit is contained in:
ArnoChen
2025-02-19 04:30:52 +08:00
parent 7a970451b9
commit 6d8e627f85
7 changed files with 92 additions and 97 deletions

View File

@@ -17,7 +17,6 @@ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent

View File

@@ -6,7 +6,6 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent

View File

@@ -4,7 +4,6 @@ import os
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.kg.tidb_impl import TiDB
from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
from lightrag.utils import EmbeddingFunc

View File

@@ -5,7 +5,6 @@ import time
from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam
from lightrag.kg.postgres_impl import PostgreSQLDB
from lightrag.llm.zhipu import zhipu_complete
from lightrag.llm.ollama import ollama_embedding
from lightrag.utils import EmbeddingFunc

View File

@@ -31,13 +31,17 @@ import configparser
import traceback
from datetime import datetime
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
print(root_path)
sys.path.append(root_path)
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger
from .ollama_api import OllamaAPI, ollama_server_infos
from ollama_api import OllamaAPI, ollama_server_infos
# Load environment variables

View File

@@ -27,8 +27,7 @@ if not pm.is_installed("motor"):
pm.install("motor")
try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
@@ -40,31 +39,59 @@ config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod
async def get_client(cls) -> AsyncIOMotorDatabase:
async with cls._lock:
if cls._instances["db"] is None:
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb",
"uri",
fallback="mongodb://root:root@localhost:27017/",
),
)
database_name = os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
client = AsyncIOMotorClient(uri)
db = client.get_database(database_name)
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: AsyncIOMotorDatabase):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id})
@@ -122,27 +149,18 @@ class MongoKVStorage(BaseKVStorage):
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id})
@@ -212,27 +230,20 @@ class MongoGraphStorage(BaseGraphStorage):
global_config=global_config,
embedding_func=embedding_func,
)
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self.collection = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection(
self.db, self._collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#
# -------------------------------------------------------------------------
@@ -779,40 +790,26 @@ class MongoVectorDBStorage(BaseVectorStorage):
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
self._max_batch_size = self.global_config["embedding_batch_num"]
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
# Ensure vector index exists
await self.create_vector_index()
# Ensure vector index exists
self.create_vector_index(uri, database.name, self._collection_name)
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def create_vector_index(self):
"""Creates an Atlas Vector Search index."""
client = MongoClient(uri)
collection = client.get_database(database_name).get_collection(
self._collection_name
)
try:
search_index_model = SearchIndexModel(
definition={
@@ -829,7 +826,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
type="vectorSearch",
)
collection.create_search_index(search_index_model)
await self._data.create_search_index(search_index_model)
logger.info("Vector index created successfully.")
except PyMongoError as _:
@@ -923,15 +920,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it."""
client = MongoClient(uri)
database = client.get_database(database_name)
collection_names = database.list_collection_names()
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
collection_names = await db.list_collection_names()
if collection_name not in collection_names:
database.create_collection(collection_name)
collection = await db.create_collection(collection_name)
logger.info(f"Created collection: {collection_name}")
return collection
else:
logger.debug(f"Collection '{collection_name}' already exists.")
return db.get_collection(collection_name)

View File

@@ -609,9 +609,9 @@ class LightRAG:
tasks.append(storage.finalize())
await asyncio.gather(*tasks)
logger.debug("Finalized Storages")
self.storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages")
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()