|
|
|
@@ -27,8 +27,7 @@ if not pm.is_installed("motor"):
|
|
|
|
|
pm.install("motor")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
|
|
|
from pymongo import MongoClient
|
|
|
|
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
|
|
|
|
from pymongo.operations import SearchIndexModel
|
|
|
|
|
from pymongo.errors import PyMongoError
|
|
|
|
|
except ImportError as e:
|
|
|
|
@@ -40,31 +39,59 @@ config = configparser.ConfigParser()
|
|
|
|
|
config.read("config.ini", "utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClientManager:
|
|
|
|
|
_instances = {"db": None, "ref_count": 0}
|
|
|
|
|
_lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def get_client(cls) -> AsyncIOMotorDatabase:
|
|
|
|
|
async with cls._lock:
|
|
|
|
|
if cls._instances["db"] is None:
|
|
|
|
|
uri = os.environ.get(
|
|
|
|
|
"MONGO_URI",
|
|
|
|
|
config.get(
|
|
|
|
|
"mongodb",
|
|
|
|
|
"uri",
|
|
|
|
|
fallback="mongodb://root:root@localhost:27017/",
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
database_name = os.environ.get(
|
|
|
|
|
"MONGO_DATABASE",
|
|
|
|
|
config.get("mongodb", "database", fallback="LightRAG"),
|
|
|
|
|
)
|
|
|
|
|
client = AsyncIOMotorClient(uri)
|
|
|
|
|
db = client.get_database(database_name)
|
|
|
|
|
cls._instances["db"] = db
|
|
|
|
|
cls._instances["ref_count"] = 0
|
|
|
|
|
cls._instances["ref_count"] += 1
|
|
|
|
|
return cls._instances["db"]
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def release_client(cls, db: AsyncIOMotorDatabase):
|
|
|
|
|
async with cls._lock:
|
|
|
|
|
if db is not None:
|
|
|
|
|
if db is cls._instances["db"]:
|
|
|
|
|
cls._instances["ref_count"] -= 1
|
|
|
|
|
if cls._instances["ref_count"] == 0:
|
|
|
|
|
cls._instances["db"] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@final
|
|
|
|
|
@dataclass
|
|
|
|
|
class MongoKVStorage(BaseKVStorage):
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
uri = os.environ.get(
|
|
|
|
|
"MONGO_URI",
|
|
|
|
|
config.get(
|
|
|
|
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
client = AsyncIOMotorClient(uri)
|
|
|
|
|
database = client.get_database(
|
|
|
|
|
os.environ.get(
|
|
|
|
|
"MONGO_DATABASE",
|
|
|
|
|
config.get("mongodb", "database", fallback="LightRAG"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._collection_name = self.namespace
|
|
|
|
|
|
|
|
|
|
self._data = database.get_collection(self._collection_name)
|
|
|
|
|
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
|
|
|
|
async def initialize(self):
|
|
|
|
|
if not hasattr(self, "db") or self.db is None:
|
|
|
|
|
self.db = await ClientManager.get_client()
|
|
|
|
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
|
|
|
|
logger.debug(f"Use MongoDB as KV {self._collection_name}")
|
|
|
|
|
|
|
|
|
|
# Ensure collection exists
|
|
|
|
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
|
|
|
|
async def finalize(self):
|
|
|
|
|
if hasattr(self, "db") and self.db is not None:
|
|
|
|
|
await ClientManager.release_client(self.db)
|
|
|
|
|
self.db = None
|
|
|
|
|
|
|
|
|
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
|
|
|
|
return await self._data.find_one({"_id": id})
|
|
|
|
@@ -122,27 +149,18 @@ class MongoKVStorage(BaseKVStorage):
|
|
|
|
|
@dataclass
|
|
|
|
|
class MongoDocStatusStorage(DocStatusStorage):
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
uri = os.environ.get(
|
|
|
|
|
"MONGO_URI",
|
|
|
|
|
config.get(
|
|
|
|
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
client = AsyncIOMotorClient(uri)
|
|
|
|
|
database = client.get_database(
|
|
|
|
|
os.environ.get(
|
|
|
|
|
"MONGO_DATABASE",
|
|
|
|
|
config.get("mongodb", "database", fallback="LightRAG"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._collection_name = self.namespace
|
|
|
|
|
self._data = database.get_collection(self._collection_name)
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
|
|
|
|
|
async def initialize(self):
|
|
|
|
|
if not hasattr(self, "db") or self.db is None:
|
|
|
|
|
self.db = await ClientManager.get_client()
|
|
|
|
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
|
|
|
|
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
|
|
|
|
|
|
|
|
|
|
# Ensure collection exists
|
|
|
|
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
|
|
|
|
async def finalize(self):
|
|
|
|
|
if hasattr(self, "db") and self.db is not None:
|
|
|
|
|
await ClientManager.release_client(self.db)
|
|
|
|
|
self.db = None
|
|
|
|
|
|
|
|
|
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
|
|
|
|
return await self._data.find_one({"_id": id})
|
|
|
|
@@ -212,27 +230,20 @@ class MongoGraphStorage(BaseGraphStorage):
|
|
|
|
|
global_config=global_config,
|
|
|
|
|
embedding_func=embedding_func,
|
|
|
|
|
)
|
|
|
|
|
uri = os.environ.get(
|
|
|
|
|
"MONGO_URI",
|
|
|
|
|
config.get(
|
|
|
|
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
client = AsyncIOMotorClient(uri)
|
|
|
|
|
database = client.get_database(
|
|
|
|
|
os.environ.get(
|
|
|
|
|
"MONGO_DATABASE",
|
|
|
|
|
config.get("mongodb", "database", fallback="LightRAG"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._collection_name = self.namespace
|
|
|
|
|
self.collection = database.get_collection(self._collection_name)
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
|
|
|
|
async def initialize(self):
|
|
|
|
|
if not hasattr(self, "db") or self.db is None:
|
|
|
|
|
self.db = await ClientManager.get_client()
|
|
|
|
|
self.collection = await get_or_create_collection(
|
|
|
|
|
self.db, self._collection_name
|
|
|
|
|
)
|
|
|
|
|
logger.debug(f"Use MongoDB as KG {self._collection_name}")
|
|
|
|
|
|
|
|
|
|
# Ensure collection exists
|
|
|
|
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
|
|
|
|
async def finalize(self):
|
|
|
|
|
if hasattr(self, "db") and self.db is not None:
|
|
|
|
|
await ClientManager.release_client(self.db)
|
|
|
|
|
self.db = None
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
# -------------------------------------------------------------------------
|
|
|
|
@@ -779,40 +790,26 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
|
|
|
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
|
|
|
|
|
)
|
|
|
|
|
self.cosine_better_than_threshold = cosine_threshold
|
|
|
|
|
|
|
|
|
|
uri = os.environ.get(
|
|
|
|
|
"MONGO_URI",
|
|
|
|
|
config.get(
|
|
|
|
|
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
client = AsyncIOMotorClient(uri)
|
|
|
|
|
database = client.get_database(
|
|
|
|
|
os.environ.get(
|
|
|
|
|
"MONGO_DATABASE",
|
|
|
|
|
config.get("mongodb", "database", fallback="LightRAG"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._collection_name = self.namespace
|
|
|
|
|
self._data = database.get_collection(self._collection_name)
|
|
|
|
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
|
|
|
|
async def initialize(self):
|
|
|
|
|
if not hasattr(self, "db") or self.db is None:
|
|
|
|
|
self.db = await ClientManager.get_client()
|
|
|
|
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
|
|
|
|
|
|
|
|
|
# Ensure collection exists
|
|
|
|
|
create_collection_if_not_exists(uri, database.name, self._collection_name)
|
|
|
|
|
# Ensure vector index exists
|
|
|
|
|
await self.create_vector_index()
|
|
|
|
|
|
|
|
|
|
# Ensure vector index exists
|
|
|
|
|
self.create_vector_index(uri, database.name, self._collection_name)
|
|
|
|
|
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
|
|
|
|
|
|
|
|
|
|
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
|
|
|
|
|
async def finalize(self):
|
|
|
|
|
if hasattr(self, "db") and self.db is not None:
|
|
|
|
|
await ClientManager.release_client(self.db)
|
|
|
|
|
self.db = None
|
|
|
|
|
|
|
|
|
|
async def create_vector_index(self):
|
|
|
|
|
"""Creates an Atlas Vector Search index."""
|
|
|
|
|
client = MongoClient(uri)
|
|
|
|
|
collection = client.get_database(database_name).get_collection(
|
|
|
|
|
self._collection_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
search_index_model = SearchIndexModel(
|
|
|
|
|
definition={
|
|
|
|
@@ -829,7 +826,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
|
|
|
type="vectorSearch",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
collection.create_search_index(search_index_model)
|
|
|
|
|
await self._data.create_search_index(search_index_model)
|
|
|
|
|
logger.info("Vector index created successfully.")
|
|
|
|
|
|
|
|
|
|
except PyMongoError as _:
|
|
|
|
@@ -923,15 +920,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
|
|
|
|
|
"""Check if the collection exists. if not, create it."""
|
|
|
|
|
client = MongoClient(uri)
|
|
|
|
|
database = client.get_database(database_name)
|
|
|
|
|
|
|
|
|
|
collection_names = database.list_collection_names()
|
|
|
|
|
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
|
|
|
|
|
collection_names = await db.list_collection_names()
|
|
|
|
|
|
|
|
|
|
if collection_name not in collection_names:
|
|
|
|
|
database.create_collection(collection_name)
|
|
|
|
|
collection = await db.create_collection(collection_name)
|
|
|
|
|
logger.info(f"Created collection: {collection_name}")
|
|
|
|
|
return collection
|
|
|
|
|
else:
|
|
|
|
|
logger.debug(f"Collection '{collection_name}' already exists.")
|
|
|
|
|
return db.get_collection(collection_name)
|
|
|
|
|