cleanup code
This commit is contained in:
@@ -32,8 +32,10 @@ from .operate import (
|
||||
from .prompt import GRAPH_FIELD_SEP
|
||||
from .utils import (
|
||||
EmbeddingFunc,
|
||||
always_get_an_event_loop,
|
||||
compute_mdhash_id,
|
||||
convert_response_to_json,
|
||||
lazy_external_import,
|
||||
limit_async_func_call,
|
||||
logger,
|
||||
set_logger,
|
||||
@@ -182,48 +184,9 @@ STORAGES = {
|
||||
}
|
||||
|
||||
|
||||
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
|
||||
"""Lazily import a class from an external module based on the package of the caller."""
|
||||
# Get the caller's module and package
|
||||
import inspect
|
||||
|
||||
caller_frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(caller_frame)
|
||||
package = module.__package__ if module else None
|
||||
|
||||
def import_class(*args: Any, **kwargs: Any):
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_name, package=package)
|
||||
cls = getattr(module, class_name)
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
return import_class
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
"""
|
||||
Ensure that there is always an event loop available.
|
||||
|
||||
This function tries to get the current event loop. If the current event loop is closed or does not exist,
|
||||
it creates a new event loop and sets it as the current event loop.
|
||||
|
||||
Returns:
|
||||
asyncio.AbstractEventLoop: The current or newly created event loop.
|
||||
"""
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
current_loop = asyncio.get_event_loop()
|
||||
if current_loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed.")
|
||||
return current_loop
|
||||
|
||||
except RuntimeError:
|
||||
# If no event loop exists or it is closed, create a new one
|
||||
logger.info("Creating a new event loop in main thread.")
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop
|
||||
|
||||
|
||||
@final
|
||||
@@ -428,46 +391,6 @@ class LightRAG:
|
||||
The default function is :func:`.utils.convert_response_to_json`.
|
||||
"""
|
||||
|
||||
def verify_storage_implementation(
|
||||
self, storage_type: str, storage_name: str
|
||||
) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
log_file = os.path.join(self.log_dir, "lightrag.log")
|
||||
@@ -1681,3 +1604,43 @@ class LightRAG:
|
||||
result["vector_data"] = vector_data[0] if vector_data else None
|
||||
|
||||
return result
|
||||
|
||||
def verify_storage_implementation(
|
||||
self, storage_type: str, storage_name: str
|
||||
) -> None:
|
||||
"""Verify if storage implementation is compatible with specified storage type
|
||||
|
||||
Args:
|
||||
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If storage implementation is incompatible or missing required methods
|
||||
"""
|
||||
if storage_type not in STORAGE_IMPLEMENTATIONS:
|
||||
raise ValueError(f"Unknown storage type: {storage_type}")
|
||||
|
||||
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
|
||||
if storage_name not in storage_info["implementations"]:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
|
||||
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
|
||||
)
|
||||
|
||||
def check_storage_env_vars(self, storage_name: str) -> None:
|
||||
"""Check if all required environment variables for storage implementation exist
|
||||
|
||||
Args:
|
||||
storage_name: Storage implementation name
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
|
||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Storage implementation '{storage_name}' requires the following "
|
||||
f"environment variables: {', '.join(missing_vars)}"
|
||||
)
|
Reference in New Issue
Block a user