Merge pull request #898 from YanSte/update

Database Cleanup
This commit is contained in:
Yannick Stephan
2025-02-20 13:35:37 +01:00
committed by GitHub
11 changed files with 355 additions and 330 deletions

View File

@@ -1,5 +1,3 @@
version: '3.8'
services: services:
lightrag: lightrag:
build: . build: .

View File

@@ -98,7 +98,6 @@ async def init():
# Initialize LightRAG # Initialize LightRAG
# We use Oracle DB as the KV/vector/graph storage # We use Oracle DB as the KV/vector/graph storage
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,

View File

@@ -1,9 +1,8 @@
import os
import inspect import inspect
import os
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.llm import openai_complete, openai_embed from lightrag.llm import openai_complete, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
from lightrag.lightrag import always_get_an_event_loop
from lightrag import QueryParam from lightrag import QueryParam
# WorkingDir # WorkingDir

View File

@@ -63,7 +63,6 @@ async def main():
# Initialize LightRAG # Initialize LightRAG
# We use TiDB DB as the KV/vector # We use TiDB DB as the KV/vector
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
rag = LightRAG( rag = LightRAG(
enable_llm_cache=False, enable_llm_cache=False,
working_dir=WORKING_DIR, working_dir=WORKING_DIR,

View File

@@ -1 +1,136 @@
# print ("init package vars here. ......") STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
}

View File

@@ -44,7 +44,7 @@ class OracleDB:
self.increment = 1 self.increment = 1
logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier") logger.info(f"Using the label {self.workspace} for Oracle Graph as identifier")
if self.user is None or self.password is None: if self.user is None or self.password is None:
raise ValueError("Missing database user or password in addon_params") raise ValueError("Missing database user or password")
try: try:
oracledb.defaults.fetch_lobs = False oracledb.defaults.fetch_lobs = False

View File

@@ -54,9 +54,7 @@ class PostgreSQLDB:
self.pool: Pool | None = None self.pool: Pool | None = None
if self.user is None or self.password is None or self.database is None: if self.user is None or self.password is None or self.database is None:
raise ValueError( raise ValueError("Missing database user, password, or database")
"Missing database user, password, or database in addon_params"
)
async def initdb(self): async def initdb(self):
try: try:

View File

@@ -6,8 +6,10 @@ import configparser
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, cast from typing import Any, AsyncIterator, Callable, Iterator, cast, final
from asyncio import Lock
from lightrag.kg import STORAGE_ENV_REQUIREMENTS, STORAGE_IMPLEMENTATIONS, STORAGES
from .base import ( from .base import (
BaseGraphStorage, BaseGraphStorage,
BaseKVStorage, BaseKVStorage,
@@ -32,8 +34,10 @@ from .operate import (
from .prompt import GRAPH_FIELD_SEP from .prompt import GRAPH_FIELD_SEP
from .utils import ( from .utils import (
EmbeddingFunc, EmbeddingFunc,
always_get_an_event_loop,
compute_mdhash_id, compute_mdhash_id,
convert_response_to_json, convert_response_to_json,
lazy_external_import,
limit_async_func_call, limit_async_func_call,
logger, logger,
set_logger, set_logger,
@@ -43,210 +47,22 @@ from .utils import (
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorage",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
"MongoVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": [
"JsonDocStatusStorage",
"PGDocStatusStorage",
"PGDocStatusStorage",
"MongoDocStatusStorage",
],
"required_methods": ["get_docs_by_status"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DATABASE",
],
"OracleGraphStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorage": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": [
"ORACLE_DSN",
"ORACLE_USER",
"ORACLE_PASSWORD",
"ORACLE_CONFIG_DIR",
],
"MongoVectorDBStorage": [],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"MongoDocStatusStorage": [],
}
# Storage implementation module mapping
STORAGES = {
"NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl",
"NanoVectorDBStorage": ".kg.nano_vector_db_impl",
"JsonDocStatusStorage": ".kg.json_doc_status_impl",
"Neo4JStorage": ".kg.neo4j_impl",
"OracleKVStorage": ".kg.oracle_impl",
"OracleGraphStorage": ".kg.oracle_impl",
"OracleVectorDBStorage": ".kg.oracle_impl",
"MilvusVectorDBStorage": ".kg.milvus_impl",
"MongoKVStorage": ".kg.mongo_impl",
"MongoDocStatusStorage": ".kg.mongo_impl",
"MongoGraphStorage": ".kg.mongo_impl",
"MongoVectorDBStorage": ".kg.mongo_impl",
"RedisKVStorage": ".kg.redis_impl",
"ChromaVectorDBStorage": ".kg.chroma_impl",
"TiDBKVStorage": ".kg.tidb_impl",
"TiDBVectorDBStorage": ".kg.tidb_impl",
"TiDBGraphStorage": ".kg.tidb_impl",
"PGKVStorage": ".kg.postgres_impl",
"PGVectorStorage": ".kg.postgres_impl",
"AGEStorage": ".kg.age_impl",
"PGGraphStorage": ".kg.postgres_impl",
"GremlinStorage": ".kg.gremlin_impl",
"PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl",
}
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
@final
@dataclass @dataclass
class LightRAG: class LightRAG:
"""LightRAG: Simple and Fast Retrieval-Augmented Generation.""" """LightRAG: Simple and Fast Retrieval-Augmented Generation."""
# Directory
# ---
working_dir: str = field( working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
) )
"""Directory where cache and temporary files are stored.""" """Directory where cache and temporary files are stored."""
embedding_cache_config: dict[str, Any] = field( # Storage
default_factory=lambda: { # ---
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
kv_storage: str = field(default="JsonKVStorage") kv_storage: str = field(default="JsonKVStorage")
"""Storage backend for key-value data.""" """Storage backend for key-value data."""
@@ -261,32 +77,74 @@ class LightRAG:
"""Storage type for tracking document processing statuses.""" """Storage type for tracking document processing statuses."""
# Logging # Logging
current_log_level = logger.level # ---
log_level: int = field(default=current_log_level)
log_level: int = field(default=logger.level)
"""Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING').""" """Logging level for the system (e.g., 'DEBUG', 'INFO', 'WARNING')."""
log_dir: str = field(default=os.getcwd()) log_file_path: str = field(default=os.path.join(os.getcwd(), "lightrag.log"))
"""Directory where logs are stored. Defaults to the current working directory.""" """Log file path."""
# Text chunking
chunk_token_size: int = int(os.getenv("CHUNK_SIZE", "1200"))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = int(os.getenv("CHUNK_OVERLAP_SIZE", "100"))
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = "gpt-4o-mini"
"""Model name used for tokenization when chunking text."""
# Entity extraction # Entity extraction
entity_extract_max_gleaning: int = 1 # ---
entity_extract_max_gleaning: int = field(default=1)
"""Maximum number of entity extraction attempts for ambiguous content.""" """Maximum number of entity extraction attempts for ambiguous content."""
entity_summary_to_max_tokens: int = int(os.getenv("MAX_TOKEN_SUMMARY", "500")) entity_summary_to_max_tokens: int = field(
default=int(os.getenv("MAX_TOKEN_SUMMARY", 500))
)
# Text chunking
# ---
chunk_token_size: int = field(default=int(os.getenv("CHUNK_SIZE", 1200)))
"""Maximum number of tokens per text chunk when splitting documents."""
chunk_overlap_token_size: int = field(
default=int(os.getenv("CHUNK_OVERLAP_SIZE", 100))
)
"""Number of overlapping tokens between consecutive text chunks to preserve context."""
tiktoken_model_name: str = field(default="gpt-4o-mini")
"""Model name used for tokenization when chunking text."""
"""Maximum number of tokens used for summarizing extracted entities.""" """Maximum number of tokens used for summarizing extracted entities."""
chunking_func: Callable[
[
str,
str | None,
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = field(default_factory=lambda: chunking_by_token_size)
"""
Custom chunking function for splitting text into chunks before processing.
The function should take the following parameters:
- `content`: The text to be split into chunks.
- `split_by_character`: The character to split the text on. If None, the text is split into chunks of `chunk_token_size` tokens.
- `split_by_character_only`: If True, the text is split only on the specified character.
- `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
- `tiktoken_model_name`: The name of the tiktoken model to use for tokenization.
The function should return a list of dictionaries, where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk.
- `content`: The text content of the chunk.
Defaults to `chunking_by_token_size` if not specified.
"""
# Node embedding # Node embedding
node_embedding_algorithm: str = "node2vec" # ---
node_embedding_algorithm: str = field(default="node2vec")
"""Algorithm used for node embedding in knowledge graphs.""" """Algorithm used for node embedding in knowledge graphs."""
node2vec_params: dict[str, int] = field( node2vec_params: dict[str, int] = field(
@@ -308,119 +166,98 @@ class LightRAG:
- random_seed: Seed value for reproducibility. - random_seed: Seed value for reproducibility.
""" """
embedding_func: EmbeddingFunc | None = None # Embedding
# ---
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use.""" """Function for computing text embeddings. Must be set before use."""
embedding_batch_num: int = 32 embedding_batch_num: int = field(default=32)
"""Batch size for embedding computations.""" """Batch size for embedding computations."""
embedding_func_max_async: int = 16 embedding_func_max_async: int = field(default=16)
"""Maximum number of concurrent embedding function calls.""" """Maximum number of concurrent embedding function calls."""
embedding_cache_config: dict[str, Any] = field(
default={
"enabled": False,
"similarity_threshold": 0.95,
"use_llm_check": False,
}
)
"""Configuration for embedding cache.
- enabled: If True, enables caching to avoid redundant computations.
- similarity_threshold: Minimum similarity score to use cached embeddings.
- use_llm_check: If True, validates cached embeddings using an LLM.
"""
# LLM Configuration # LLM Configuration
llm_model_func: Callable[..., object] | None = None # ---
llm_model_func: Callable[..., object] | None = field(default=None)
"""Function for interacting with the large language model (LLM). Must be set before use.""" """Function for interacting with the large language model (LLM). Must be set before use."""
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" llm_model_name: str = field(default="gpt-4o-mini")
"""Name of the LLM model used for generating responses.""" """Name of the LLM model used for generating responses."""
llm_model_max_token_size: int = int(os.getenv("MAX_TOKENS", "32768")) llm_model_max_token_size: int = field(default=int(os.getenv("MAX_TOKENS", 32768)))
"""Maximum number of tokens allowed per LLM response.""" """Maximum number of tokens allowed per LLM response."""
llm_model_max_async: int = int(os.getenv("MAX_ASYNC", "16")) llm_model_max_async: int = field(default=int(os.getenv("MAX_ASYNC", 16)))
"""Maximum number of concurrent LLM calls.""" """Maximum number of concurrent LLM calls."""
llm_model_kwargs: dict[str, Any] = field(default_factory=dict) llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function.""" """Additional keyword arguments passed to the LLM model function."""
# Storage # Storage
# ---
vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional parameters for vector database storage.""" """Additional parameters for vector database storage."""
namespace_prefix: str = field(default="") namespace_prefix: str = field(default="")
"""Prefix for namespacing stored data across different environments.""" """Prefix for namespacing stored data across different environments."""
enable_llm_cache: bool = True enable_llm_cache: bool = field(default=True)
"""Enables caching for LLM responses to avoid redundant computations.""" """Enables caching for LLM responses to avoid redundant computations."""
enable_llm_cache_for_entity_extract: bool = True enable_llm_cache_for_entity_extract: bool = field(default=True)
"""If True, enables caching for entity extraction steps to reduce LLM costs.""" """If True, enables caching for entity extraction steps to reduce LLM costs."""
# Extensions # Extensions
# ---
max_parallel_insert: int = field(default=int(os.getenv("MAX_PARALLEL_INSERT", 20)))
"""Maximum number of parallel insert operations."""
addon_params: dict[str, Any] = field(default_factory=dict) addon_params: dict[str, Any] = field(default_factory=dict)
# Storages Management # Storages Management
auto_manage_storages_states: bool = True # ---
auto_manage_storages_states: bool = field(default=True)
"""If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times.""" """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
"""Dictionary for additional parameters and extensions.""" # Storages Management
convert_response_to_json_func: Callable[[str], dict[str, Any]] = ( # ---
convert_response_to_json
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
default_factory=lambda: convert_response_to_json
) )
"""
Custom function for converting LLM responses to JSON format.
# Lock for entity extraction The default function is :func:`.utils.convert_response_to_json`.
_entity_lock = Lock() """
# Custom Chunking Function _storages_status: StoragesStatus = field(default=StoragesStatus.NOT_CREATED)
chunking_func: Callable[
[
str,
str | None,
bool,
int,
int,
str,
],
list[dict[str, Any]],
] = chunking_by_token_size
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
def __post_init__(self): def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True)
log_file = os.path.join(self.log_dir, "lightrag.log")
set_logger(log_file)
logger.setLevel(self.log_level) logger.setLevel(self.log_level)
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
set_logger(self.log_file_path)
logger.info(f"Logger initialized for working directory: {self.working_dir}") logger.info(f"Logger initialized for working directory: {self.working_dir}")
if not os.path.exists(self.working_dir): if not os.path.exists(self.working_dir):
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
@@ -448,9 +285,6 @@ class LightRAG:
**self.vector_db_storage_cls_kwargs, **self.vector_db_storage_cls_kwargs,
} }
# Life cycle
self.storages_status = StoragesStatus.NOT_CREATED
# Show config # Show config
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
@@ -558,7 +392,7 @@ class LightRAG:
) )
) )
self.storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
# Initialize storages # Initialize storages
if self.auto_manage_storages_states: if self.auto_manage_storages_states:
@@ -573,7 +407,7 @@ class LightRAG:
async def initialize_storages(self): async def initialize_storages(self):
"""Asynchronously initialize the storages""" """Asynchronously initialize the storages"""
if self.storages_status == StoragesStatus.CREATED: if self._storages_status == StoragesStatus.CREATED:
tasks = [] tasks = []
for storage in ( for storage in (
@@ -591,12 +425,12 @@ class LightRAG:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.INITIALIZED self._storages_status = StoragesStatus.INITIALIZED
logger.debug("Initialized Storages") logger.debug("Initialized Storages")
async def finalize_storages(self): async def finalize_storages(self):
"""Asynchronously finalize the storages""" """Asynchronously finalize the storages"""
if self.storages_status == StoragesStatus.INITIALIZED: if self._storages_status == StoragesStatus.INITIALIZED:
tasks = [] tasks = []
for storage in ( for storage in (
@@ -614,7 +448,7 @@ class LightRAG:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
self.storages_status = StoragesStatus.FINALIZED self._storages_status = StoragesStatus.FINALIZED
logger.debug("Finalized Storages") logger.debug("Finalized Storages")
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
@@ -789,10 +623,9 @@ class LightRAG:
return return
# 2. split docs into chunks, insert chunks, update doc status # 2. split docs into chunks, insert chunks, update doc status
batch_size = self.addon_params.get("insert_batch_size", 10)
docs_batches = [ docs_batches = [
list(to_process_docs.items())[i : i + batch_size] list(to_process_docs.items())[i : i + self.max_parallel_insert]
for i in range(0, len(to_process_docs), batch_size) for i in range(0, len(to_process_docs), self.max_parallel_insert)
] ]
logger.info(f"Number of batches to process: {len(docs_batches)}.") logger.info(f"Number of batches to process: {len(docs_batches)}.")
@@ -1203,7 +1036,6 @@ class LightRAG:
# --------------------- # ---------------------
# STEP 1: Keyword Extraction # STEP 1: Keyword Extraction
# --------------------- # ---------------------
# We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
hl_keywords, ll_keywords = await extract_keywords_only( hl_keywords, ll_keywords = await extract_keywords_only(
text=query, text=query,
param=param, param=param,
@@ -1629,3 +1461,43 @@ class LightRAG:
result["vector_data"] = vector_data[0] if vector_data else None result["vector_data"] = vector_data[0] if vector_data else None
return result return result
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)

View File

@@ -713,3 +713,47 @@ def get_conversation_turns(
) )
return "\n".join(formatted_turns) return "\n".join(formatted_turns)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
"""
Ensure that there is always an event loop available.
This function tries to get the current event loop. If the current event loop is closed or does not exist,
it creates a new event loop and sets it as the current event loop.
Returns:
asyncio.AbstractEventLoop: The current or newly created event loop.
"""
try:
# Try to get the current event loop
current_loop = asyncio.get_event_loop()
if current_loop.is_closed():
raise RuntimeError("Event loop is closed.")
return current_loop
except RuntimeError:
# If no event loop exists or it is closed, create a new one
logger.info("Creating a new event loop in main thread.")
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
return new_loop
def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]:
"""Lazily import a class from an external module based on the package of the caller."""
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args: Any, **kwargs: Any):
import importlib
module = importlib.import_module(module_name, package=package)
cls = getattr(module, class_name)
return cls(*args, **kwargs)
return import_class

View File

@@ -1,7 +1,7 @@
import re import re
import json import json
import asyncio
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.utils import always_get_an_event_loop
def extract_queries(file_path): def extract_queries(file_path):
@@ -23,15 +23,6 @@ async def process_query(query_text, rag_instance, query_param):
return None, {"query": query_text, "error": str(e)} return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json( def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file queries, rag_instance, query_param, output_file, error_file
): ):

View File

@@ -1,10 +1,9 @@
import os import os
import re import re
import json import json
import asyncio
from lightrag import LightRAG, QueryParam from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc from lightrag.utils import EmbeddingFunc, always_get_an_event_loop
import numpy as np import numpy as np
@@ -55,15 +54,6 @@ async def process_query(query_text, rag_instance, query_param):
return None, {"query": query_text, "error": str(e)} return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json( def run_queries_and_save_to_json(
queries, rag_instance, query_param, output_file, error_file queries, rag_instance, query_param, output_file, error_file
): ):