Merge pull request #846 from ArnoChenFx/db-connection-and-storage-lifecycle

Refactor Database Connection Management and Improve Storage Lifecycle Handling
This commit is contained in:
Yannick Stephan
2025-02-18 22:39:31 +01:00
committed by GitHub
11 changed files with 540 additions and 416 deletions

View File

@@ -17,7 +17,6 @@ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
import numpy as np import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd()) print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent script_directory = Path(__file__).resolve().parent.parent
@@ -48,6 +47,14 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
os.environ["ORACLE_USER"] = ""
os.environ["ORACLE_PASSWORD"] = ""
os.environ["ORACLE_DSN"] = ""
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
os.environ["ORACLE_WORKSPACE"] = "company"
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
@@ -89,20 +96,6 @@ async def init():
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "",
"password": "",
"dsn": "",
"config_dir": "path_to_config_dir",
"wallet_location": "path_to_wallet_location",
"wallet_password": "wallet_password",
"workspace": "company",
} # specify which docs you want to store and query
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -121,11 +114,6 @@ async def init():
vector_storage="OracleVectorDBStorage", vector_storage="OracleVectorDBStorage",
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.graph_storage_cls.db = oracle_db
rag.key_string_value_json_storage_cls.db = oracle_db
rag.vector_db_storage_cls.db = oracle_db
return rag return rag

View File

@@ -6,7 +6,6 @@ from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
import numpy as np import numpy as np
from lightrag.kg.oracle_impl import OracleDB
print(os.getcwd()) print(os.getcwd())
script_directory = Path(__file__).resolve().parent.parent script_directory = Path(__file__).resolve().parent.parent
@@ -26,6 +25,14 @@ MAX_TOKENS = 4000
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
os.environ["ORACLE_USER"] = "username"
os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
os.environ["ORACLE_WORKSPACE"] = "company"
async def llm_model_func( async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
@@ -63,26 +70,6 @@ async def main():
embedding_dimension = await get_embedding_dim() embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}") print(f"Detected embedding dimension: {embedding_dimension}")
# Create Oracle DB connection
# The `config` parameter is the connection configuration of Oracle DB
# More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
# We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
# Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
oracle_db = OracleDB(
config={
"user": "username",
"password": "xxxxxxxxx",
"dsn": "xxxxxxx_medium",
"config_dir": "dir/path/to/oracle/config",
"wallet_location": "dir/path/to/oracle/wallet",
"wallet_password": "xxxxxxxxx",
"workspace": "company", # specify which docs you want to store and query
}
)
# Check if Oracle DB tables exist, if not, tables will be created
await oracle_db.check_tables()
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -112,26 +99,6 @@ async def main():
}, },
) )
# Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
for storage in [
rag.vector_db_storage_cls,
rag.graph_storage_cls,
rag.doc_status,
rag.full_docs,
rag.text_chunks,
rag.llm_response_cache,
rag.key_string_value_json_storage_cls,
rag.chunks_vdb,
rag.relationships_vdb,
rag.entities_vdb,
rag.graph_storage_cls,
rag.chunk_entity_relation_graph,
rag.llm_response_cache,
]:
# set client
storage.db = oracle_db
# Extract and Insert into LightRAG storage # Extract and Insert into LightRAG storage
with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f: with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
all_text = f.read() all_text = f.read()

View File

@@ -4,7 +4,6 @@ import os
import numpy as np import numpy as np
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.kg.tidb_impl import TiDB
from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
@@ -17,11 +16,11 @@ APIKEY = ""
CHATMODEL = "" CHATMODEL = ""
EMBEDMODEL = "" EMBEDMODEL = ""
TIDB_HOST = "" os.environ["TIDB_HOST"] = ""
TIDB_PORT = "" os.environ["TIDB_PORT"] = ""
TIDB_USER = "" os.environ["TIDB_USER"] = ""
TIDB_PASSWORD = "" os.environ["TIDB_PASSWORD"] = ""
TIDB_DATABASE = "lightrag" os.environ["TIDB_DATABASE"] = "lightrag"
if not os.path.exists(WORKING_DIR): if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR) os.mkdir(WORKING_DIR)
@@ -62,21 +61,6 @@ async def main():
embedding_dimension = await get_embedding_dim() embedding_dimension = await get_embedding_dim()
print(f"Detected embedding dimension: {embedding_dimension}") print(f"Detected embedding dimension: {embedding_dimension}")
# Create TiDB DB connection
tidb = TiDB(
config={
"host": TIDB_HOST,
"port": TIDB_PORT,
"user": TIDB_USER,
"password": TIDB_PASSWORD,
"database": TIDB_DATABASE,
"workspace": "company", # specify which docs you want to store and query
}
)
# Check if TiDB DB tables exist, if not, tables will be created
await tidb.check_tables()
# Initialize LightRAG # Initialize LightRAG
# We use TiDB DB as the KV/vector # We use TiDB DB as the KV/vector
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -95,15 +79,6 @@ async def main():
graph_storage="TiDBGraphStorage", graph_storage="TiDBGraphStorage",
) )
if rag.llm_response_cache:
rag.llm_response_cache.db = tidb
rag.full_docs.db = tidb
rag.text_chunks.db = tidb
rag.entities_vdb.db = tidb
rag.relationships_vdb.db = tidb
rag.chunks_vdb.db = tidb
rag.chunk_entity_relation_graph.db = tidb
# Extract and Insert into LightRAG storage # Extract and Insert into LightRAG storage
with open("./dickens/demo.txt", "r", encoding="utf-8") as f: with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
await rag.ainsert(f.read()) await rag.ainsert(f.read())

View File

@@ -5,7 +5,6 @@ import time
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.kg.postgres_impl import PostgreSQLDB
from lightrag.llm.zhipu import zhipu_complete from lightrag.llm.zhipu import zhipu_complete
from lightrag.llm.ollama import ollama_embedding from lightrag.llm.ollama import ollama_embedding
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc
@@ -22,22 +21,14 @@ if not os.path.exists(WORKING_DIR):
# AGE # AGE
os.environ["AGE_GRAPH_NAME"] = "dickens" os.environ["AGE_GRAPH_NAME"] = "dickens"
postgres_db = PostgreSQLDB( os.environ["POSTGRES_HOST"] = "localhost"
config={ os.environ["POSTGRES_PORT"] = "15432"
"host": "localhost", os.environ["POSTGRES_USER"] = "rag"
"port": 15432, os.environ["POSTGRES_PASSWORD"] = "rag"
"user": "rag", os.environ["POSTGRES_DATABASE"] = "rag"
"password": "rag",
"database": "rag",
}
)
async def main(): async def main():
await postgres_db.initdb()
# Check if PostgreSQL DB tables exist, if not, tables will be created
await postgres_db.check_tables()
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=zhipu_complete, llm_model_func=zhipu_complete,
@@ -57,17 +48,7 @@ async def main():
graph_storage="PGGraphStorage", graph_storage="PGGraphStorage",
vector_storage="PGVectorStorage", vector_storage="PGVectorStorage",
) )
# Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
rag.doc_status.db = postgres_db
rag.full_docs.db = postgres_db
rag.text_chunks.db = postgres_db
rag.llm_response_cache.db = postgres_db
rag.key_string_value_json_storage_cls.db = postgres_db
rag.chunks_vdb.db = postgres_db
rag.relationships_vdb.db = postgres_db
rag.entities_vdb.db = postgres_db
rag.graph_storage_cls.db = postgres_db
rag.chunk_entity_relation_graph.db = postgres_db
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func

View File

@@ -15,11 +15,6 @@ import logging
import argparse import argparse
from typing import List, Any, Literal, Optional, Dict from typing import List, Any, Literal, Optional, Dict
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from pathlib import Path from pathlib import Path
import shutil import shutil
import aiofiles import aiofiles
@@ -36,39 +31,13 @@ import configparser
import traceback import traceback
from datetime import datetime from datetime import datetime
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger from lightrag.utils import logger
from .ollama_api import ( from .ollama_api import OllamaAPI, ollama_server_infos
OllamaAPI,
)
from .ollama_api import ollama_server_infos
def get_db_type_from_storage_class(class_name: str) -> str | None:
"""Determine database type based on storage class name"""
if class_name.startswith("PG"):
return "postgres"
elif class_name.startswith("Oracle"):
return "oracle"
elif class_name.startswith("TiDB"):
return "tidb"
return None
def import_db_module(db_type: str):
"""Dynamically import database module"""
if db_type == "postgres":
from ..kg.postgres_impl import PostgreSQLDB
return PostgreSQLDB
elif db_type == "oracle":
from ..kg.oracle_impl import OracleDB
return OracleDB
elif db_type == "tidb":
from ..kg.tidb_impl import TiDB
return TiDB
return None
# Load environment variables # Load environment variables
@@ -929,52 +898,12 @@ def create_app(args):
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events""" """Lifespan context manager for startup and shutdown events"""
# Initialize database connections
db_instances = {}
# Store background tasks # Store background tasks
app.state.background_tasks = set() app.state.background_tasks = set()
try: try:
# Check which database types are used # Initialize database connections
db_types = set() await rag.initialize_storages()
for storage_name, storage_instance in storage_instances:
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
db_types.add(db_type)
# Import and initialize databases as needed
for db_type in db_types:
if db_type == "postgres":
DB = import_db_module("postgres")
db = DB(_get_postgres_config())
await db.initdb()
await db.check_tables()
db_instances["postgres"] = db
elif db_type == "oracle":
DB = import_db_module("oracle")
db = DB(_get_oracle_config())
await db.check_tables()
db_instances["oracle"] = db
elif db_type == "tidb":
DB = import_db_module("tidb")
db = DB(_get_tidb_config())
await db.check_tables()
db_instances["tidb"] = db
# Inject database instances into storage classes
for storage_name, storage_instance in storage_instances:
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
if db_type not in db_instances:
error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
logger.error(error_msg)
raise RuntimeError(error_msg)
storage_instance.db = db_instances[db_type]
logger.info(f"Injected {db_type} db to {storage_name}")
# Auto scan documents if enabled # Auto scan documents if enabled
if args.auto_scan_at_startup: if args.auto_scan_at_startup:
@@ -1000,17 +929,7 @@ def create_app(args):
finally: finally:
# Clean up database connections # Clean up database connections
for db_type, db in db_instances.items(): await rag.finalize_storages()
if hasattr(db, "pool"):
await db.pool.close()
# Use more accurate database name display
db_names = {
"postgres": "PostgreSQL",
"oracle": "Oracle",
"tidb": "TiDB",
}
db_name = db_names.get(db_type, db_type)
logger.info(f"Closed {db_name} database connection pool")
# Initialize FastAPI # Initialize FastAPI
app = FastAPI( app = FastAPI(
@@ -1042,92 +961,6 @@ def create_app(args):
allow_headers=["*"], allow_headers=["*"],
) )
# Database configuration functions
def _get_postgres_config():
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config():
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config():
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Create the optional API key dependency # Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)
@@ -1262,6 +1095,7 @@ def create_app(args):
}, },
log_level=args.log_level, log_level=args.log_level,
namespace_prefix=args.namespace_prefix, namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
) )
else: else:
rag = LightRAG( rag = LightRAG(
@@ -1293,20 +1127,9 @@ def create_app(args):
}, },
log_level=args.log_level, log_level=args.log_level,
namespace_prefix=args.namespace_prefix, namespace_prefix=args.namespace_prefix,
auto_manage_storages_states=False,
) )
# Collect all storage instances
storage_instances = [
("full_docs", rag.full_docs),
("text_chunks", rag.text_chunks),
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
("entities_vdb", rag.entities_vdb),
("relationships_vdb", rag.relationships_vdb),
("chunks_vdb", rag.chunks_vdb),
("doc_status", rag.doc_status),
("llm_response_cache", rag.llm_response_cache),
]
async def pipeline_enqueue_file(file_path: Path) -> bool: async def pipeline_enqueue_file(file_path: Path) -> bool:
"""Add a file to the queue for processing """Add a file to the queue for processing

View File

@@ -87,6 +87,14 @@ class StorageNameSpace(ABC):
namespace: str namespace: str
global_config: dict[str, Any] global_config: dict[str, Any]
async def initialize(self):
"""Initialize the storage"""
pass
async def finalize(self):
"""Finalize the storage"""
pass
@abstractmethod @abstractmethod
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing""" """Commit the storage operations after indexing"""
@@ -247,3 +255,12 @@ class DocStatusStorage(BaseKVStorage, ABC):
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
class StoragesStatus(str, Enum):
"""Storages status"""
NOT_CREATED = "not_created"
CREATED = "created"
INITIALIZED = "initialized"
FINALIZED = "finalized"

View File

@@ -1,5 +1,5 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass, field
import numpy as np import numpy as np
import configparser import configparser
import asyncio import asyncio
@@ -26,8 +26,11 @@ if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
try: try:
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import (
from pymongo import MongoClient AsyncIOMotorClient,
AsyncIOMotorDatabase,
AsyncIOMotorCollection,
)
from pymongo.operations import SearchIndexModel from pymongo.operations import SearchIndexModel
from pymongo.errors import PyMongoError from pymongo.errors import PyMongoError
except ImportError as e: except ImportError as e:
@@ -39,31 +42,63 @@ config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod
async def get_client(cls) -> AsyncIOMotorDatabase:
async with cls._lock:
if cls._instances["db"] is None:
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb",
"uri",
fallback="mongodb://root:root@localhost:27017/",
),
)
database_name = os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
client = AsyncIOMotorClient(uri)
db = client.get_database(database_name)
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: AsyncIOMotorDatabase):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final @final
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
def __post_init__(self): db: AsyncIOMotorDatabase = field(default=None)
uri = os.environ.get( _data: AsyncIOMotorCollection = field(default=None)
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
def __post_init__(self):
self._collection_name = self.namespace self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name) async def initialize(self):
logger.debug(f"Use MongoDB as KV {self._collection_name}") if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as KV {self._collection_name}")
# Ensure collection exists async def finalize(self):
create_collection_if_not_exists(uri, database.name, self._collection_name) if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@@ -120,28 +155,23 @@ class MongoKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class MongoDocStatusStorage(DocStatusStorage): class MongoDocStatusStorage(DocStatusStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as doc status {self._collection_name}") async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
# Ensure collection exists async def finalize(self):
create_collection_if_not_exists(uri, database.name, self._collection_name) if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@@ -202,36 +232,33 @@ class MongoDocStatusStorage(DocStatusStorage):
@dataclass @dataclass
class MongoGraphStorage(BaseGraphStorage): class MongoGraphStorage(BaseGraphStorage):
""" """
A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries. A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
""" """
db: AsyncIOMotorDatabase = field(default=None)
collection: AsyncIOMotorCollection = field(default=None)
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
global_config=global_config, global_config=global_config,
embedding_func=embedding_func, embedding_func=embedding_func,
) )
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace self._collection_name = self.namespace
self.collection = database.get_collection(self._collection_name)
logger.debug(f"Use MongoDB as KG {self._collection_name}") async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection(
self.db, self._collection_name
)
logger.debug(f"Use MongoDB as KG {self._collection_name}")
# Ensure collection exists async def finalize(self):
create_collection_if_not_exists(uri, database.name, self._collection_name) if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self.collection = None
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@@ -770,6 +797,9 @@ class MongoGraphStorage(BaseGraphStorage):
@final @final
@dataclass @dataclass
class MongoVectorDBStorage(BaseVectorStorage): class MongoVectorDBStorage(BaseVectorStorage):
db: AsyncIOMotorDatabase = field(default=None)
_data: AsyncIOMotorCollection = field(default=None)
def __post_init__(self): def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = kwargs.get("cosine_better_than_threshold") cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -778,41 +808,36 @@ class MongoVectorDBStorage(BaseVectorStorage):
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
uri = os.environ.get(
"MONGO_URI",
config.get(
"mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
),
)
client = AsyncIOMotorClient(uri)
database = client.get_database(
os.environ.get(
"MONGO_DATABASE",
config.get("mongodb", "database", fallback="LightRAG"),
)
)
self._collection_name = self.namespace self._collection_name = self.namespace
self._data = database.get_collection(self._collection_name)
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
logger.debug(f"Use MongoDB as VDB {self._collection_name}") async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name)
# Ensure collection exists # Ensure vector index exists
create_collection_if_not_exists(uri, database.name, self._collection_name) await self.create_vector_index_if_not_exists()
# Ensure vector index exists logger.debug(f"Use MongoDB as VDB {self._collection_name}")
self.create_vector_index(uri, database.name, self._collection_name)
def create_vector_index(self, uri: str, database_name: str, collection_name: str): async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
self._data = None
async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index.""" """Creates an Atlas Vector Search index."""
client = MongoClient(uri)
collection = client.get_database(database_name).get_collection(
self._collection_name
)
try: try:
index_name = "vector_knn_index"
indexes = await self._data.list_search_indexes().to_list(length=None)
for index in indexes:
if index["name"] == index_name:
logger.debug("vector index already exist")
return
search_index_model = SearchIndexModel( search_index_model = SearchIndexModel(
definition={ definition={
"fields": [ "fields": [
@@ -824,11 +849,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
} }
] ]
}, },
name="vector_knn_index", name=index_name,
type="vectorSearch", type="vectorSearch",
) )
collection.create_search_index(search_index_model) await self._data.create_search_index(search_index_model)
logger.info("Vector index created successfully.") logger.info("Vector index created successfully.")
except PyMongoError as _: except PyMongoError as _:
@@ -913,15 +938,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
raise NotImplementedError raise NotImplementedError
def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str): async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
"""Check if the collection exists. if not, create it.""" collection_names = await db.list_collection_names()
client = MongoClient(uri)
database = client.get_database(database_name)
collection_names = database.list_collection_names()
if collection_name not in collection_names: if collection_name not in collection_names:
database.create_collection(collection_name) collection = await db.create_collection(collection_name)
logger.info(f"Created collection: {collection_name}") logger.info(f"Created collection: {collection_name}")
return collection
else: else:
logger.debug(f"Collection '{collection_name}' already exists.") logger.debug(f"Collection '{collection_name}' already exists.")
return db.get_collection(collection_name)

View File

@@ -2,11 +2,11 @@ import array
import asyncio import asyncio
# import html # import html
# import os import os
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Union, final from typing import Any, Union, final
import numpy as np import numpy as np
import configparser
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
@@ -177,17 +177,91 @@ class OracleDB:
raise raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> OracleDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = OracleDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: OracleDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed OracleDB database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final @final
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
# db instance must be injected before use db: OracleDB = field(default=None)
# db: OracleDB
meta_fields = None meta_fields = None
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -324,6 +398,8 @@ class OracleKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class OracleVectorDBStorage(BaseVectorStorage): class OracleVectorDBStorage(BaseVectorStorage):
db: OracleDB = field(default=None)
def __post_init__(self): def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold") cosine_threshold = config.get("cosine_better_than_threshold")
@@ -333,6 +409,15 @@ class OracleVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
@@ -369,9 +454,20 @@ class OracleVectorDBStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class OracleGraphStorage(BaseGraphStorage): class OracleGraphStorage(BaseGraphStorage):
db: OracleDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### insert method ################ #################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:

View File

@@ -3,10 +3,10 @@ import inspect
import json import json
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, final from typing import Any, Dict, List, Union, final
import numpy as np import numpy as np
import configparser
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
@@ -181,15 +181,84 @@ class PostgreSQLDB:
pass pass
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> PostgreSQLDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = PostgreSQLDB(config)
await db.initdb()
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: PostgreSQLDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed PostgreSQL database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final @final
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
# db instance must be injected before use db: PostgreSQLDB = field(default=None)
# db: PostgreSQLDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -308,6 +377,8 @@ class PGKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class PGVectorStorage(BaseVectorStorage): class PGVectorStorage(BaseVectorStorage):
db: PostgreSQLDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -318,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks(self, item: dict): def _upsert_chunks(self, item: dict):
try: try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"] upsert_sql = SQL_TEMPLATES["upsert_chunk"]
@@ -426,6 +506,17 @@ class PGVectorStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
db: PostgreSQLDB = field(default=None)
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( sql = SQL_TEMPLATES["filter_keys"].format(
@@ -565,6 +656,8 @@ class PGGraphQueryException(Exception):
@final @final
@dataclass @dataclass
class PGGraphStorage(BaseGraphStorage): class PGGraphStorage(BaseGraphStorage):
db: PostgreSQLDB = field(default=None)
@staticmethod @staticmethod
def load_nx_graph(file_name): def load_nx_graph(file_name):
print("no preloading of graph with AGE in production") print("no preloading of graph with AGE in production")
@@ -575,6 +668,15 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# PG handles persistence automatically # PG handles persistence automatically
pass pass

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
import os import os
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Union, final from typing import Any, Union, final
import numpy as np import numpy as np
@@ -13,6 +13,7 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("pymysql"): if not pm.is_installed("pymysql"):
pm.install("pymysql") pm.install("pymysql")
@@ -104,16 +105,81 @@ class TiDB:
raise raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> TiDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = TiDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: TiDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final @final
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
# db instance must be injected before use db: TiDB = field(default=None)
# db: TiDB
def __post_init__(self): def __post_init__(self):
self._data = {} self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -184,7 +250,7 @@ class TiDBKVStorage(BaseKVStorage):
"tokens": item["tokens"], "tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"], "chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"], "full_doc_id": item["full_doc_id"],
"content_vector": f'{item["__vector__"].tolist()}', "content_vector": f"{item['__vector__'].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
) )
@@ -212,6 +278,8 @@ class TiDBKVStorage(BaseKVStorage):
@final @final
@dataclass @dataclass
class TiDBVectorDBStorage(BaseVectorStorage): class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB = field(default=None)
def __post_init__(self): def __post_init__(self):
self._client_file_name = os.path.join( self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json" self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -225,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Search from tidb vector""" """Search from tidb vector"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
@@ -282,7 +359,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"id": item["id"], "id": item["id"],
"name": item["entity_name"], "name": item["entity_name"],
"content": item["content"], "content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}', "content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
# update entity_id if node inserted by graph_storage_instance before # update entity_id if node inserted by graph_storage_instance before
@@ -304,7 +381,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"source_name": item["src_id"], "source_name": item["src_id"],
"target_name": item["tgt_id"], "target_name": item["tgt_id"],
"content": item["content"], "content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}', "content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
# update relation_id if node inserted by graph_storage_instance before # update relation_id if node inserted by graph_storage_instance before
@@ -337,12 +414,20 @@ class TiDBVectorDBStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class TiDBGraphStorage(BaseGraphStorage): class TiDBGraphStorage(BaseGraphStorage):
# db instance must be injected before use db: TiDB = field(default=None)
# db: TiDB
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### upsert method ################ #################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id entity_name = node_id

View File

@@ -17,6 +17,7 @@ from .base import (
DocStatusStorage, DocStatusStorage,
QueryParam, QueryParam,
StorageNameSpace, StorageNameSpace,
StoragesStatus,
) )
from .namespace import NameSpace, make_namespace from .namespace import NameSpace, make_namespace
from .operate import ( from .operate import (
@@ -348,6 +349,10 @@ class LightRAG:
# Extensions # Extensions
addon_params: dict[str, Any] = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
# Storages Management
auto_manage_storages_states: bool = True
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
"""Dictionary for additional parameters and extensions.""" """Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json convert_response_to_json
@@ -440,7 +445,10 @@ class LightRAG:
**self.vector_db_storage_cls_kwargs, **self.vector_db_storage_cls_kwargs,
} }
# show config # Life cycle
self.storages_status = StoragesStatus.NOT_CREATED
# Show config
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -547,6 +555,65 @@ class LightRAG:
) )
) )
self.storages_status = StoragesStatus.CREATED
# Initialize storages
if self.auto_manage_storages_states:
loop = always_get_an_event_loop()
loop.run_until_complete(self.initialize_storages())
def __del__(self):
# Finalize storages
if self.auto_manage_storages_states:
loop = always_get_an_event_loop()
loop.run_until_complete(self.finalize_storages())
async def initialize_storages(self):
"""Asynchronously initialize the storages"""
if self.storages_status == StoragesStatus.CREATED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.initialize())
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages")
async def finalize_storages(self):
"""Asynchronously finalize the storages"""
if self.storages_status == StoragesStatus.INITIALIZED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.finalize())
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages")
async def get_graph_labels(self): async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels() text = await self.chunk_entity_relation_graph.get_all_labels()
return text return text