Merge pull request #846 from ArnoChenFx/db-connection-and-storage-lifecycle

Refactor Database Connection Management and Improve Storage Lifecycle Handling
This commit is contained in:
Yannick Stephan
2025-02-18 22:39:31 +01:00
committed by GitHub
11 changed files with 540 additions and 416 deletions

View File

@@ -17,7 +17,6 @@ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
@@ -48,6 +47,14 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
os.environ["ORACLE_USER"] = ""
os.environ["ORACLE_PASSWORD"] = ""
os.environ["ORACLE_DSN"] = ""
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
os.environ["ORACLE_WORKSPACE"] = "company"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
@@ -89,20 +96,6 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "path_to_config_dir",
"wallet_location": "path_to_wallet_location",
"wallet_password": "wallet_password",
"workspace": "company",
} # specify which docs you want to store and query
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -121,11 +114,6 @@ async def init():
vector_storage="OracleVectorDBStorage",
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
rag.key_string_value_json_storage_cls.db = oracle_db
rag.vector_db_storage_cls.db = oracle_db
return rag

View File

@@ -6,7 +6,6 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent
@@ -26,6 +25,14 @@ MAX_TOKENS = 4000
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
os.environ["ORACLE_USER"] = "username"
os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
os.environ["ORACLE_WORKSPACE"] = "company"
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
@@ -63,26 +70,6 @@ async def main():
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "username",
"password": "xxxxxxxxx",
"dsn": "xxxxxxx_medium",
"config_dir": "dir/path/to/oracle/config",
"wallet_location": "dir/path/to/oracle/wallet",
"wallet_password": "xxxxxxxxx",
"workspace": "company", # specify which docs you want to store and query
}
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -112,26 +99,6 @@ async def main():
},
)
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
for storage in [
rag.vector_db_storage_cls,
rag.graph_storage_cls,
rag.doc_status,
rag.full_docs,
rag.text_chunks,
rag.llm_response_cache,
rag.key_string_value_json_storage_cls,
rag.chunks_vdb,
rag.relationships_vdb,
rag.entities_vdb,
rag.graph_storage_cls,
rag.chunk_entity_relation_graph,
rag.llm_response_cache,
]:
# set client
storage.db = oracle_db
# Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read()

View File

@@ -4,7 +4,6 @@ import os
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.kg.tidb_impl import TiDB
from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
from lightrag.utils import EmbeddingFunc
@@ -17,11 +16,11 @@ APIKEY = ""
CHATMODEL = ""
EMBEDMODEL = ""
TIDB_HOST = ""
TIDB_PORT = ""
TIDB_USER = ""
TIDB_PASSWORD = ""
TIDB_DATABASE = "lightrag"
os.environ["TIDB_HOST"] = ""
os.environ["TIDB_PORT"] = ""
os.environ["TIDB_USER"] = ""
os.environ["TIDB_PASSWORD"] = ""
os.environ["TIDB_DATABASE"] = "lightrag"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
@@ -62,21 +61,6 @@ async def main():
embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}")
# Create TiDB DB connection
tidb = TiDB(
config={
"host": TIDB_HOST,
"port": TIDB_PORT,
"user": TIDB_USER,
"password": TIDB_PASSWORD,
"database": TIDB_DATABASE,
"workspace": "company", # specify which docs you want to store and query
}
)
# Check if TiDB DB tables exist, if not, tables will be created
await tidb.check_tables()
# Initialize LightRAG
# We use TiDB DB as the KV/vector
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -95,15 +79,6 @@ async def main():
graph_storage="TiDBGraphStorage",
)
if rag.llm_response_cache:
rag.llm_response_cache.db = tidb
rag.full_docs.db = tidb
rag.text_chunks.db = tidb
rag.entities_vdb.db = tidb
rag.relationships_vdb.db = tidb
rag.chunks_vdb.db = tidb
rag.chunk_entity_relation_graph.db = tidb
# Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read())

View File

@@ -5,7 +5,6 @@ import time
from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam
from lightrag.kg.postgres_impl import PostgreSQLDB
from lightrag.llm.zhipu import zhipu_complete
from lightrag.llm.ollama import ollama_embedding
from lightrag.utils import EmbeddingFunc
@@ -22,22 +21,14 @@ if not os.path.exists(WORKING_DIR):
# AGE
os.environ["AGE_GRAPH_NAME"] = "dickens"
postgres_db = PostgreSQLDB(
config={
"host": "localhost",
"port": 15432,
"user": "rag",
"password": "rag",
"database": "rag",
}
)
os.environ["POSTGRES_HOST"] = "localhost"
os.environ["POSTGRES_PORT"] = "15432"
os.environ["POSTGRES_USER"] = "rag"
os.environ["POSTGRES_PASSWORD"] = "rag"
os.environ["POSTGRES_DATABASE"] = "rag"
async def main():
await postgres_db.initdb()
# Check if PostgreSQL DB tables exist, if not, tables will be created
await postgres_db.check_tables()
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=zhipu_complete,
@@ -57,17 +48,7 @@ async def main():
graph_storage="PGGraphStorage",
vector_storage="PGVectorStorage",
)
# Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.doc_status.db = postgres_db
rag.full_docs.db = postgres_db
rag.text_chunks.db = postgres_db
rag.llm_response_cache.db = postgres_db
rag.key_string_value_json_storage_cls.db = postgres_db
rag.chunks_vdb.db = postgres_db
rag.relationships_vdb.db = postgres_db
rag.entities_vdb.db = postgres_db
rag.graph_storage_cls.db = postgres_db
rag.chunk_entity_relation_graph.db = postgres_db
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func

View File

@@ -15,11 +15,6 @@ import logging
import argparse
from typing import List, Any, Literal, Optional, Dict
from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from pathlib import Path
import shutil
import aiofiles
@@ -36,39 +31,13 @@ import configparser
import traceback
from datetime import datetime
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger
from .ollama_api import (
OllamaAPI,
)
from .ollama_api import ollama_server_infos
def get_db_type_from_storage_class(class_name: str) -> str | None:
"""Determine database type based on storage class name"""
if class_name.startswith("PG"):
return "postgres"
elif class_name.startswith("Oracle"):
return "oracle"
elif class_name.startswith("TiDB"):
return "tidb"
return None
def import_db_module(db_type: str):
"""Dynamically import database module"""
if db_type == "postgres":
from ..kg.postgres_impl import PostgreSQLDB
return PostgreSQLDB
elif db_type == "oracle":
from ..kg.oracle_impl import OracleDB
return OracleDB
elif db_type == "tidb":
from ..kg.tidb_impl import TiDB
return TiDB
return None
from .ollama_api import OllamaAPI, ollama_server_infos
# Load environment variables
@@ -929,52 +898,12 @@ def create_app(args):
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
# Initialize database connections
db_instances = {}
# Store background tasks
app.state.background_tasks = set()
try:
# Check which database types are used
db_types = set()
for storage_name, storage_instance in storage_instances:
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
db_types.add(db_type)
# Import and initialize databases as needed
for db_type in db_types:
if db_type == "postgres":
DB = import_db_module("postgres")
db = DB(_get_postgres_config())
await db.initdb()
await db.check_tables()
db_instances["postgres"] = db
elif db_type == "oracle":
DB = import_db_module("oracle")
db = DB(_get_oracle_config())
await db.check_tables()
db_instances["oracle"] = db
elif db_type == "tidb":
DB = import_db_module("tidb")
db = DB(_get_tidb_config())
await db.check_tables()
db_instances["tidb"] = db
# Inject database instances into storage classes
for storage_name, storage_instance in storage_instances:
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
if db_type not in db_instances:
error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
logger.error(error_msg)
raise RuntimeError(error_msg)
storage_instance.db = db_instances[db_type]
logger.info(f"Injected {db_type} db to {storage_name}")
# Initialize database connections
await rag.initialize_storages()
# Auto scan documents if enabled
if args.auto_scan_at_startup:
@@ -1000,17 +929,7 @@ def create_app(args):
finally:
# Clean up database connections
for db_type, db in db_instances.items():
if hasattr(db, "pool"):
await db.pool.close()
# Use more accurate database name display
db_names = {
"postgres": "PostgreSQL",
"oracle": "Oracle",
"tidb": "TiDB",
}
db_name = db_names.get(db_type, db_type)
logger.info(f"Closed {db_name} database connection pool")
await rag.finalize_storages()
# Initialize FastAPI
app = FastAPI(
@@ -1042,92 +961,6 @@ def create_app(args):
allow_headers=["*"],
)
# Database configuration functions
def _get_postgres_config():
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config():
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config():
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)
@@ -1262,6 +1095,7 @@ def create_app(args):
},
log_level=args.log_level,
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
)
else:
rag = LightRAG(
@@ -1293,20 +1127,9 @@ def create_app(args):
},
log_level=args.log_level,
namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
)
# Collect all storage instances
storage_instances = [
("full_docs", rag.full_docs),
("text_chunks", rag.text_chunks),
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
("entities_vdb", rag.entities_vdb),
("relationships_vdb", rag.relationships_vdb),
("chunks_vdb", rag.chunks_vdb),
("doc_status", rag.doc_status),
("llm_response_cache", rag.llm_response_cache),
]
async def pipeline_enqueue_file(file_path: Path) -> bool:
"""Add a file to the queue for processing

View File

@@ -87,6 +87,14 @@ class StorageNameSpace(ABC):
namespace: str
global_config: dict[str, Any]
async def initialize(self):
"""Initialize the storage"""
pass
async def finalize(self):
"""Finalize the storage"""
pass
@abstractmethod
async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing"""
@@ -247,3 +255,12 @@ class DocStatusStorage(BaseKVStorage, ABC):
self, status: DocStatus
) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status"""
class StoragesStatus(str, Enum):
"""Storages status"""
NOT_CREATED = "not_created"
CREATED = "created"
INITIALIZED = "initialized"
FINALIZED = "finalized"

View File

@@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
import numpy as np
import configparser
import asyncio
@@ -26,8 +26,11 @@ if not pm.is_installed("motor"):
pm.install("motor")
try:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError
except ImportError as e:
@@ -39,31 +42,63 @@ config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod
async def get_client(cls) -> AsyncIOMotorDatabase:
async with cls._lock:
if cls._instances["db"] is None:
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb",
"uri",
fallback="mongodb://root:root@localhost:27017/",
),
)
database_name = os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
client = AsyncIOMotorClient(uri)
db = client.get_database(database_name)
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: AsyncIOMotorDatabase):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final
@dataclass
class MongoKVStorage(BaseKVStorage):
def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self):
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id})
@@ -120,28 +155,23 @@ class MongoKVStorage(BaseKVStorage):
@final
@dataclass
class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as doc status {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id})
@@ -202,36 +232,33 @@ class MongoDocStatusStorage(DocStatusStorage):
@dataclass
class MongoGraphStorage(BaseGraphStorage):
"""
A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries.
A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
"""
db: AsyncIOMotorDatabase = field(default=None)
collection: AsyncIOMotorCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self.collection = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection(
self.db, self._collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self.collection = None
#
# -------------------------------------------------------------------------
@@ -770,6 +797,9 @@ class MongoGraphStorage(BaseGraphStorage):
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -778,41 +808,36 @@ class MongoVectorDBStorage(BaseVectorStorage):
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
self._max_batch_size = self.global_config["embedding_batch_num"]
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
# Ensure collection exists
create_collection_if_not_exists(uri, database.name, self._collection_name)
# Ensure vector index exists
await self.create_vector_index_if_not_exists()
# Ensure vector index exists
self.create_vector_index(uri, database.name, self._collection_name)
logger.debug(f"Use MongoDB as VDB {self._collection_name}")
def create_vector_index(self, uri: str, database_name: str, collection_name: str):
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index."""
client = MongoClient(uri)
collection = client.get_database(database_name).get_collection(
self._collection_name
)
try:
index_name = "vector_knn_index"
indexes = await self._data.list_search_indexes().to_list(length=None)
for index in indexes:
if index["name"] == index_name:
logger.debug("vector index already exist")
return
search_index_model = SearchIndexModel(
definition={
"fields": [
@@ -824,11 +849,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
}
]
},
name="vector_knn_index",
name=index_name,
type="vectorSearch",
)
collection.create_search_index(search_index_model)
await self._data.create_search_index(search_index_model)
logger.info("Vector index created successfully.")
except PyMongoError as _:
@@ -913,15 +938,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
"""Check if the collection exists. if not, create it."""
client = MongoClient(uri)
database = client.get_database(database_name)
collection_names = database.list_collection_names()
async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
collection_names = await db.list_collection_names()
if collection_name not in collection_names:
database.create_collection(collection_name)
collection = await db.create_collection(collection_name)
logger.info(f"Created collection: {collection_name}")
return collection
else:
logger.debug(f"Collection '{collection_name}' already exists.")
return db.get_collection(collection_name)

View File

@@ -2,11 +2,11 @@ import array
import asyncio
# import html
# import os
from dataclasses import dataclass
import os
from dataclasses import dataclass, field
from typing import Any, Union, final
import numpy as np
import configparser
from lightrag.types import KnowledgeGraph
@@ -177,17 +177,91 @@ class OracleDB:
raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> OracleDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = OracleDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: OracleDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed OracleDB database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final
@dataclass
class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: OracleDB
db: OracleDB = field(default=None)
meta_fields = None
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -324,6 +398,8 @@ class OracleKVStorage(BaseKVStorage):
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(default=None)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
@@ -333,6 +409,15 @@ class OracleVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query])
@@ -369,9 +454,20 @@ class OracleVectorDBStorage(BaseVectorStorage):
@final
@dataclass
class OracleGraphStorage(BaseGraphStorage):
db: OracleDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:

View File

@@ -3,10 +3,10 @@ import inspect
import json
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final
import numpy as np
import configparser
from lightrag.types import KnowledgeGraph
@@ -181,15 +181,84 @@ class PostgreSQLDB:
pass
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> PostgreSQLDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = PostgreSQLDB(config)
await db.initdb()
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: PostgreSQLDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed PostgreSQL database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final
@dataclass
class PGKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: PostgreSQLDB
db: PostgreSQLDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -308,6 +377,8 @@ class PGKVStorage(BaseKVStorage):
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -318,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks(self, item: dict):
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
@@ -426,6 +506,17 @@ class PGVectorStorage(BaseVectorStorage):
@final
@dataclass
class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(default=None)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format(
@@ -565,6 +656,8 @@ class PGGraphQueryException(Exception):
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(default=None)
@staticmethod
def load_nx_graph(file_name):
print("no preloading of graph with AGE in production")
@@ -575,6 +668,15 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed,
}
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None:
# PG handles persistence automatically
pass

View File

@@ -1,6 +1,6 @@
import asyncio
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Union, final
import numpy as np
@@ -13,6 +13,7 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger
import pipmaster as pm
import configparser
if not pm.is_installed("pymysql"):
pm.install("pymysql")
@@ -104,16 +105,81 @@ class TiDB:
raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> TiDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = TiDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: TiDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final
@dataclass
class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use
# db: TiDB
db: TiDB = field(default=None)
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -184,7 +250,7 @@ class TiDBKVStorage(BaseKVStorage):
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": f'{item["__vector__"].tolist()}',
"content_vector": f"{item['__vector__'].tolist()}",
"workspace": self.db.workspace,
}
)
@@ -212,6 +278,8 @@ class TiDBKVStorage(BaseKVStorage):
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(default=None)
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -225,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage):
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Search from tidb vector"""
embeddings = await self.embedding_func([query])
@@ -282,7 +359,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}',
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update entity_id if node inserted by graph_storage_instance before
@@ -304,7 +381,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}',
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update relation_id if node inserted by graph_storage_instance before
@@ -337,12 +414,20 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@final
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use
# db: TiDB
db: TiDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id

View File

@@ -17,6 +17,7 @@ from .base import (
DocStatusStorage,
QueryParam,
StorageNameSpace,
StoragesStatus,
)
from .namespace import NameSpace, make_namespace
from .operate import (
@@ -348,6 +349,10 @@ class LightRAG:
# Extensions
addon_params: dict[str, Any] = field(default_factory=dict)
# Storages Management
auto_manage_storages_states: bool = True
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
"""Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json
@@ -440,7 +445,10 @@ class LightRAG:
**self.vector_db_storage_cls_kwargs,
}
# show config
# Life cycle
self.storages_status = StoragesStatus.NOT_CREATED
# Show config
global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -547,6 +555,65 @@ class LightRAG:
)
)
self.storages_status = StoragesStatus.CREATED
# Initialize storages
if self.auto_manage_storages_states:
loop = always_get_an_event_loop()
loop.run_until_complete(self.initialize_storages())
def __del__(self):
# Finalize storages
if self.auto_manage_storages_states:
loop = always_get_an_event_loop()
loop.run_until_complete(self.finalize_storages())
async def initialize_storages(self):
"""Asynchronously initialize the storages"""
if self.storages_status == StoragesStatus.CREATED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.initialize())
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages")
async def finalize_storages(self):
"""Asynchronously finalize the storages"""
if self.storages_status == StoragesStatus.INITIALIZED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.finalize())
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages")
async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels()
return text