feat optimize storage configuration and environment variables

* add storage type compatibility validation table
* add enviroment variables check for storage
* modify storage init to get setting from confing.ini and env
This commit is contained in:
yangdx
2025-02-11 00:55:52 +08:00
parent d0779209d9
commit 56c1792767
9 changed files with 249 additions and 225 deletions

View File

@@ -26,7 +26,6 @@ import shutil
import aiofiles import aiofiles
from ascii_colors import trace_exception, ASCIIColors from ascii_colors import trace_exception, ASCIIColors
import sys import sys
import configparser
from fastapi import Depends, Security from fastapi import Depends, Security
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -44,13 +43,40 @@ load_dotenv(override=True)
class RAGStorageConfig: class RAGStorageConfig:
KV_STORAGE = "JsonKVStorage" """存储配置类,支持通过环境变量和命令行参数修改默认值"""
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage" # 默认存储实现
VECTOR_STORAGE = "NanoVectorDBStorage" DEFAULT_KV_STORAGE = "JsonKVStorage"
DEFAULT_DOC_STATUS_STORAGE = "JsonDocStatusStorage"
DEFAULT_GRAPH_STORAGE = "NetworkXStorage"
DEFAULT_VECTOR_STORAGE = "NanoVectorDBStorage"
def __init__(self):
# 从环境变量读取配置,如果没有则使用默认值
self.KV_STORAGE = os.getenv("LIGHTRAG_KV_STORAGE", self.DEFAULT_KV_STORAGE)
self.DOC_STATUS_STORAGE = os.getenv(
"LIGHTRAG_DOC_STATUS_STORAGE", self.DEFAULT_DOC_STATUS_STORAGE
)
self.GRAPH_STORAGE = os.getenv(
"LIGHTRAG_GRAPH_STORAGE", self.DEFAULT_GRAPH_STORAGE
)
self.VECTOR_STORAGE = os.getenv(
"LIGHTRAG_VECTOR_STORAGE", self.DEFAULT_VECTOR_STORAGE
)
def update_from_args(self, args):
"""从命令行参数更新配置"""
if hasattr(args, "kv_storage"):
self.KV_STORAGE = args.kv_storage
if hasattr(args, "doc_status_storage"):
self.DOC_STATUS_STORAGE = args.doc_status_storage
if hasattr(args, "graph_storage"):
self.GRAPH_STORAGE = args.graph_storage
if hasattr(args, "vector_storage"):
self.VECTOR_STORAGE = args.vector_storage
# Initialize rag storage config # 初始化存储配置
rag_storage_config = RAGStorageConfig() rag_storage_config = RAGStorageConfig()
# Global progress tracker # Global progress tracker
@@ -81,60 +107,6 @@ def estimate_tokens(text: str) -> int:
return int(tokens) return int(tokens)
# read config.ini
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Redis config
redis_uri = config.get("redis", "uri", fallback=None)
if redis_uri:
os.environ["REDIS_URI"] = redis_uri
rag_storage_config.KV_STORAGE = "RedisKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
# Neo4j config
neo4j_uri = config.get("neo4j", "uri", fallback=None)
neo4j_username = config.get("neo4j", "username", fallback=None)
neo4j_password = config.get("neo4j", "password", fallback=None)
if neo4j_uri:
os.environ["NEO4J_URI"] = neo4j_uri
os.environ["NEO4J_USERNAME"] = neo4j_username
os.environ["NEO4J_PASSWORD"] = neo4j_password
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
# Milvus config
milvus_uri = config.get("milvus", "uri", fallback=None)
milvus_user = config.get("milvus", "user", fallback=None)
milvus_password = config.get("milvus", "password", fallback=None)
milvus_db_name = config.get("milvus", "db_name", fallback=None)
if milvus_uri:
os.environ["MILVUS_URI"] = milvus_uri
os.environ["MILVUS_USER"] = milvus_user
os.environ["MILVUS_PASSWORD"] = milvus_password
os.environ["MILVUS_DB_NAME"] = milvus_db_name
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
# Qdrant config
qdrant_uri = config.get("qdrant", "uri", fallback=None)
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
if qdrant_uri:
os.environ["QDRANT_URL"] = qdrant_uri
if qdrant_api_key:
os.environ["QDRANT_API_KEY"] = qdrant_api_key
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
# MongoDB config
mongo_uri = config.get("mongodb", "uri", fallback=None)
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
if mongo_graph:
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
def get_default_host(binding_type: str) -> str: def get_default_host(binding_type: str) -> str:
default_hosts = { default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
@@ -340,6 +312,27 @@ def parse_args() -> argparse.Namespace:
description="LightRAG FastAPI Server with separate working and input directories" description="LightRAG FastAPI Server with separate working and input directories"
) )
parser.add_argument(
"--kv-storage",
default=rag_storage_config.KV_STORAGE,
help=f"KV存储实现 (default: {rag_storage_config.KV_STORAGE})",
)
parser.add_argument(
"--doc-status-storage",
default=rag_storage_config.DOC_STATUS_STORAGE,
help=f"文档状态存储实现 (default: {rag_storage_config.DOC_STATUS_STORAGE})",
)
parser.add_argument(
"--graph-storage",
default=rag_storage_config.GRAPH_STORAGE,
help=f"图存储实现 (default: {rag_storage_config.GRAPH_STORAGE})",
)
parser.add_argument(
"--vector-storage",
default=rag_storage_config.VECTOR_STORAGE,
help=f"向量存储实现 (default: {rag_storage_config.VECTOR_STORAGE})",
)
# Bindings configuration # Bindings configuration
parser.add_argument( parser.add_argument(
"--llm-binding", "--llm-binding",
@@ -554,6 +547,8 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args() args = parser.parse_args()
rag_storage_config.update_from_args(args)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args return args

View File

@@ -47,7 +47,9 @@ class GremlinStorage(BaseGraphStorage):
# All vertices will have graph={GRAPH} property, so that we can # All vertices will have graph={GRAPH} property, so that we can
# have several logical graphs for one source # have several logical graphs for one source
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"]) GRAPH = GremlinStorage._to_value_map(
os.environ.get("GREMLIN_GRAPH", "LightRAG")
)
self.graph_name = GRAPH self.graph_name = GRAPH

View File

@@ -5,14 +5,17 @@ from dataclasses import dataclass
import numpy as np import numpy as np
from lightrag.utils import logger from lightrag.utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("pymilvus"): if not pm.is_installed("pymilvus"):
pm.install("pymilvus") pm.install("pymilvus")
from pymilvus import MilvusClient from pymilvus import MilvusClient
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class MilvusVectorDBStorge(BaseVectorStorage): class MilvusVectorDBStorge(BaseVectorStorage):
@staticmethod @staticmethod
@@ -27,14 +30,11 @@ class MilvusVectorDBStorge(BaseVectorStorage):
def __post_init__(self): def __post_init__(self):
self._client = MilvusClient( self._client = MilvusClient(
uri=os.environ.get( uri = os.environ.get("MILVUS_URI", config.get("milvus", "uri", fallback=os.path.join(self.global_config["working_dir"], "milvus_lite.db"))),
"MILVUS_URI", user = os.environ.get("MILVUS_USER", config.get("milvus", "user", fallback=None)),
os.path.join(self.global_config["working_dir"], "milvus_lite.db"), password = os.environ.get("MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)),
), token = os.environ.get("MILVUS_TOKEN", config.get("milvus", "token", fallback=None)),
user=os.environ.get("MILVUS_USER", ""), db_name = os.environ.get("MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)),
password=os.environ.get("MILVUS_PASSWORD", ""),
token=os.environ.get("MILVUS_TOKEN", ""),
db_name=os.environ.get("MILVUS_DB_NAME", ""),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorge.create_collection_if_not_exist( MilvusVectorDBStorge.create_collection_if_not_exist(

View File

@@ -1,8 +1,8 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
import pipmaster as pm import pipmaster as pm
import configparser
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
if not pm.is_installed("pymongo"): if not pm.is_installed("pymongo"):
@@ -12,22 +12,23 @@ if not pm.is_installed("motor"):
pm.install("motor") pm.install("motor")
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient from pymongo import MongoClient
from ..base import BaseGraphStorage, BaseKVStorage from ..base import BaseGraphStorage, BaseKVStorage
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class MongoKVStorage(BaseKVStorage): class MongoKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
client = MongoClient( client = MongoClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"))
) )
database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) database = client.get_database(os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG")))
self._data = database.get_collection(self.namespace) self._data = database.get_collection(self.namespace)
logger.info(f"Use MongoDB as KV {self.namespace}") logger.info(f"Use MongoDB as KV {self.namespace}")
@@ -90,10 +91,10 @@ class MongoGraphStorage(BaseGraphStorage):
embedding_func=embedding_func, embedding_func=embedding_func,
) )
self.client = AsyncIOMotorClient( self.client = AsyncIOMotorClient(
os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"))
) )
self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")] self.db = self.client[os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))]
self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")] self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"))]
# #
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------

View File

@@ -5,6 +5,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union, Tuple, List, Dict from typing import Any, Union, Tuple, List, Dict
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
pm.install("neo4j") pm.install("neo4j")
@@ -27,6 +28,9 @@ from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@staticmethod @staticmethod
@@ -41,13 +45,15 @@ class Neo4JStorage(BaseGraphStorage):
) )
self._driver = None self._driver = None
self._driver_lock = asyncio.Lock() self._driver_lock = asyncio.Lock()
URI = os.environ["NEO4J_URI"]
USERNAME = os.environ["NEO4J_USERNAME"] URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)]
PASSWORD = os.environ["NEO4J_PASSWORD"] USERNAME = os.environ["NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) PASSWORD = os.environ["NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)]
MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", config.get("neo4j", "connection_pool_size", fallback=800))
DATABASE = os.environ.get( DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace)
) )
self._driver: AsyncDriver = AsyncGraphDatabase.driver( self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI, auth=(USERNAME, PASSWORD) URI, auth=(USERNAME, PASSWORD)
) )

View File

@@ -1,138 +0,0 @@
import asyncio
import sys
import os
import pipmaster as pm
if not pm.is_installed("psycopg-pool"):
pm.install("psycopg-pool")
pm.install("psycopg[binary,pool]")
if not pm.is_installed("asyncpg"):
pm.install("asyncpg")
import asyncpg
import psycopg
from psycopg_pool import AsyncConnectionPool
from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from ..namespace import NameSpace
DB = "rag"
USER = "rag"
PASSWORD = "rag"
HOST = "localhost"
PORT = "15432"
os.environ["AGE_GRAPH_NAME"] = "dickens"
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def get_pool():
return await asyncpg.create_pool(
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
min_size=10,
max_size=10,
max_queries=5000,
max_inactive_connection_lifetime=300.0,
)
async def main1():
connection_string = (
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
)
pool = AsyncConnectionPool(connection_string, open=False)
await pool.open()
try:
conn = await pool.getconn(timeout=10)
async with conn.cursor() as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute("SELECT create_graph('dickens-2')")
await conn.commit()
print("create_graph success")
except (
psycopg.errors.InvalidSchemaName,
psycopg.errors.UniqueViolation,
):
print("create_graph already exists")
await conn.rollback()
finally:
pass
db = PostgreSQLDB(
config={
"host": "localhost",
"port": 15432,
"user": "rag",
"password": "rag",
"database": "r1",
}
)
async def query_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)
graph.db = db
res = await graph.get_node('"A CHRISTMAS CAROL"')
print("Node is: ", res)
res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
print("Edge is: ", res)
res = await graph.get_node_edges('"SCROOGE"')
print("Node Edges are: ", res)
async def create_edge_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
global_config={},
embedding_func=None,
)
graph.db = db
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
await graph.upsert_edge(
'"THE CRATCHITS"',
'"THE GIRLS"',
edge_data={
"weight": 7.0,
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
"keywords": '"family, collective effort"',
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
},
)
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
print("Edge is: ", res)
async def main():
pool = await get_pool()
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
async with pool.acquire() as conn:
try:
await conn.execute(
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
)
except asyncpg.exceptions.InvalidSchemaNameError:
print("create_graph already exists")
# stmt = await conn.prepare(sql)
row = await conn.fetch(sql)
print("row is: ", row)
row = await conn.fetchrow("select '100'::int + 200 as result")
print(row) # <Record result=300>
if __name__ == "__main__":
asyncio.run(query_with_age())

View File

@@ -5,11 +5,10 @@ from dataclasses import dataclass
import numpy as np import numpy as np
import hashlib import hashlib
import uuid import uuid
from ..utils import logger from ..utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("qdrant_client"): if not pm.is_installed("qdrant_client"):
pm.install("qdrant_client") pm.install("qdrant_client")
@@ -17,6 +16,9 @@ if not pm.is_installed("qdrant_client"):
from qdrant_client import QdrantClient, models from qdrant_client import QdrantClient, models
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
def compute_mdhash_id_for_qdrant( def compute_mdhash_id_for_qdrant(
content: str, prefix: str = "", style: str = "simple" content: str, prefix: str = "", style: str = "simple"
) -> str: ) -> str:
@@ -57,8 +59,8 @@ class QdrantVectorDBStorage(BaseVectorStorage):
def __post_init__(self): def __post_init__(self):
self._client = QdrantClient( self._client = QdrantClient(
url=os.environ.get("QDRANT_URL"), url=os.environ.get("QDRANT_URL", config.get("qdrant", "uri", fallback=None)),
api_key=os.environ.get("QDRANT_API_KEY", None), api_key=os.environ.get("QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist( QdrantVectorDBStorage.create_collection_if_not_exist(

View File

@@ -3,6 +3,7 @@ from typing import Any, Union
from tqdm.asyncio import tqdm as tqdm_async from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
import configparser
if not pm.is_installed("redis"): if not pm.is_installed("redis"):
pm.install("redis") pm.install("redis")
@@ -14,10 +15,13 @@ from lightrag.base import BaseKVStorage
import json import json
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@dataclass @dataclass
class RedisKVStorage(BaseKVStorage): class RedisKVStorage(BaseKVStorage):
def __post_init__(self): def __post_init__(self):
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379") redis_url = os.environ.get("REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379"))
self._redis = Redis.from_url(redis_url, decode_responses=True) self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}") logger.info(f"Use Redis as KV {self.namespace}")

View File

@@ -35,6 +35,89 @@ from .utils import (
set_logger, set_logger,
) )
# Storage type and implementation compatibility validation table
STORAGE_IMPLEMENTATIONS = {
"KV_STORAGE": {
"implementations": [
"JsonKVStorage",
"MongoKVStorage",
"RedisKVStorage",
"TiDBKVStorage",
"PGKVStorage",
"OracleKVStorage",
],
"required_methods": ["get_by_id", "upsert"],
},
"GRAPH_STORAGE": {
"implementations": [
"NetworkXStorage",
"Neo4JStorage",
"MongoGraphStorage",
"TiDBGraphStorage",
"AGEStorage",
"GremlinStorage",
"PGGraphStorage",
"OracleGraphStorage",
],
"required_methods": ["upsert_node", "upsert_edge"],
},
"VECTOR_STORAGE": {
"implementations": [
"NanoVectorDBStorage",
"MilvusVectorDBStorge",
"ChromaVectorDBStorage",
"TiDBVectorDBStorage",
"PGVectorStorage",
"FaissVectorDBStorage",
"QdrantVectorDBStorage",
"OracleVectorDBStorage",
],
"required_methods": ["query", "upsert"],
},
"DOC_STATUS_STORAGE": {
"implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"],
"required_methods": ["get_pending_docs"],
},
}
# Storage implementation environment variable without default value
STORAGE_ENV_REQUIREMENTS = {
# KV Storage Implementations
"JsonKVStorage": [],
"MongoKVStorage": [],
"RedisKVStorage": ["REDIS_URI"],
"TiDBKVStorage": ["TIDB_URI", "TIDB_DATABASE"],
"PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleKVStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"],
# Graph Storage Implementations
"NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [],
"TiDBGraphStorage": ["TIDB_URI", "TIDB_DATABASE"],
"AGEStorage": [
"AGE_POSTGRES_DB",
"AGE_POSTGRES_USER",
"AGE_POSTGRES_PASSWORD",
"AGE_GRAPH_NAME",
],
"GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"],
"PGGraphStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"OracleGraphStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"],
# Vector Storage Implementations
"NanoVectorDBStorage": [],
"MilvusVectorDBStorge": [],
"ChromaVectorDBStorage": [],
"TiDBVectorDBStorage": ["TIDB_URI", "TIDB_DATABASE"],
"PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
"FaissVectorDBStorage": [],
"QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None
"OracleVectorDBStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"],
# Document Status Storage Implementations
"JsonDocStatusStorage": [],
"PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"],
}
# Storage implementation module mapping
STORAGES = { STORAGES = {
"NetworkXStorage": ".kg.networkx_impl", "NetworkXStorage": ".kg.networkx_impl",
"JsonKVStorage": ".kg.json_kv_impl", "JsonKVStorage": ".kg.json_kv_impl",
@@ -193,6 +276,61 @@ class LightRAG:
list[dict[str, Any]], list[dict[str, Any]],
] = chunking_by_token_size ] = chunking_by_token_size
def verify_storage_implementation(
self, storage_type: str, storage_name: str
) -> None:
"""Verify if storage implementation is compatible with specified storage type
Args:
storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.)
storage_name: Storage implementation name
Raises:
ValueError: If storage implementation is incompatible or missing required methods
"""
if storage_type not in STORAGE_IMPLEMENTATIONS:
raise ValueError(f"Unknown storage type: {storage_type}")
storage_info = STORAGE_IMPLEMENTATIONS[storage_type]
if storage_name not in storage_info["implementations"]:
raise ValueError(
f"Storage implementation '{storage_name}' is not compatible with {storage_type}. "
f"Compatible implementations are: {', '.join(storage_info['implementations'])}"
)
# Get storage class
storage_class = self._get_storage_class(storage_name)
# Check required methods
missing_methods = []
for method in storage_info["required_methods"]:
if not hasattr(storage_class, method):
missing_methods.append(method)
if missing_methods:
raise ValueError(
f"Storage implementation '{storage_name}' is missing required methods: "
f"{', '.join(missing_methods)}"
)
def check_storage_env_vars(self, storage_name: str) -> None:
"""Check if all required environment variables for storage implementation exist
Args:
storage_name: Storage implementation name
Raises:
ValueError: If required environment variables are missing
"""
required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Storage implementation '{storage_name}' requires the following "
f"environment variables: {', '.join(missing_vars)}"
)
def __post_init__(self): def __post_init__(self):
os.makedirs(self.log_dir, exist_ok=True) os.makedirs(self.log_dir, exist_ok=True)
log_file = os.path.join(self.log_dir, "lightrag.log") log_file = os.path.join(self.log_dir, "lightrag.log")
@@ -204,6 +342,20 @@ class LightRAG:
logger.info(f"Creating working directory {self.working_dir}") logger.info(f"Creating working directory {self.working_dir}")
os.makedirs(self.working_dir) os.makedirs(self.working_dir)
# Verify storage implementation compatibility and environment variables
storage_configs = [
("KV_STORAGE", self.kv_storage),
("VECTOR_STORAGE", self.vector_storage),
("GRAPH_STORAGE", self.graph_storage),
("DOC_STATUS_STORAGE", self.doc_status_storage),
]
for storage_type, storage_name in storage_configs:
# Verify storage implementation compatibility
self.verify_storage_implementation(storage_type, storage_name)
# Check environment variables
self.check_storage_env_vars(storage_name)
# show config # show config
global_config = asdict(self) global_config = asdict(self)
_print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])