improve MongoDB client management and storage init

This commit is contained in:
ArnoChen
2025-02-19 04:30:52 +08:00
parent 7a970451b9
commit 6d8e627f85
7 changed files with 92 additions and 97 deletions

View File

@@ -17,7 +17,6 @@ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
import numpy as np import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd()) print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent script_directory = Path(__file__).resolve().parent.parent

View File

@@ -6,7 +6,6 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
import numpy as np import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd()) print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent script_directory = Path(__file__).resolve().parent.parent

View File

@@ -4,7 +4,6 @@ import os
import numpy as np import numpy as np
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.kg.tidb_impl import TiDB
from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc

View File

@@ -5,7 +5,6 @@ import time
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.kg.postgres_impl import PostgreSQLDB
from lightrag.llm.zhipu import zhipu_complete from lightrag.llm.zhipu import zhipu_complete
from lightrag.llm.ollama import ollama_embedding from lightrag.llm.ollama import ollama_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc

View File

@@ -31,13 +31,17 @@ import configparser
import traceback import traceback
from datetime import datetime from datetime import datetime
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
print(root_path)
sys.path.append(root_path)
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__ from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger from lightrag.utils import logger
from .ollama_api import OllamaAPI, ollama_server_infos from ollama_api import OllamaAPI, ollama_server_infos
# Load environment variables # Load environment variables

View File

@@ -27,8 +27,7 @@ if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
try: try:
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from pymongo import MongoClient
from pymongo.operations import SearchIndexModel from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError from pymongo.errors import PyMongoError
except ImportError as e: except ImportError as e:
@@ -40,31 +39,59 @@ config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod
async def get_client(cls) -> AsyncIOMotorDatabase:
async with cls._lock:
if cls._instances["db"] is None:
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb",
"uri",
fallback="mongodb://root:root@localhost:27017/",
),
)
database_name = os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
client = AsyncIOMotorClient(uri)
db = client.get_database(database_name)
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: AsyncIOMotorDatabase):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final @final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name) async def initialize(self):
logger.debug(f"Use MongoDB as KV {self._collection_name}") if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
# Ensure collection exists async def finalize(self):
create_collection_if_not_exists(uri, database.name, self._collection_name) if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@@ -122,27 +149,18 @@ class MongoKVStorage(BaseKVStorage):
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
def __post_init__(self): def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as doc status {self._collection_name}") async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
# Ensure collection exists async def finalize(self):
create_collection_if_not_exists(uri, database.name, self._collection_name) if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@@ -212,27 +230,20 @@ class MongoGraphStorage(BaseGraphStorage):
global_config=global_config, global_config=global_config,
embedding_func=embedding_func, embedding_func=embedding_func,
) )
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace self._collection_name = self.namespace
self.collection = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KG {self._collection_name}") async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection(
self.db, self._collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
# Ensure collection exists async def finalize(self):
create_collection_if_not_exists(uri, database.name, self._collection_name) if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@@ -779,40 +790,26 @@ class MongoVectorDBStorage(BaseVectorStorage):
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
logger.debug(f"Use MongoDB as VDB {self._collection_name}") async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
# Ensure collection exists # Ensure vector index exists
create_collection_if_not_exists(uri, database.name, self._collection_name) await self.create_vector_index()
# Ensure vector index exists logger.debug(f"Use MongoDB as VDB {self._collection_name}")
self.create_vector_index(uri, database.name, self._collection_name)
def create_vector_index(self, uri: str, database_name: str, collection_name: str): async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def create_vector_index(self):
"""Creates an Atlas Vector Search index.""" """Creates an Atlas Vector Search index."""
client = MongoClient(uri)
collection = client.get_database(database_name).get_collection(
self._collection_name
)
try: try:
search_index_model = SearchIndexModel( search_index_model = SearchIndexModel(
definition={ definition={
@@ -829,7 +826,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
type="vectorSearch", type="vectorSearch",
) )
collection.create_search_index(search_index_model) await self._data.create_search_index(search_index_model)
logger.info("Vector index created successfully.") logger.info("Vector index created successfully.")
except PyMongoError as _: except PyMongoError as _:
@@ -923,15 +920,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
raise NotImplementedError raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
"""Check if the collection exists. if not, create it.""" collection_names = await db.list_collection_names()
client = MongoClient(uri)
database = client.get_database(database_name)
collection_names = database.list_collection_names()
if collection_name not in collection_names: if collection_name not in collection_names:
database.create_collection(collection_name) collection = await db.create_collection(collection_name)
logger.info(f"Created collection: {collection_name}") logger.info(f"Created collection: {collection_name}")
return collection
else: else:
logger.debug(f"Collection '{collection_name}' already exists.") logger.debug(f"Collection '{collection_name}' already exists.")
return db.get_collection(collection_name)

View File

@@ -609,9 +609,9 @@ class LightRAG:
tasks.append(storage.finalize()) tasks.append(storage.finalize())
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
logger.debug("Finalized Storages")
self.storages_status = StoragesStatus.FINALIZED self.storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages")
async def get_graph_labels(self): async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels() text = await self.chunk_entity_relation_graph.get_all_labels()