refactor database connection management and improve storage lifecycle handling

update
This commit is contained in:
ArnoChen
2025-02-19 03:46:18 +08:00
parent 780d0b45f7
commit e194e04226
6 changed files with 376 additions and 195 deletions

View File

@@ -15,11 +15,6 @@ import logging
import argparse import argparse
from typing import List, Any, Literal, Optional, Dict from typing import List, Any, Literal, Optional, Dict
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from pathlib import Path from pathlib import Path
import shutil import shutil
import aiofiles import aiofiles
@@ -36,39 +31,13 @@ import configparser
import traceback import traceback
from datetime import datetime from datetime import datetime
from lightrag import LightRAG, QueryParam
from lightrag.base import DocProcessingStatus, DocStatus
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
from lightrag.utils import EmbeddingFunc
from lightrag.utils import logger from lightrag.utils import logger
from .ollama_api import ( from .ollama_api import OllamaAPI, ollama_server_infos
OllamaAPI,
)
from .ollama_api import ollama_server_infos
def get_db_type_from_storage_class(class_name: str) -> str | None:
"""Determine database type based on storage class name"""
if class_name.startswith("PG"):
return "postgres"
elif class_name.startswith("Oracle"):
return "oracle"
elif class_name.startswith("TiDB"):
return "tidb"
return None
def import_db_module(db_type: str):
"""Dynamically import database module"""
if db_type == "postgres":
from ..kg.postgres_impl import PostgreSQLDB
return PostgreSQLDB
elif db_type == "oracle":
from ..kg.oracle_impl import OracleDB
return OracleDB
elif db_type == "tidb":
from ..kg.tidb_impl import TiDB
return TiDB
return None
# Load environment variables # Load environment variables
@@ -929,52 +898,12 @@ def create_app(args):
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events""" """Lifespan context manager for startup and shutdown events"""
# Initialize database connections
db_instances = {}
# Store background tasks # Store background tasks
app.state.background_tasks = set() app.state.background_tasks = set()
try: try:
# Check which database types are used # Initialize database connections
db_types = set() await rag.initialize_storages()
for storage_name, storage_instance in storage_instances:
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
db_types.add(db_type)
# Import and initialize databases as needed
for db_type in db_types:
if db_type == "postgres":
DB = import_db_module("postgres")
db = DB(_get_postgres_config())
await db.initdb()
await db.check_tables()
db_instances["postgres"] = db
elif db_type == "oracle":
DB = import_db_module("oracle")
db = DB(_get_oracle_config())
await db.check_tables()
db_instances["oracle"] = db
elif db_type == "tidb":
DB = import_db_module("tidb")
db = DB(_get_tidb_config())
await db.check_tables()
db_instances["tidb"] = db
# Inject database instances into storage classes
for storage_name, storage_instance in storage_instances:
db_type = get_db_type_from_storage_class(
storage_instance.__class__.__name__
)
if db_type:
if db_type not in db_instances:
error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
logger.error(error_msg)
raise RuntimeError(error_msg)
storage_instance.db = db_instances[db_type]
logger.info(f"Injected {db_type} db to {storage_name}")
# Auto scan documents if enabled # Auto scan documents if enabled
if args.auto_scan_at_startup: if args.auto_scan_at_startup:
@@ -1000,17 +929,7 @@ def create_app(args):
finally: finally:
# Clean up database connections # Clean up database connections
for db_type, db in db_instances.items(): await rag.finalize_storages()
if hasattr(db, "pool"):
await db.pool.close()
# Use more accurate database name display
db_names = {
"postgres": "PostgreSQL",
"oracle": "Oracle",
"tidb": "TiDB",
}
db_name = db_names.get(db_type, db_type)
logger.info(f"Closed {db_name} database connection pool")
# Initialize FastAPI # Initialize FastAPI
app = FastAPI( app = FastAPI(
@@ -1042,92 +961,6 @@ def create_app(args):
allow_headers=["*"], allow_headers=["*"],
) )
# Database configuration functions
def _get_postgres_config():
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
def _get_oracle_config():
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
def _get_tidb_config():
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
# Create the optional API key dependency # Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key) optional_api_key = get_api_key_dependency(api_key)
@@ -1262,6 +1095,7 @@ def create_app(args):
}, },
log_level=args.log_level, log_level=args.log_level,
namespace_prefix=args.namespace_prefix, namespace_prefix=args.namespace_prefix,
is_managed_by_server=True,
) )
else: else:
rag = LightRAG( rag = LightRAG(
@@ -1293,20 +1127,9 @@ def create_app(args):
}, },
log_level=args.log_level, log_level=args.log_level,
namespace_prefix=args.namespace_prefix, namespace_prefix=args.namespace_prefix,
is_managed_by_server=True,
) )
# Collect all storage instances
storage_instances = [
("full_docs", rag.full_docs),
("text_chunks", rag.text_chunks),
("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
("entities_vdb", rag.entities_vdb),
("relationships_vdb", rag.relationships_vdb),
("chunks_vdb", rag.chunks_vdb),
("doc_status", rag.doc_status),
("llm_response_cache", rag.llm_response_cache),
]
async def pipeline_enqueue_file(file_path: Path) -> bool: async def pipeline_enqueue_file(file_path: Path) -> bool:
"""Add a file to the queue for processing """Add a file to the queue for processing

View File

@@ -87,6 +87,14 @@ class StorageNameSpace(ABC):
namespace: str namespace: str
global_config: dict[str, Any] global_config: dict[str, Any]
async def initialize(self):
"""Initialize the storage"""
pass
async def finalize(self):
"""Finalize the storage"""
pass
@abstractmethod @abstractmethod
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
"""Commit the storage operations after indexing""" """Commit the storage operations after indexing"""
@@ -247,3 +255,12 @@ class DocStatusStorage(BaseKVStorage, ABC):
self, status: DocStatus self, status: DocStatus
) -> dict[str, DocProcessingStatus]: ) -> dict[str, DocProcessingStatus]:
"""Get all documents with a specific status""" """Get all documents with a specific status"""
class StoragesStatus(str, Enum):
"""Storages status"""
NOT_CREATED = "not_created"
CREATED = "created"
INITIALIZED = "initialized"
FINALIZED = "finalized"

View File

@@ -2,11 +2,11 @@ import array
import asyncio import asyncio
# import html # import html
# import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, final from typing import Any, Union, final
import numpy as np import numpy as np
import configparser
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
@@ -173,6 +173,72 @@ class OracleDB:
raise raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"user": os.environ.get(
"ORACLE_USER",
config.get("oracle", "user", fallback=None),
),
"password": os.environ.get(
"ORACLE_PASSWORD",
config.get("oracle", "password", fallback=None),
),
"dsn": os.environ.get(
"ORACLE_DSN",
config.get("oracle", "dsn", fallback=None),
),
"config_dir": os.environ.get(
"ORACLE_CONFIG_DIR",
config.get("oracle", "config_dir", fallback=None),
),
"wallet_location": os.environ.get(
"ORACLE_WALLET_LOCATION",
config.get("oracle", "wallet_location", fallback=None),
),
"wallet_password": os.environ.get(
"ORACLE_WALLET_PASSWORD",
config.get("oracle", "wallet_password", fallback=None),
),
"workspace": os.environ.get(
"ORACLE_WORKSPACE",
config.get("oracle", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> OracleDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = OracleDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: OracleDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed OracleDB database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final @final
@dataclass @dataclass
class OracleKVStorage(BaseKVStorage): class OracleKVStorage(BaseKVStorage):
@@ -184,6 +250,15 @@ class OracleKVStorage(BaseKVStorage):
self._data = {} self._data = {}
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -329,6 +404,15 @@ class OracleVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### query method ############### #################### query method ###############
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
@@ -368,6 +452,15 @@ class OracleGraphStorage(BaseGraphStorage):
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config.get("embedding_batch_num", 10) self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### insert method ################ #################### insert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:

View File

@@ -5,8 +5,8 @@ import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Union, final from typing import Any, Dict, List, Union, final
import numpy as np import numpy as np
import configparser
from lightrag.types import KnowledgeGraph from lightrag.types import KnowledgeGraph
@@ -182,6 +182,67 @@ class PostgreSQLDB:
pass pass
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"POSTGRES_HOST",
config.get("postgres", "host", fallback="localhost"),
),
"port": os.environ.get(
"POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
),
"user": os.environ.get(
"POSTGRES_USER", config.get("postgres", "user", fallback=None)
),
"password": os.environ.get(
"POSTGRES_PASSWORD",
config.get("postgres", "password", fallback=None),
),
"database": os.environ.get(
"POSTGRES_DATABASE",
config.get("postgres", "database", fallback=None),
),
"workspace": os.environ.get(
"POSTGRES_WORKSPACE",
config.get("postgres", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> PostgreSQLDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = PostgreSQLDB(config)
await db.initdb()
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: PostgreSQLDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
await db.pool.close()
logger.info("Closed PostgreSQL database connection pool")
cls._instances["db"] = None
else:
await db.pool.close()
@final @final
@dataclass @dataclass
class PGKVStorage(BaseKVStorage): class PGKVStorage(BaseKVStorage):
@@ -191,6 +252,15 @@ class PGKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -319,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks(self, item: dict): def _upsert_chunks(self, item: dict):
try: try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"] upsert_sql = SQL_TEMPLATES["upsert_chunk"]
@@ -435,6 +514,15 @@ class PGVectorStorage(BaseVectorStorage):
@final @final
@dataclass @dataclass
class PGDocStatusStorage(DocStatusStorage): class PGDocStatusStorage(DocStatusStorage):
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
sql = SQL_TEMPLATES["filter_keys"].format( sql = SQL_TEMPLATES["filter_keys"].format(
@@ -584,6 +672,15 @@ class PGGraphStorage(BaseGraphStorage):
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# PG handles persistence automatically # PG handles persistence automatically
pass pass

View File

@@ -14,6 +14,7 @@ from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("pymysql"): if not pm.is_installed("pymysql"):
pm.install("pymysql") pm.install("pymysql")
@@ -105,6 +106,63 @@ class TiDB:
raise raise
class ClientManager:
_instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config():
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> TiDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = TiDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: TiDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final @final
@dataclass @dataclass
class TiDBKVStorage(BaseKVStorage): class TiDBKVStorage(BaseKVStorage):
@@ -115,6 +173,15 @@ class TiDBKVStorage(BaseKVStorage):
self._data = {} self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -185,7 +252,7 @@ class TiDBKVStorage(BaseKVStorage):
"tokens": item["tokens"], "tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"], "chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"], "full_doc_id": item["full_doc_id"],
"content_vector": f'{item["__vector__"].tolist()}', "content_vector": f"{item['__vector__'].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
) )
@@ -226,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage):
) )
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
"""Search from tidb vector""" """Search from tidb vector"""
embeddings = await self.embedding_func([query]) embeddings = await self.embedding_func([query])
@@ -290,7 +366,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"id": item["id"], "id": item["id"],
"name": item["entity_name"], "name": item["entity_name"],
"content": item["content"], "content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}', "content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
# update entity_id if node inserted by graph_storage_instance before # update entity_id if node inserted by graph_storage_instance before
@@ -312,7 +388,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
"source_name": item["src_id"], "source_name": item["src_id"],
"target_name": item["tgt_id"], "target_name": item["tgt_id"],
"content": item["content"], "content": item["content"],
"content_vector": f'{item["content_vector"].tolist()}', "content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace, "workspace": self.db.workspace,
} }
# update relation_id if node inserted by graph_storage_instance before # update relation_id if node inserted by graph_storage_instance before
@@ -351,6 +427,15 @@ class TiDBGraphStorage(BaseGraphStorage):
def __post_init__(self): def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if not hasattr(self, "db") or self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if hasattr(self, "db") and self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### upsert method ################ #################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id entity_name = node_id

View File

@@ -17,6 +17,7 @@ from .base import (
DocStatusStorage, DocStatusStorage,
QueryParam, QueryParam,
StorageNameSpace, StorageNameSpace,
StoragesStatus,
) )
from .namespace import NameSpace, make_namespace from .namespace import NameSpace, make_namespace
from .operate import ( from .operate import (
@@ -348,6 +349,9 @@ class LightRAG:
# Extensions # Extensions
addon_params: dict[str, Any] = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
# Ownership
is_managed_by_server: bool = False
"""Dictionary for additional parameters and extensions.""" """Dictionary for additional parameters and extensions."""
convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
convert_response_to_json convert_response_to_json
@@ -440,7 +444,10 @@ class LightRAG:
**self.vector_db_storage_cls_kwargs, **self.vector_db_storage_cls_kwargs,
} }
# show config # Life cycle
self.storages_status = StoragesStatus.NOT_CREATED
# Show config
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -547,6 +554,65 @@ class LightRAG:
) )
) )
self.storages_status = StoragesStatus.CREATED
# Initialize storages
if not self.is_managed_by_server:
loop = always_get_an_event_loop()
loop.run_until_complete(self.initialize_storages())
def __del__(self):
# Finalize storages
if not self.is_managed_by_server:
loop = always_get_an_event_loop()
loop.run_until_complete(self.finalize_storages())
async def initialize_storages(self):
"""Asynchronously initialize the storages"""
if self.storages_status == StoragesStatus.CREATED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.initialize())
await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages")
async def finalize_storages(self):
"""Asynchronously finalize the storages"""
if self.storages_status == StoragesStatus.INITIALIZED:
tasks = []
for storage in (
self.full_docs,
self.text_chunks,
self.entities_vdb,
self.relationships_vdb,
self.chunks_vdb,
self.chunk_entity_relation_graph,
self.llm_response_cache,
self.doc_status,
):
if storage:
tasks.append(storage.finalize())
await asyncio.gather(*tasks)
logger.debug("Finalized Storages")
self.storages_status = StoragesStatus.FINALIZED
async def get_graph_labels(self): async def get_graph_labels(self):
text = await self.chunk_entity_relation_graph.get_all_labels() text = await self.chunk_entity_relation_graph.get_all_labels()
return text return text