feat optimize storage configuration and environment variables

* add storage type compatibility validation table
* add enviroment variables check for storage
* modify storage init to get setting from confing.ini and env
This commit is contained in:
yangdx
2025-02-11 00:55:52 +08:00
parent d0779209d9
commit 56c1792767
9 changed files with 249 additions and 225 deletions

View File

@@ -26,7 +26,6 @@ import shutil
import aiofiles
from ascii_colors import trace_exception, ASCIIColors
import sys
import configparser
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
@@ -44,13 +43,40 @@ load_dotenv(override=True)
class RAGStorageConfig:
KV_STORAGE = "JsonKVStorage"
DOC_STATUS_STORAGE = "JsonDocStatusStorage"
GRAPH_STORAGE = "NetworkXStorage"
VECTOR_STORAGE = "NanoVectorDBStorage"
"""存储配置类,支持通过环境变量和命令行参数修改默认值"""
# 默认存储实现
DEFAULT_KV_STORAGE = "JsonKVStorage"
DEFAULT_DOC_STATUS_STORAGE = "JsonDocStatusStorage"
DEFAULT_GRAPH_STORAGE = "NetworkXStorage"
DEFAULT_VECTOR_STORAGE = "NanoVectorDBStorage"
def __init__(self):
# 从环境变量读取配置,如果没有则使用默认值
self.KV_STORAGE = os.getenv("LIGHTRAG_KV_STORAGE", self.DEFAULT_KV_STORAGE)
self.DOC_STATUS_STORAGE = os.getenv(
"LIGHTRAG_DOC_STATUS_STORAGE", self.DEFAULT_DOC_STATUS_STORAGE
)
self.GRAPH_STORAGE = os.getenv(
"LIGHTRAG_GRAPH_STORAGE", self.DEFAULT_GRAPH_STORAGE
)
self.VECTOR_STORAGE = os.getenv(
"LIGHTRAG_VECTOR_STORAGE", self.DEFAULT_VECTOR_STORAGE
)
def update_from_args(self, args):
"""从命令行参数更新配置"""
if hasattr(args, "kv_storage"):
self.KV_STORAGE = args.kv_storage
if hasattr(args, "doc_status_storage"):
self.DOC_STATUS_STORAGE = args.doc_status_storage
if hasattr(args, "graph_storage"):
self.GRAPH_STORAGE = args.graph_storage
if hasattr(args, "vector_storage"):
self.VECTOR_STORAGE = args.vector_storage
# Initialize rag storage config
# 初始化存储配置
rag_storage_config = RAGStorageConfig()
# Global progress tracker
@@ -81,60 +107,6 @@ def estimate_tokens(text: str) -> int:
return int(tokens)
# read config.ini
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Redis config
redis_uri = config.get("redis", "uri", fallback=None)
if redis_uri:
os.environ["REDIS_URI"] = redis_uri
rag_storage_config.KV_STORAGE = "RedisKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage"
# Neo4j config
neo4j_uri = config.get("neo4j", "uri", fallback=None)
neo4j_username = config.get("neo4j", "username", fallback=None)
neo4j_password = config.get("neo4j", "password", fallback=None)
if neo4j_uri:
os.environ["NEO4J_URI"] = neo4j_uri
os.environ["NEO4J_USERNAME"] = neo4j_username
os.environ["NEO4J_PASSWORD"] = neo4j_password
rag_storage_config.GRAPH_STORAGE = "Neo4JStorage"
# Milvus config
milvus_uri = config.get("milvus", "uri", fallback=None)
milvus_user = config.get("milvus", "user", fallback=None)
milvus_password = config.get("milvus", "password", fallback=None)
milvus_db_name = config.get("milvus", "db_name", fallback=None)
if milvus_uri:
os.environ["MILVUS_URI"] = milvus_uri
os.environ["MILVUS_USER"] = milvus_user
os.environ["MILVUS_PASSWORD"] = milvus_password
os.environ["MILVUS_DB_NAME"] = milvus_db_name
rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge"
# Qdrant config
qdrant_uri = config.get("qdrant", "uri", fallback=None)
qdrant_api_key = config.get("qdrant", "apikey", fallback=None)
if qdrant_uri:
os.environ["QDRANT_URL"] = qdrant_uri
if qdrant_api_key:
os.environ["QDRANT_API_KEY"] = qdrant_api_key
rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage"
# MongoDB config
mongo_uri = config.get("mongodb", "uri", fallback=None)
mongo_database = config.get("mongodb", "database", fallback="LightRAG")
mongo_graph = config.getboolean("mongodb", "graph", fallback=False)
if mongo_uri:
os.environ["MONGO_URI"] = mongo_uri
os.environ["MONGO_DATABASE"] = mongo_database
rag_storage_config.KV_STORAGE = "MongoKVStorage"
rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage"
if mongo_graph:
rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage"
def get_default_host(binding_type: str) -> str:
default_hosts = {
"ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
@@ -340,6 +312,27 @@ def parse_args() -> argparse.Namespace:
description="LightRAG FastAPI Server with separate working and input directories"
)
parser.add_argument(
"--kv-storage",
default=rag_storage_config.KV_STORAGE,
help=f"KV存储实现 (default: {rag_storage_config.KV_STORAGE})",
)
parser.add_argument(
"--doc-status-storage",
default=rag_storage_config.DOC_STATUS_STORAGE,
help=f"文档状态存储实现 (default: {rag_storage_config.DOC_STATUS_STORAGE})",
)
parser.add_argument(
"--graph-storage",
default=rag_storage_config.GRAPH_STORAGE,
help=f"图存储实现 (default: {rag_storage_config.GRAPH_STORAGE})",
)
parser.add_argument(
"--vector-storage",
default=rag_storage_config.VECTOR_STORAGE,
help=f"向量存储实现 (default: {rag_storage_config.VECTOR_STORAGE})",
)
# Bindings configuration
parser.add_argument(
"--llm-binding",
@@ -554,6 +547,8 @@ def parse_args() -> argparse.Namespace:
args = parser.parse_args()
rag_storage_config.update_from_args(args)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args