diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8d13fab0..8d85c292 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -26,7 +26,6 @@ import shutil import aiofiles from ascii_colors import trace_exception, ASCIIColors import sys -import configparser from fastapi import Depends, Security from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware @@ -44,13 +43,40 @@ load_dotenv(override=True) class RAGStorageConfig: - KV_STORAGE = "JsonKVStorage" - DOC_STATUS_STORAGE = "JsonDocStatusStorage" - GRAPH_STORAGE = "NetworkXStorage" - VECTOR_STORAGE = "NanoVectorDBStorage" + """存储配置类,支持通过环境变量和命令行参数修改默认值""" + + # 默认存储实现 + DEFAULT_KV_STORAGE = "JsonKVStorage" + DEFAULT_DOC_STATUS_STORAGE = "JsonDocStatusStorage" + DEFAULT_GRAPH_STORAGE = "NetworkXStorage" + DEFAULT_VECTOR_STORAGE = "NanoVectorDBStorage" + + def __init__(self): + # 从环境变量读取配置,如果没有则使用默认值 + self.KV_STORAGE = os.getenv("LIGHTRAG_KV_STORAGE", self.DEFAULT_KV_STORAGE) + self.DOC_STATUS_STORAGE = os.getenv( + "LIGHTRAG_DOC_STATUS_STORAGE", self.DEFAULT_DOC_STATUS_STORAGE + ) + self.GRAPH_STORAGE = os.getenv( + "LIGHTRAG_GRAPH_STORAGE", self.DEFAULT_GRAPH_STORAGE + ) + self.VECTOR_STORAGE = os.getenv( + "LIGHTRAG_VECTOR_STORAGE", self.DEFAULT_VECTOR_STORAGE + ) + + def update_from_args(self, args): + """从命令行参数更新配置""" + if hasattr(args, "kv_storage"): + self.KV_STORAGE = args.kv_storage + if hasattr(args, "doc_status_storage"): + self.DOC_STATUS_STORAGE = args.doc_status_storage + if hasattr(args, "graph_storage"): + self.GRAPH_STORAGE = args.graph_storage + if hasattr(args, "vector_storage"): + self.VECTOR_STORAGE = args.vector_storage -# Initialize rag storage config +# 初始化存储配置 rag_storage_config = RAGStorageConfig() # Global progress tracker @@ -81,60 +107,6 @@ def estimate_tokens(text: str) -> int: return int(tokens) -# read config.ini -config = configparser.ConfigParser() -config.read("config.ini", "utf-8") -# Redis config -redis_uri = config.get("redis", "uri", fallback=None) -if redis_uri: - os.environ["REDIS_URI"] = redis_uri - rag_storage_config.KV_STORAGE = "RedisKVStorage" - rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage" - -# Neo4j config -neo4j_uri = config.get("neo4j", "uri", fallback=None) -neo4j_username = config.get("neo4j", "username", fallback=None) -neo4j_password = config.get("neo4j", "password", fallback=None) -if neo4j_uri: - os.environ["NEO4J_URI"] = neo4j_uri - os.environ["NEO4J_USERNAME"] = neo4j_username - os.environ["NEO4J_PASSWORD"] = neo4j_password - rag_storage_config.GRAPH_STORAGE = "Neo4JStorage" - -# Milvus config -milvus_uri = config.get("milvus", "uri", fallback=None) -milvus_user = config.get("milvus", "user", fallback=None) -milvus_password = config.get("milvus", "password", fallback=None) -milvus_db_name = config.get("milvus", "db_name", fallback=None) -if milvus_uri: - os.environ["MILVUS_URI"] = milvus_uri - os.environ["MILVUS_USER"] = milvus_user - os.environ["MILVUS_PASSWORD"] = milvus_password - os.environ["MILVUS_DB_NAME"] = milvus_db_name - rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge" - -# Qdrant config -qdrant_uri = config.get("qdrant", "uri", fallback=None) -qdrant_api_key = config.get("qdrant", "apikey", fallback=None) -if qdrant_uri: - os.environ["QDRANT_URL"] = qdrant_uri - if qdrant_api_key: - os.environ["QDRANT_API_KEY"] = qdrant_api_key - rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage" - -# MongoDB config -mongo_uri = config.get("mongodb", "uri", fallback=None) -mongo_database = config.get("mongodb", "database", fallback="LightRAG") -mongo_graph = config.getboolean("mongodb", "graph", fallback=False) -if mongo_uri: - os.environ["MONGO_URI"] = mongo_uri - os.environ["MONGO_DATABASE"] = mongo_database - rag_storage_config.KV_STORAGE = "MongoKVStorage" - rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" - if mongo_graph: - rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage" - - def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), @@ -340,6 +312,27 @@ def parse_args() -> argparse.Namespace: description="LightRAG FastAPI Server with separate working and input directories" ) + parser.add_argument( + "--kv-storage", + default=rag_storage_config.KV_STORAGE, + help=f"KV存储实现 (default: {rag_storage_config.KV_STORAGE})", + ) + parser.add_argument( + "--doc-status-storage", + default=rag_storage_config.DOC_STATUS_STORAGE, + help=f"文档状态存储实现 (default: {rag_storage_config.DOC_STATUS_STORAGE})", + ) + parser.add_argument( + "--graph-storage", + default=rag_storage_config.GRAPH_STORAGE, + help=f"图存储实现 (default: {rag_storage_config.GRAPH_STORAGE})", + ) + parser.add_argument( + "--vector-storage", + default=rag_storage_config.VECTOR_STORAGE, + help=f"向量存储实现 (default: {rag_storage_config.VECTOR_STORAGE})", + ) + # Bindings configuration parser.add_argument( "--llm-binding", @@ -554,6 +547,8 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() + rag_storage_config.update_from_args(args) + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index d3c64617..f38fd00a 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -47,7 +47,9 @@ class GremlinStorage(BaseGraphStorage): # All vertices will have graph={GRAPH} property, so that we can # have several logical graphs for one source - GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"]) + GRAPH = GremlinStorage._to_value_map( + os.environ.get("GREMLIN_GRAPH", "LightRAG") + ) self.graph_name = GRAPH diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 2995fd9b..35260652 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -5,14 +5,17 @@ from dataclasses import dataclass import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage - import pipmaster as pm +import configparser if not pm.is_installed("pymilvus"): pm.install("pymilvus") from pymilvus import MilvusClient +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class MilvusVectorDBStorge(BaseVectorStorage): @staticmethod @@ -27,14 +30,11 @@ class MilvusVectorDBStorge(BaseVectorStorage): def __post_init__(self): self._client = MilvusClient( - uri=os.environ.get( - "MILVUS_URI", - os.path.join(self.global_config["working_dir"], "milvus_lite.db"), - ), - user=os.environ.get("MILVUS_USER", ""), - password=os.environ.get("MILVUS_PASSWORD", ""), - token=os.environ.get("MILVUS_TOKEN", ""), - db_name=os.environ.get("MILVUS_DB_NAME", ""), + uri = os.environ.get("MILVUS_URI", config.get("milvus", "uri", fallback=os.path.join(self.global_config["working_dir"], "milvus_lite.db"))), + user = os.environ.get("MILVUS_USER", config.get("milvus", "user", fallback=None)), + password = os.environ.get("MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)), + token = os.environ.get("MILVUS_TOKEN", config.get("milvus", "token", fallback=None)), + db_name = os.environ.get("MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)), ) self._max_batch_size = self.global_config["embedding_batch_num"] MilvusVectorDBStorge.create_collection_if_not_exist( diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 4f919ecd..142050cf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,8 +1,8 @@ import os from dataclasses import dataclass - import numpy as np import pipmaster as pm +import configparser from tqdm.asyncio import tqdm as tqdm_async if not pm.is_installed("pymongo"): @@ -12,22 +12,23 @@ if not pm.is_installed("motor"): pm.install("motor") from typing import Any, List, Tuple, Union - from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient - from ..base import BaseGraphStorage, BaseKVStorage from ..namespace import NameSpace, is_namespace from ..utils import logger +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): client = MongoClient( - os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) ) - database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) + database = client.get_database(os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))) self._data = database.get_collection(self.namespace) logger.info(f"Use MongoDB as KV {self.namespace}") @@ -90,10 +91,10 @@ class MongoGraphStorage(BaseGraphStorage): embedding_func=embedding_func, ) self.client = AsyncIOMotorClient( - os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) ) - self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")] - self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")] + self.db = self.client[os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))] + self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"))] # # ------------------------------------------------------------------------- diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index fe01aaf3..61f85ad4 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -5,6 +5,7 @@ import re from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict import pipmaster as pm +import configparser if not pm.is_installed("neo4j"): pm.install("neo4j") @@ -27,6 +28,9 @@ from ..utils import logger from ..base import BaseGraphStorage +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class Neo4JStorage(BaseGraphStorage): @staticmethod @@ -41,13 +45,15 @@ class Neo4JStorage(BaseGraphStorage): ) self._driver = None self._driver_lock = asyncio.Lock() - URI = os.environ["NEO4J_URI"] - USERNAME = os.environ["NEO4J_USERNAME"] - PASSWORD = os.environ["NEO4J_PASSWORD"] - MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) + + URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)] + USERNAME = os.environ["NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)] + PASSWORD = os.environ["NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)] + MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", config.get("neo4j", "connection_pool_size", fallback=800)) DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) ) + self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py deleted file mode 100644 index 304d556c..00000000 --- a/lightrag/kg/postgres_impl_test.py +++ /dev/null @@ -1,138 +0,0 @@ -import asyncio -import sys -import os -import pipmaster as pm - -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - -import asyncpg -import psycopg -from psycopg_pool import AsyncConnectionPool - -from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage -from ..namespace import NameSpace - -DB = "rag" -USER = "rag" -PASSWORD = "rag" -HOST = "localhost" -PORT = "15432" -os.environ["AGE_GRAPH_NAME"] = "dickens" - -if sys.platform.startswith("win"): - import asyncio.windows_events - - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - -async def get_pool(): - return await asyncpg.create_pool( - f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", - min_size=10, - max_size=10, - max_queries=5000, - max_inactive_connection_lifetime=300.0, - ) - - -async def main1(): - connection_string = ( - f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" - ) - pool = AsyncConnectionPool(connection_string, open=False) - await pool.open() - - try: - conn = await pool.getconn(timeout=10) - async with conn.cursor() as curs: - try: - await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute("SELECT create_graph('dickens-2')") - await conn.commit() - print("create_graph success") - except ( - psycopg.errors.InvalidSchemaName, - psycopg.errors.UniqueViolation, - ): - print("create_graph already exists") - await conn.rollback() - finally: - pass - - -db = PostgreSQLDB( - config={ - "host": "localhost", - "port": 15432, - "user": "rag", - "password": "rag", - "database": "r1", - } -) - - -async def query_with_age(): - await db.initdb() - graph = PGGraphStorage( - namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, - global_config={}, - embedding_func=None, - ) - graph.db = db - res = await graph.get_node('"A CHRISTMAS CAROL"') - print("Node is: ", res) - res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG") - print("Edge is: ", res) - res = await graph.get_node_edges('"SCROOGE"') - print("Node Edges are: ", res) - - -async def create_edge_with_age(): - await db.initdb() - graph = PGGraphStorage( - namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, - global_config={}, - embedding_func=None, - ) - graph.db = db - await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"}) - await graph.upsert_node('"THE GIRLS"', {"world": "hello"}) - await graph.upsert_edge( - '"THE CRATCHITS"', - '"THE GIRLS"', - edge_data={ - "weight": 7.0, - "description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.', - "keywords": '"family, collective effort"', - "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", - }, - ) - res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"') - print("Edge is: ", res) - - -async def main(): - pool = await get_pool() - sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" - # cypher = "MATCH (n:how_are_you_doing) RETURN n" - async with pool.acquire() as conn: - try: - await conn.execute( - """SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""" - ) - except asyncpg.exceptions.InvalidSchemaNameError: - print("create_graph already exists") - # stmt = await conn.prepare(sql) - row = await conn.fetch(sql) - print("row is: ", row) - - row = await conn.fetchrow("select '100'::int + 200 as result") - print(row) # - - -if __name__ == "__main__": - asyncio.run(query_with_age()) diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index e2f8d3a2..0a971b73 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -5,11 +5,10 @@ from dataclasses import dataclass import numpy as np import hashlib import uuid - from ..utils import logger from ..base import BaseVectorStorage - import pipmaster as pm +import configparser if not pm.is_installed("qdrant_client"): pm.install("qdrant_client") @@ -17,6 +16,9 @@ if not pm.is_installed("qdrant_client"): from qdrant_client import QdrantClient, models +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + def compute_mdhash_id_for_qdrant( content: str, prefix: str = "", style: str = "simple" ) -> str: @@ -57,8 +59,8 @@ class QdrantVectorDBStorage(BaseVectorStorage): def __post_init__(self): self._client = QdrantClient( - url=os.environ.get("QDRANT_URL"), - api_key=os.environ.get("QDRANT_API_KEY", None), + url=os.environ.get("QDRANT_URL", config.get("qdrant", "uri", fallback=None)), + api_key=os.environ.get("QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)), ) self._max_batch_size = self.global_config["embedding_batch_num"] QdrantVectorDBStorage.create_collection_if_not_exist( diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index e97a6afc..fef62d5f 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -3,6 +3,7 @@ from typing import Any, Union from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm +import configparser if not pm.is_installed("redis"): pm.install("redis") @@ -14,10 +15,13 @@ from lightrag.base import BaseKVStorage import json +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): - redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379") + redis_url = os.environ.get("REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")) self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 347f0f4c..de0e4f59 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -35,6 +35,89 @@ from .utils import ( set_logger, ) +# Storage type and implementation compatibility validation table +STORAGE_IMPLEMENTATIONS = { + "KV_STORAGE": { + "implementations": [ + "JsonKVStorage", + "MongoKVStorage", + "RedisKVStorage", + "TiDBKVStorage", + "PGKVStorage", + "OracleKVStorage", + ], + "required_methods": ["get_by_id", "upsert"], + }, + "GRAPH_STORAGE": { + "implementations": [ + "NetworkXStorage", + "Neo4JStorage", + "MongoGraphStorage", + "TiDBGraphStorage", + "AGEStorage", + "GremlinStorage", + "PGGraphStorage", + "OracleGraphStorage", + ], + "required_methods": ["upsert_node", "upsert_edge"], + }, + "VECTOR_STORAGE": { + "implementations": [ + "NanoVectorDBStorage", + "MilvusVectorDBStorge", + "ChromaVectorDBStorage", + "TiDBVectorDBStorage", + "PGVectorStorage", + "FaissVectorDBStorage", + "QdrantVectorDBStorage", + "OracleVectorDBStorage", + ], + "required_methods": ["query", "upsert"], + }, + "DOC_STATUS_STORAGE": { + "implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"], + "required_methods": ["get_pending_docs"], + }, +} + +# Storage implementation environment variable without default value +STORAGE_ENV_REQUIREMENTS = { + # KV Storage Implementations + "JsonKVStorage": [], + "MongoKVStorage": [], + "RedisKVStorage": ["REDIS_URI"], + "TiDBKVStorage": ["TIDB_URI", "TIDB_DATABASE"], + "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "OracleKVStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + # Graph Storage Implementations + "NetworkXStorage": [], + "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], + "MongoGraphStorage": [], + "TiDBGraphStorage": ["TIDB_URI", "TIDB_DATABASE"], + "AGEStorage": [ + "AGE_POSTGRES_DB", + "AGE_POSTGRES_USER", + "AGE_POSTGRES_PASSWORD", + "AGE_GRAPH_NAME", + ], + "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], + "PGGraphStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "OracleGraphStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + # Vector Storage Implementations + "NanoVectorDBStorage": [], + "MilvusVectorDBStorge": [], + "ChromaVectorDBStorage": [], + "TiDBVectorDBStorage": ["TIDB_URI", "TIDB_DATABASE"], + "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "FaissVectorDBStorage": [], + "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None + "OracleVectorDBStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + # Document Status Storage Implementations + "JsonDocStatusStorage": [], + "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], +} + +# Storage implementation module mapping STORAGES = { "NetworkXStorage": ".kg.networkx_impl", "JsonKVStorage": ".kg.json_kv_impl", @@ -193,6 +276,61 @@ class LightRAG: list[dict[str, Any]], ] = chunking_by_token_size + def verify_storage_implementation( + self, storage_type: str, storage_name: str + ) -> None: + """Verify if storage implementation is compatible with specified storage type + + Args: + storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.) + storage_name: Storage implementation name + + Raises: + ValueError: If storage implementation is incompatible or missing required methods + """ + if storage_type not in STORAGE_IMPLEMENTATIONS: + raise ValueError(f"Unknown storage type: {storage_type}") + + storage_info = STORAGE_IMPLEMENTATIONS[storage_type] + if storage_name not in storage_info["implementations"]: + raise ValueError( + f"Storage implementation '{storage_name}' is not compatible with {storage_type}. " + f"Compatible implementations are: {', '.join(storage_info['implementations'])}" + ) + + # Get storage class + storage_class = self._get_storage_class(storage_name) + + # Check required methods + missing_methods = [] + for method in storage_info["required_methods"]: + if not hasattr(storage_class, method): + missing_methods.append(method) + + if missing_methods: + raise ValueError( + f"Storage implementation '{storage_name}' is missing required methods: " + f"{', '.join(missing_methods)}" + ) + + def check_storage_env_vars(self, storage_name: str) -> None: + """Check if all required environment variables for storage implementation exist + + Args: + storage_name: Storage implementation name + + Raises: + ValueError: If required environment variables are missing + """ + required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + raise ValueError( + f"Storage implementation '{storage_name}' requires the following " + f"environment variables: {', '.join(missing_vars)}" + ) + def __post_init__(self): os.makedirs(self.log_dir, exist_ok=True) log_file = os.path.join(self.log_dir, "lightrag.log") @@ -204,6 +342,20 @@ class LightRAG: logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) + # Verify storage implementation compatibility and environment variables + storage_configs = [ + ("KV_STORAGE", self.kv_storage), + ("VECTOR_STORAGE", self.vector_storage), + ("GRAPH_STORAGE", self.graph_storage), + ("DOC_STATUS_STORAGE", self.doc_status_storage), + ] + + for storage_type, storage_name in storage_configs: + # Verify storage implementation compatibility + self.verify_storage_implementation(storage_type, storage_name) + # Check environment variables + self.check_storage_env_vars(storage_name) + # show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])