From 56c1792767fe29962c19150cacf283792fe0c383 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 00:55:52 +0800 Subject: [PATCH 01/35] feat optimize storage configuration and environment variables * add storage type compatibility validation table * add enviroment variables check for storage * modify storage init to get setting from confing.ini and env --- lightrag/api/lightrag_server.py | 115 +++++++++++----------- lightrag/kg/gremlin_impl.py | 4 +- lightrag/kg/milvus_impl.py | 18 ++-- lightrag/kg/mongo_impl.py | 17 ++-- lightrag/kg/neo4j_impl.py | 14 ++- lightrag/kg/postgres_impl_test.py | 138 --------------------------- lightrag/kg/qdrant_impl.py | 10 +- lightrag/kg/redis_impl.py | 6 +- lightrag/lightrag.py | 152 ++++++++++++++++++++++++++++++ 9 files changed, 249 insertions(+), 225 deletions(-) delete mode 100644 lightrag/kg/postgres_impl_test.py diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8d13fab0..8d85c292 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -26,7 +26,6 @@ import shutil import aiofiles from ascii_colors import trace_exception, ASCIIColors import sys -import configparser from fastapi import Depends, Security from fastapi.security import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware @@ -44,13 +43,40 @@ load_dotenv(override=True) class RAGStorageConfig: - KV_STORAGE = "JsonKVStorage" - DOC_STATUS_STORAGE = "JsonDocStatusStorage" - GRAPH_STORAGE = "NetworkXStorage" - VECTOR_STORAGE = "NanoVectorDBStorage" + """存储配置类,支持通过环境变量和命令行参数修改默认值""" + + # 默认存储实现 + DEFAULT_KV_STORAGE = "JsonKVStorage" + DEFAULT_DOC_STATUS_STORAGE = "JsonDocStatusStorage" + DEFAULT_GRAPH_STORAGE = "NetworkXStorage" + DEFAULT_VECTOR_STORAGE = "NanoVectorDBStorage" + + def __init__(self): + # 从环境变量读取配置,如果没有则使用默认值 + self.KV_STORAGE = os.getenv("LIGHTRAG_KV_STORAGE", self.DEFAULT_KV_STORAGE) + self.DOC_STATUS_STORAGE = os.getenv( + "LIGHTRAG_DOC_STATUS_STORAGE", self.DEFAULT_DOC_STATUS_STORAGE + ) + self.GRAPH_STORAGE = os.getenv( + "LIGHTRAG_GRAPH_STORAGE", self.DEFAULT_GRAPH_STORAGE + ) + self.VECTOR_STORAGE = os.getenv( + "LIGHTRAG_VECTOR_STORAGE", self.DEFAULT_VECTOR_STORAGE + ) + + def update_from_args(self, args): + """从命令行参数更新配置""" + if hasattr(args, "kv_storage"): + self.KV_STORAGE = args.kv_storage + if hasattr(args, "doc_status_storage"): + self.DOC_STATUS_STORAGE = args.doc_status_storage + if hasattr(args, "graph_storage"): + self.GRAPH_STORAGE = args.graph_storage + if hasattr(args, "vector_storage"): + self.VECTOR_STORAGE = args.vector_storage -# Initialize rag storage config +# 初始化存储配置 rag_storage_config = RAGStorageConfig() # Global progress tracker @@ -81,60 +107,6 @@ def estimate_tokens(text: str) -> int: return int(tokens) -# read config.ini -config = configparser.ConfigParser() -config.read("config.ini", "utf-8") -# Redis config -redis_uri = config.get("redis", "uri", fallback=None) -if redis_uri: - os.environ["REDIS_URI"] = redis_uri - rag_storage_config.KV_STORAGE = "RedisKVStorage" - rag_storage_config.DOC_STATUS_STORAGE = "RedisKVStorage" - -# Neo4j config -neo4j_uri = config.get("neo4j", "uri", fallback=None) -neo4j_username = config.get("neo4j", "username", fallback=None) -neo4j_password = config.get("neo4j", "password", fallback=None) -if neo4j_uri: - os.environ["NEO4J_URI"] = neo4j_uri - os.environ["NEO4J_USERNAME"] = neo4j_username - os.environ["NEO4J_PASSWORD"] = neo4j_password - rag_storage_config.GRAPH_STORAGE = "Neo4JStorage" - -# Milvus config -milvus_uri = config.get("milvus", "uri", fallback=None) -milvus_user = config.get("milvus", "user", fallback=None) -milvus_password = config.get("milvus", "password", fallback=None) -milvus_db_name = config.get("milvus", "db_name", fallback=None) -if milvus_uri: - os.environ["MILVUS_URI"] = milvus_uri - os.environ["MILVUS_USER"] = milvus_user - os.environ["MILVUS_PASSWORD"] = milvus_password - os.environ["MILVUS_DB_NAME"] = milvus_db_name - rag_storage_config.VECTOR_STORAGE = "MilvusVectorDBStorge" - -# Qdrant config -qdrant_uri = config.get("qdrant", "uri", fallback=None) -qdrant_api_key = config.get("qdrant", "apikey", fallback=None) -if qdrant_uri: - os.environ["QDRANT_URL"] = qdrant_uri - if qdrant_api_key: - os.environ["QDRANT_API_KEY"] = qdrant_api_key - rag_storage_config.VECTOR_STORAGE = "QdrantVectorDBStorage" - -# MongoDB config -mongo_uri = config.get("mongodb", "uri", fallback=None) -mongo_database = config.get("mongodb", "database", fallback="LightRAG") -mongo_graph = config.getboolean("mongodb", "graph", fallback=False) -if mongo_uri: - os.environ["MONGO_URI"] = mongo_uri - os.environ["MONGO_DATABASE"] = mongo_database - rag_storage_config.KV_STORAGE = "MongoKVStorage" - rag_storage_config.DOC_STATUS_STORAGE = "MongoKVStorage" - if mongo_graph: - rag_storage_config.GRAPH_STORAGE = "MongoGraphStorage" - - def get_default_host(binding_type: str) -> str: default_hosts = { "ollama": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"), @@ -340,6 +312,27 @@ def parse_args() -> argparse.Namespace: description="LightRAG FastAPI Server with separate working and input directories" ) + parser.add_argument( + "--kv-storage", + default=rag_storage_config.KV_STORAGE, + help=f"KV存储实现 (default: {rag_storage_config.KV_STORAGE})", + ) + parser.add_argument( + "--doc-status-storage", + default=rag_storage_config.DOC_STATUS_STORAGE, + help=f"文档状态存储实现 (default: {rag_storage_config.DOC_STATUS_STORAGE})", + ) + parser.add_argument( + "--graph-storage", + default=rag_storage_config.GRAPH_STORAGE, + help=f"图存储实现 (default: {rag_storage_config.GRAPH_STORAGE})", + ) + parser.add_argument( + "--vector-storage", + default=rag_storage_config.VECTOR_STORAGE, + help=f"向量存储实现 (default: {rag_storage_config.VECTOR_STORAGE})", + ) + # Bindings configuration parser.add_argument( "--llm-binding", @@ -554,6 +547,8 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() + rag_storage_config.update_from_args(args) + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args diff --git a/lightrag/kg/gremlin_impl.py b/lightrag/kg/gremlin_impl.py index d3c64617..f38fd00a 100644 --- a/lightrag/kg/gremlin_impl.py +++ b/lightrag/kg/gremlin_impl.py @@ -47,7 +47,9 @@ class GremlinStorage(BaseGraphStorage): # All vertices will have graph={GRAPH} property, so that we can # have several logical graphs for one source - GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"]) + GRAPH = GremlinStorage._to_value_map( + os.environ.get("GREMLIN_GRAPH", "LightRAG") + ) self.graph_name = GRAPH diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 2995fd9b..35260652 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -5,14 +5,17 @@ from dataclasses import dataclass import numpy as np from lightrag.utils import logger from ..base import BaseVectorStorage - import pipmaster as pm +import configparser if not pm.is_installed("pymilvus"): pm.install("pymilvus") from pymilvus import MilvusClient +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class MilvusVectorDBStorge(BaseVectorStorage): @staticmethod @@ -27,14 +30,11 @@ class MilvusVectorDBStorge(BaseVectorStorage): def __post_init__(self): self._client = MilvusClient( - uri=os.environ.get( - "MILVUS_URI", - os.path.join(self.global_config["working_dir"], "milvus_lite.db"), - ), - user=os.environ.get("MILVUS_USER", ""), - password=os.environ.get("MILVUS_PASSWORD", ""), - token=os.environ.get("MILVUS_TOKEN", ""), - db_name=os.environ.get("MILVUS_DB_NAME", ""), + uri = os.environ.get("MILVUS_URI", config.get("milvus", "uri", fallback=os.path.join(self.global_config["working_dir"], "milvus_lite.db"))), + user = os.environ.get("MILVUS_USER", config.get("milvus", "user", fallback=None)), + password = os.environ.get("MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)), + token = os.environ.get("MILVUS_TOKEN", config.get("milvus", "token", fallback=None)), + db_name = os.environ.get("MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)), ) self._max_batch_size = self.global_config["embedding_batch_num"] MilvusVectorDBStorge.create_collection_if_not_exist( diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 4f919ecd..142050cf 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -1,8 +1,8 @@ import os from dataclasses import dataclass - import numpy as np import pipmaster as pm +import configparser from tqdm.asyncio import tqdm as tqdm_async if not pm.is_installed("pymongo"): @@ -12,22 +12,23 @@ if not pm.is_installed("motor"): pm.install("motor") from typing import Any, List, Tuple, Union - from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient - from ..base import BaseGraphStorage, BaseKVStorage from ..namespace import NameSpace, is_namespace from ..utils import logger +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): client = MongoClient( - os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) ) - database = client.get_database(os.environ.get("MONGO_DATABASE", "LightRAG")) + database = client.get_database(os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))) self._data = database.get_collection(self.namespace) logger.info(f"Use MongoDB as KV {self.namespace}") @@ -90,10 +91,10 @@ class MongoGraphStorage(BaseGraphStorage): embedding_func=embedding_func, ) self.client = AsyncIOMotorClient( - os.environ.get("MONGO_URI", "mongodb://root:root@localhost:27017/") + os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) ) - self.db = self.client[os.environ.get("MONGO_DATABASE", "LightRAG")] - self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", "MDB_KG")] + self.db = self.client[os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))] + self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"))] # # ------------------------------------------------------------------------- diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index fe01aaf3..61f85ad4 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -5,6 +5,7 @@ import re from dataclasses import dataclass from typing import Any, Union, Tuple, List, Dict import pipmaster as pm +import configparser if not pm.is_installed("neo4j"): pm.install("neo4j") @@ -27,6 +28,9 @@ from ..utils import logger from ..base import BaseGraphStorage +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class Neo4JStorage(BaseGraphStorage): @staticmethod @@ -41,13 +45,15 @@ class Neo4JStorage(BaseGraphStorage): ) self._driver = None self._driver_lock = asyncio.Lock() - URI = os.environ["NEO4J_URI"] - USERNAME = os.environ["NEO4J_USERNAME"] - PASSWORD = os.environ["NEO4J_PASSWORD"] - MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", 800) + + URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)] + USERNAME = os.environ["NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)] + PASSWORD = os.environ["NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)] + MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", config.get("neo4j", "connection_pool_size", fallback=800)) DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) ) + self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD) ) diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py deleted file mode 100644 index 304d556c..00000000 --- a/lightrag/kg/postgres_impl_test.py +++ /dev/null @@ -1,138 +0,0 @@ -import asyncio -import sys -import os -import pipmaster as pm - -if not pm.is_installed("psycopg-pool"): - pm.install("psycopg-pool") - pm.install("psycopg[binary,pool]") -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - -import asyncpg -import psycopg -from psycopg_pool import AsyncConnectionPool - -from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage -from ..namespace import NameSpace - -DB = "rag" -USER = "rag" -PASSWORD = "rag" -HOST = "localhost" -PORT = "15432" -os.environ["AGE_GRAPH_NAME"] = "dickens" - -if sys.platform.startswith("win"): - import asyncio.windows_events - - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - -async def get_pool(): - return await asyncpg.create_pool( - f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", - min_size=10, - max_size=10, - max_queries=5000, - max_inactive_connection_lifetime=300.0, - ) - - -async def main1(): - connection_string = ( - f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" - ) - pool = AsyncConnectionPool(connection_string, open=False) - await pool.open() - - try: - conn = await pool.getconn(timeout=10) - async with conn.cursor() as curs: - try: - await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute("SELECT create_graph('dickens-2')") - await conn.commit() - print("create_graph success") - except ( - psycopg.errors.InvalidSchemaName, - psycopg.errors.UniqueViolation, - ): - print("create_graph already exists") - await conn.rollback() - finally: - pass - - -db = PostgreSQLDB( - config={ - "host": "localhost", - "port": 15432, - "user": "rag", - "password": "rag", - "database": "r1", - } -) - - -async def query_with_age(): - await db.initdb() - graph = PGGraphStorage( - namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, - global_config={}, - embedding_func=None, - ) - graph.db = db - res = await graph.get_node('"A CHRISTMAS CAROL"') - print("Node is: ", res) - res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG") - print("Edge is: ", res) - res = await graph.get_node_edges('"SCROOGE"') - print("Node Edges are: ", res) - - -async def create_edge_with_age(): - await db.initdb() - graph = PGGraphStorage( - namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, - global_config={}, - embedding_func=None, - ) - graph.db = db - await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"}) - await graph.upsert_node('"THE GIRLS"', {"world": "hello"}) - await graph.upsert_edge( - '"THE CRATCHITS"', - '"THE GIRLS"', - edge_data={ - "weight": 7.0, - "description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.', - "keywords": '"family, collective effort"', - "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", - }, - ) - res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"') - print("Edge is: ", res) - - -async def main(): - pool = await get_pool() - sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" - # cypher = "MATCH (n:how_are_you_doing) RETURN n" - async with pool.acquire() as conn: - try: - await conn.execute( - """SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""" - ) - except asyncpg.exceptions.InvalidSchemaNameError: - print("create_graph already exists") - # stmt = await conn.prepare(sql) - row = await conn.fetch(sql) - print("row is: ", row) - - row = await conn.fetchrow("select '100'::int + 200 as result") - print(row) # - - -if __name__ == "__main__": - asyncio.run(query_with_age()) diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index e2f8d3a2..0a971b73 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -5,11 +5,10 @@ from dataclasses import dataclass import numpy as np import hashlib import uuid - from ..utils import logger from ..base import BaseVectorStorage - import pipmaster as pm +import configparser if not pm.is_installed("qdrant_client"): pm.install("qdrant_client") @@ -17,6 +16,9 @@ if not pm.is_installed("qdrant_client"): from qdrant_client import QdrantClient, models +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + def compute_mdhash_id_for_qdrant( content: str, prefix: str = "", style: str = "simple" ) -> str: @@ -57,8 +59,8 @@ class QdrantVectorDBStorage(BaseVectorStorage): def __post_init__(self): self._client = QdrantClient( - url=os.environ.get("QDRANT_URL"), - api_key=os.environ.get("QDRANT_API_KEY", None), + url=os.environ.get("QDRANT_URL", config.get("qdrant", "uri", fallback=None)), + api_key=os.environ.get("QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)), ) self._max_batch_size = self.global_config["embedding_batch_num"] QdrantVectorDBStorage.create_collection_if_not_exist( diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index e97a6afc..fef62d5f 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -3,6 +3,7 @@ from typing import Any, Union from tqdm.asyncio import tqdm as tqdm_async from dataclasses import dataclass import pipmaster as pm +import configparser if not pm.is_installed("redis"): pm.install("redis") @@ -14,10 +15,13 @@ from lightrag.base import BaseKVStorage import json +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): - redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379") + redis_url = os.environ.get("REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")) self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 347f0f4c..de0e4f59 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -35,6 +35,89 @@ from .utils import ( set_logger, ) +# Storage type and implementation compatibility validation table +STORAGE_IMPLEMENTATIONS = { + "KV_STORAGE": { + "implementations": [ + "JsonKVStorage", + "MongoKVStorage", + "RedisKVStorage", + "TiDBKVStorage", + "PGKVStorage", + "OracleKVStorage", + ], + "required_methods": ["get_by_id", "upsert"], + }, + "GRAPH_STORAGE": { + "implementations": [ + "NetworkXStorage", + "Neo4JStorage", + "MongoGraphStorage", + "TiDBGraphStorage", + "AGEStorage", + "GremlinStorage", + "PGGraphStorage", + "OracleGraphStorage", + ], + "required_methods": ["upsert_node", "upsert_edge"], + }, + "VECTOR_STORAGE": { + "implementations": [ + "NanoVectorDBStorage", + "MilvusVectorDBStorge", + "ChromaVectorDBStorage", + "TiDBVectorDBStorage", + "PGVectorStorage", + "FaissVectorDBStorage", + "QdrantVectorDBStorage", + "OracleVectorDBStorage", + ], + "required_methods": ["query", "upsert"], + }, + "DOC_STATUS_STORAGE": { + "implementations": ["JsonDocStatusStorage", "PGDocStatusStorage"], + "required_methods": ["get_pending_docs"], + }, +} + +# Storage implementation environment variable without default value +STORAGE_ENV_REQUIREMENTS = { + # KV Storage Implementations + "JsonKVStorage": [], + "MongoKVStorage": [], + "RedisKVStorage": ["REDIS_URI"], + "TiDBKVStorage": ["TIDB_URI", "TIDB_DATABASE"], + "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "OracleKVStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + # Graph Storage Implementations + "NetworkXStorage": [], + "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], + "MongoGraphStorage": [], + "TiDBGraphStorage": ["TIDB_URI", "TIDB_DATABASE"], + "AGEStorage": [ + "AGE_POSTGRES_DB", + "AGE_POSTGRES_USER", + "AGE_POSTGRES_PASSWORD", + "AGE_GRAPH_NAME", + ], + "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], + "PGGraphStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "OracleGraphStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + # Vector Storage Implementations + "NanoVectorDBStorage": [], + "MilvusVectorDBStorge": [], + "ChromaVectorDBStorage": [], + "TiDBVectorDBStorage": ["TIDB_URI", "TIDB_DATABASE"], + "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "FaissVectorDBStorage": [], + "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None + "OracleVectorDBStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + # Document Status Storage Implementations + "JsonDocStatusStorage": [], + "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], +} + +# Storage implementation module mapping STORAGES = { "NetworkXStorage": ".kg.networkx_impl", "JsonKVStorage": ".kg.json_kv_impl", @@ -193,6 +276,61 @@ class LightRAG: list[dict[str, Any]], ] = chunking_by_token_size + def verify_storage_implementation( + self, storage_type: str, storage_name: str + ) -> None: + """Verify if storage implementation is compatible with specified storage type + + Args: + storage_type: Storage type (KV_STORAGE, GRAPH_STORAGE etc.) + storage_name: Storage implementation name + + Raises: + ValueError: If storage implementation is incompatible or missing required methods + """ + if storage_type not in STORAGE_IMPLEMENTATIONS: + raise ValueError(f"Unknown storage type: {storage_type}") + + storage_info = STORAGE_IMPLEMENTATIONS[storage_type] + if storage_name not in storage_info["implementations"]: + raise ValueError( + f"Storage implementation '{storage_name}' is not compatible with {storage_type}. " + f"Compatible implementations are: {', '.join(storage_info['implementations'])}" + ) + + # Get storage class + storage_class = self._get_storage_class(storage_name) + + # Check required methods + missing_methods = [] + for method in storage_info["required_methods"]: + if not hasattr(storage_class, method): + missing_methods.append(method) + + if missing_methods: + raise ValueError( + f"Storage implementation '{storage_name}' is missing required methods: " + f"{', '.join(missing_methods)}" + ) + + def check_storage_env_vars(self, storage_name: str) -> None: + """Check if all required environment variables for storage implementation exist + + Args: + storage_name: Storage implementation name + + Raises: + ValueError: If required environment variables are missing + """ + required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + raise ValueError( + f"Storage implementation '{storage_name}' requires the following " + f"environment variables: {', '.join(missing_vars)}" + ) + def __post_init__(self): os.makedirs(self.log_dir, exist_ok=True) log_file = os.path.join(self.log_dir, "lightrag.log") @@ -204,6 +342,20 @@ class LightRAG: logger.info(f"Creating working directory {self.working_dir}") os.makedirs(self.working_dir) + # Verify storage implementation compatibility and environment variables + storage_configs = [ + ("KV_STORAGE", self.kv_storage), + ("VECTOR_STORAGE", self.vector_storage), + ("GRAPH_STORAGE", self.graph_storage), + ("DOC_STATUS_STORAGE", self.doc_status_storage), + ] + + for storage_type, storage_name in storage_configs: + # Verify storage implementation compatibility + self.verify_storage_implementation(storage_type, storage_name) + # Check environment variables + self.check_storage_env_vars(storage_name) + # show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) From 8cfca5a141e615deee13bbb5557b77825784dd8a Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 03:29:40 +0800 Subject: [PATCH 02/35] Fix linting --- lightrag/kg/milvus_impl.py | 28 +++++++++++++++++++++++----- lightrag/kg/mongo_impl.py | 36 +++++++++++++++++++++++++++++++----- lightrag/kg/neo4j_impl.py | 14 +++++++++++--- lightrag/kg/qdrant_impl.py | 9 +++++++-- lightrag/kg/redis_impl.py | 5 ++++- 5 files changed, 76 insertions(+), 16 deletions(-) diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 35260652..ae0daac2 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -16,6 +16,7 @@ from pymilvus import MilvusClient config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class MilvusVectorDBStorge(BaseVectorStorage): @staticmethod @@ -30,11 +31,28 @@ class MilvusVectorDBStorge(BaseVectorStorage): def __post_init__(self): self._client = MilvusClient( - uri = os.environ.get("MILVUS_URI", config.get("milvus", "uri", fallback=os.path.join(self.global_config["working_dir"], "milvus_lite.db"))), - user = os.environ.get("MILVUS_USER", config.get("milvus", "user", fallback=None)), - password = os.environ.get("MILVUS_PASSWORD", config.get("milvus", "password", fallback=None)), - token = os.environ.get("MILVUS_TOKEN", config.get("milvus", "token", fallback=None)), - db_name = os.environ.get("MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None)), + uri=os.environ.get( + "MILVUS_URI", + config.get( + "milvus", + "uri", + fallback=os.path.join( + self.global_config["working_dir"], "milvus_lite.db" + ), + ), + ), + user=os.environ.get( + "MILVUS_USER", config.get("milvus", "user", fallback=None) + ), + password=os.environ.get( + "MILVUS_PASSWORD", config.get("milvus", "password", fallback=None) + ), + token=os.environ.get( + "MILVUS_TOKEN", config.get("milvus", "token", fallback=None) + ), + db_name=os.environ.get( + "MILVUS_DB_NAME", config.get("milvus", "db_name", fallback=None) + ), ) self._max_batch_size = self.global_config["embedding_batch_num"] MilvusVectorDBStorge.create_collection_if_not_exist( diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 142050cf..017f57bd 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -22,13 +22,24 @@ from ..utils import logger config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class MongoKVStorage(BaseKVStorage): def __post_init__(self): client = MongoClient( - os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) + os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), + ) + ) + database = client.get_database( + os.environ.get( + "MONGO_DATABASE", + mongo_database=config.get("mongodb", "database", fallback="LightRAG"), + ) ) - database = client.get_database(os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))) self._data = database.get_collection(self.namespace) logger.info(f"Use MongoDB as KV {self.namespace}") @@ -91,10 +102,25 @@ class MongoGraphStorage(BaseGraphStorage): embedding_func=embedding_func, ) self.client = AsyncIOMotorClient( - os.environ.get("MONGO_URI", config.get("mongodb", "uri", fallback="mongodb://root:root@localhost:27017/")) + os.environ.get( + "MONGO_URI", + config.get( + "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/" + ), + ) ) - self.db = self.client[os.environ.get("MONGO_DATABASE", mongo_database = config.get("mongodb", "database", fallback="LightRAG"))] - self.collection = self.db[os.environ.get("MONGO_KG_COLLECTION", config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"))] + self.db = self.client[ + os.environ.get( + "MONGO_DATABASE", + mongo_database=config.get("mongodb", "database", fallback="LightRAG"), + ) + ] + self.collection = self.db[ + os.environ.get( + "MONGO_KG_COLLECTION", + config.getboolean("mongodb", "kg_collection", fallback="MDB_KG"), + ) + ] # # ------------------------------------------------------------------------- diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 61f85ad4..dfb58825 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -31,6 +31,7 @@ from ..base import BaseGraphStorage config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class Neo4JStorage(BaseGraphStorage): @staticmethod @@ -47,9 +48,16 @@ class Neo4JStorage(BaseGraphStorage): self._driver_lock = asyncio.Lock() URI = os.environ["NEO4J_URI", config.get("neo4j", "uri", fallback=None)] - USERNAME = os.environ["NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)] - PASSWORD = os.environ["NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)] - MAX_CONNECTION_POOL_SIZE = os.environ.get("NEO4J_MAX_CONNECTION_POOL_SIZE", config.get("neo4j", "connection_pool_size", fallback=800)) + USERNAME = os.environ[ + "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) + ] + PASSWORD = os.environ[ + "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None) + ] + MAX_CONNECTION_POOL_SIZE = os.environ.get( + "NEO4J_MAX_CONNECTION_POOL_SIZE", + config.get("neo4j", "connection_pool_size", fallback=800), + ) DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", namespace) ) diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 0a971b73..bda23f8d 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -19,6 +19,7 @@ from qdrant_client import QdrantClient, models config = configparser.ConfigParser() config.read("config.ini", "utf-8") + def compute_mdhash_id_for_qdrant( content: str, prefix: str = "", style: str = "simple" ) -> str: @@ -59,8 +60,12 @@ class QdrantVectorDBStorage(BaseVectorStorage): def __post_init__(self): self._client = QdrantClient( - url=os.environ.get("QDRANT_URL", config.get("qdrant", "uri", fallback=None)), - api_key=os.environ.get("QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None)), + url=os.environ.get( + "QDRANT_URL", config.get("qdrant", "uri", fallback=None) + ), + api_key=os.environ.get( + "QDRANT_API_KEY", config.get("qdrant", "apikey", fallback=None) + ), ) self._max_batch_size = self.global_config["embedding_batch_num"] QdrantVectorDBStorage.create_collection_if_not_exist( diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index fef62d5f..ed8f46f9 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -18,10 +18,13 @@ import json config = configparser.ConfigParser() config.read("config.ini", "utf-8") + @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): - redis_url = os.environ.get("REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")) + redis_url = os.environ.get( + "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") + ) self._redis = Redis.from_url(redis_url, decode_responses=True) logger.info(f"Use Redis as KV {self.namespace}") From a4cf7e66d3106e6dd45a6e966832ed4d6c2433c6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 03:54:54 +0800 Subject: [PATCH 03/35] Inject oracle db to LightRag storage class when needed --- lightrag/kg/oracle_impl.py | 3 ++- lightrag/lightrag.py | 52 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index ca6bcfb2..f34fe4b1 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -361,7 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage): @dataclass class OracleGraphStorage(BaseGraphStorage): - """基于Oracle的图存储模块""" + # should pass db object to self.db + db: OracleDB = None def __post_init__(self): """从graphml文件加载图""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index de0e4f59..6b9161be 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -1,5 +1,6 @@ import asyncio import os +import configparser from dataclasses import asdict, dataclass, field from datetime import datetime from functools import partial @@ -15,6 +16,7 @@ from .base import ( QueryParam, StorageNameSpace, ) +from .kg.oracle_impl import OracleDB from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, @@ -35,6 +37,9 @@ from .utils import ( set_logger, ) +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + # Storage type and implementation compatibility validation table STORAGE_IMPLEMENTATIONS = { "KV_STORAGE": { @@ -389,6 +394,53 @@ class LightRAG: self.graph_storage_cls, global_config=global_config ) + # 检查是否使用了 Oracle 存储实现 + if ( + self.kv_storage == "OracleKVStorage" + or self.vector_storage == "OracleVectorDBStorage" + or self.graph_storage == "OracleGraphStorage" + ): + # 从环境变量或配置文件获取参数 + dbconfig = { + "user": os.environ.get( + "ORACLE_USER", config.get("oracle", "user", fallback=None) + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", config.get("oracle", "dsn", fallback=None) + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + # 初始化 OracleDB 对象 + oracle_db = OracleDB(dbconfig) + + # 只对 Oracle 实现的存储类注入 db 对象 + if self.kv_storage == "OracleKVStorage": + self.key_string_value_json_storage_cls.db = oracle_db + if self.vector_storage == "OracleVectorDBStorage": + self.vector_db_storage_cls.db = oracle_db + if self.graph_storage == "OracleGraphStorage": + self.graph_storage_cls.db = oracle_db + self.json_doc_status_storage = self.key_string_value_json_storage_cls( namespace=self.namespace_prefix + "json_doc_status_storage", embedding_func=None, From 7ec769456cc561e6c221933b91269ec2d92b287a Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 03:55:15 +0800 Subject: [PATCH 04/35] Inject Postgres to LightRag storage class when needed --- lightrag/lightrag.py | 80 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 6b9161be..3603509a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -17,6 +17,7 @@ from .base import ( StorageNameSpace, ) from .kg.oracle_impl import OracleDB +from .kg.postgres_impl import PostgreSQLDB from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, @@ -394,6 +395,18 @@ class LightRAG: self.graph_storage_cls, global_config=global_config ) + self.json_doc_status_storage = self.key_string_value_json_storage_cls( + namespace=self.namespace_prefix + "json_doc_status_storage", + embedding_func=None, + ) + + self.llm_response_cache = self.key_string_value_json_storage_cls( + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), + embedding_func=self.embedding_func, + ) + # 检查是否使用了 Oracle 存储实现 if ( self.kv_storage == "OracleKVStorage" @@ -403,14 +416,16 @@ class LightRAG: # 从环境变量或配置文件获取参数 dbconfig = { "user": os.environ.get( - "ORACLE_USER", config.get("oracle", "user", fallback=None) + "ORACLE_USER", + config.get("oracle", "user", fallback=None), ), "password": os.environ.get( "ORACLE_PASSWORD", config.get("oracle", "password", fallback=None), ), "dsn": os.environ.get( - "ORACLE_DSN", config.get("oracle", "dsn", fallback=None) + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), ), "config_dir": os.environ.get( "ORACLE_CONFIG_DIR", @@ -441,17 +456,58 @@ class LightRAG: if self.graph_storage == "OracleGraphStorage": self.graph_storage_cls.db = oracle_db - self.json_doc_status_storage = self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "json_doc_status_storage", - embedding_func=None, - ) + # 检查是否使用了 PostgreSQL 存储实现 + if ( + self.kv_storage == "PGKVStorage" + or self.vector_storage == "PGVectorStorage" + or self.graph_storage == "PGGraphStorage" + or self.json_doc_status_storage == "PGDocStatusStorage" + ): + # 读取配置文件 + config_parser = configparser.ConfigParser() + if os.path.exists("config.ini"): + config_parser.read("config.ini") - self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - embedding_func=self.embedding_func, - ) + # 从环境变量或配置文件获取参数 + dbconfig = { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + # 初始化 PostgreSQLDB 对象 + postgres_db = PostgreSQLDB(dbconfig) + loop = always_get_an_event_loop() + loop.run_until_complete(postgres_db.initdb()) + + # 只对 PostgreSQL 实现的存储类注入 db 对象 + if self.kv_storage == "PGKVStorage": + self.key_string_value_json_storage_cls.db = postgres_db + if self.vector_storage == "PGVectorStorage": + self.vector_db_storage_cls.db = postgres_db + if self.graph_storage == "PGGraphStorage": + self.graph_storage_cls.db = postgres_db + if self.json_doc_status_storage == "OracleGraphStorage": + self.json_doc_status_storage = postgres_db #### # add embedding func by walter From 5408e7ea02b5c1e0aa199f3c68f8df352a07a8e5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 04:19:12 +0800 Subject: [PATCH 05/35] Add table existence check for Oracle and PostgreSQL DB initialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add Oracle table check on startup • Add PostgreSQL table check on startup • Use event loop for async DB operations • Ensure tables exist before DB operations --- lightrag/lightrag.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 3603509a..5518023d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -447,6 +447,9 @@ class LightRAG: # 初始化 OracleDB 对象 oracle_db = OracleDB(dbconfig) + # Check if DB tables exist, if not, tables will be created + loop = always_get_an_event_loop() + loop.run_until_complete(oracle_db.check_tables()) # 只对 Oracle 实现的存储类注入 db 对象 if self.kv_storage == "OracleKVStorage": @@ -496,8 +499,11 @@ class LightRAG: # 初始化 PostgreSQLDB 对象 postgres_db = PostgreSQLDB(dbconfig) + # Initialize and check tables loop = always_get_an_event_loop() loop.run_until_complete(postgres_db.initdb()) + # Check if DB tables exist, if not, tables will be created + loop.run_until_complete(postgres_db.check_tables()) # 只对 PostgreSQL 实现的存储类注入 db 对象 if self.kv_storage == "PGKVStorage": From c5c606f4919060bc664957acada28bc33f215159 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 04:27:45 +0800 Subject: [PATCH 06/35] =?UTF-8?q?Inject=20TiDB=E5=90=8CLightRAG=20storage?= =?UTF-8?q?=20when=20needed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/kg/tidb_impl.py | 7 ++++++ lightrag/lightrag.py | 53 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index d9eeb2dd..7c75e2d8 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -102,6 +102,8 @@ class TiDB: @dataclass class TiDBKVStorage(BaseKVStorage): # should pass db object to self.db + db: TiDB = None + def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] @@ -208,6 +210,8 @@ class TiDBKVStorage(BaseKVStorage): @dataclass class TiDBVectorDBStorage(BaseVectorStorage): + # should pass db object to self.db + db: TiDB = None cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): @@ -329,6 +333,9 @@ class TiDBVectorDBStorage(BaseVectorStorage): @dataclass class TiDBGraphStorage(BaseGraphStorage): + # should pass db object to self.db + db: TiDB = None + def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5518023d..8ef65951 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -16,8 +16,6 @@ from .base import ( QueryParam, StorageNameSpace, ) -from .kg.oracle_impl import OracleDB -from .kg.postgres_impl import PostgreSQLDB from .namespace import NameSpace, make_namespace from .operate import ( chunking_by_token_size, @@ -446,6 +444,7 @@ class LightRAG: } # 初始化 OracleDB 对象 + from .kg.oracle_impl import OracleDB oracle_db = OracleDB(dbconfig) # Check if DB tables exist, if not, tables will be created loop = always_get_an_event_loop() @@ -459,6 +458,55 @@ class LightRAG: if self.graph_storage == "OracleGraphStorage": self.graph_storage_cls.db = oracle_db + # 检查是否使用了 TiDB 存储实现 + if ( + self.kv_storage == "TiDBKVStorage" + or self.vector_storage == "TiDBVectorDBStorage" + or self.graph_storage == "TiDBGraphStorage" + ): + # 从环境变量或配置文件获取参数 + dbconfig = { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", + config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + + # 初始化 TiDB 对象 + from .kg.tidb_impl import TiDB + tidb_db = TiDB(dbconfig) + # Check if DB tables exist, if not, tables will be created + loop = always_get_an_event_loop() + loop.run_until_complete(tidb_db.check_tables()) + + # 只对 TiDB 实现的存储类注入 db 对象 + if self.kv_storage == "TiDBKVStorage": + self.key_string_value_json_storage_cls.db = tidb_db + if self.vector_storage == "TiDBVectorDBStorage": + self.vector_db_storage_cls.db = tidb_db + if self.graph_storage == "TiDBGraphStorage": + self.graph_storage_cls.db = tidb_db + # 检查是否使用了 PostgreSQL 存储实现 if ( self.kv_storage == "PGKVStorage" @@ -498,6 +546,7 @@ class LightRAG: } # 初始化 PostgreSQLDB 对象 + from .kg.postgres_impl import PostgreSQLDB postgres_db = PostgreSQLDB(dbconfig) # Initialize and check tables loop = always_get_an_event_loop() From 1f7990646ef3c2261de064386985b2573dddb0ba Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 04:42:10 +0800 Subject: [PATCH 07/35] update storage environment requirements for TiDB and Oracle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Changed TiDB URI to user/password • Added Oracle config dir requirement • Fixed Oracle DSN naming • Fixed extra comma in Oracle Graph config • Made env vars more explicit --- lightrag/lightrag.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8ef65951..ed7e4538 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -90,14 +90,14 @@ STORAGE_ENV_REQUIREMENTS = { "JsonKVStorage": [], "MongoKVStorage": [], "RedisKVStorage": ["REDIS_URI"], - "TiDBKVStorage": ["TIDB_URI", "TIDB_DATABASE"], + "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "OracleKVStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + "OracleKVStorage": ["ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD", "ORACLE_CONFIG_DIR"], # Graph Storage Implementations "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], - "TiDBGraphStorage": ["TIDB_URI", "TIDB_DATABASE"], + "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", "AGE_POSTGRES_USER", @@ -106,16 +106,16 @@ STORAGE_ENV_REQUIREMENTS = { ], "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], "PGGraphStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "OracleGraphStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + "OracleGraphStorage": ["ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD", , "ORACLE_CONFIG_DIR"], # Vector Storage Implementations "NanoVectorDBStorage": [], "MilvusVectorDBStorge": [], "ChromaVectorDBStorage": [], - "TiDBVectorDBStorage": ["TIDB_URI", "TIDB_DATABASE"], + "TiDBVectorDBStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "FaissVectorDBStorage": [], "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None - "OracleVectorDBStorage": ["ORACLE_URI", "ORACLE_USER", "ORACLE_PASSWORD"], + "OracleVectorDBStorage": ["ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD", "ORACLE_CONFIG_DIR"], # Document Status Storage Implementations "JsonDocStatusStorage": [], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], From 17abd214a2e206dbe7c5e72c8f49e546f2e5e744 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 05:18:09 +0800 Subject: [PATCH 08/35] Fix linting --- lightrag/lightrag.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ed7e4538..5b334fe0 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -92,7 +92,12 @@ STORAGE_ENV_REQUIREMENTS = { "RedisKVStorage": ["REDIS_URI"], "TiDBKVStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "PGKVStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "OracleKVStorage": ["ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD", "ORACLE_CONFIG_DIR"], + "OracleKVStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], # Graph Storage Implementations "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], @@ -106,7 +111,12 @@ STORAGE_ENV_REQUIREMENTS = { ], "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], "PGGraphStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], - "OracleGraphStorage": ["ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD", , "ORACLE_CONFIG_DIR"], + "OracleGraphStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], # Vector Storage Implementations "NanoVectorDBStorage": [], "MilvusVectorDBStorge": [], @@ -115,7 +125,12 @@ STORAGE_ENV_REQUIREMENTS = { "PGVectorStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], "FaissVectorDBStorage": [], "QdrantVectorDBStorage": ["QDRANT_URL"], # QDRANT_API_KEY has default value None - "OracleVectorDBStorage": ["ORACLE_DSN", "ORACLE_USER", "ORACLE_PASSWORD", "ORACLE_CONFIG_DIR"], + "OracleVectorDBStorage": [ + "ORACLE_DSN", + "ORACLE_USER", + "ORACLE_PASSWORD", + "ORACLE_CONFIG_DIR", + ], # Document Status Storage Implementations "JsonDocStatusStorage": [], "PGDocStatusStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], @@ -445,6 +460,7 @@ class LightRAG: # 初始化 OracleDB 对象 from .kg.oracle_impl import OracleDB + oracle_db = OracleDB(dbconfig) # Check if DB tables exist, if not, tables will be created loop = always_get_an_event_loop() @@ -471,8 +487,7 @@ class LightRAG: config.get("tidb", "host", fallback="localhost"), ), "port": os.environ.get( - "TIDB_PORT", - config.get("tidb", "port", fallback=4000) + "TIDB_PORT", config.get("tidb", "port", fallback=4000) ), "user": os.environ.get( "TIDB_USER", @@ -494,6 +509,7 @@ class LightRAG: # 初始化 TiDB 对象 from .kg.tidb_impl import TiDB + tidb_db = TiDB(dbconfig) # Check if DB tables exist, if not, tables will be created loop = always_get_an_event_loop() @@ -547,6 +563,7 @@ class LightRAG: # 初始化 PostgreSQLDB 对象 from .kg.postgres_impl import PostgreSQLDB + postgres_db = PostgreSQLDB(dbconfig) # Initialize and check tables loop = always_get_an_event_loop() From f20a164467f863117889457932b1ccf0e818d2e8 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 05:22:40 +0800 Subject: [PATCH 09/35] Translate Chinese comments to English --- lightrag/lightrag.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5b334fe0..fa7ce2a8 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -420,13 +420,13 @@ class LightRAG: embedding_func=self.embedding_func, ) - # 检查是否使用了 Oracle 存储实现 + # Check if Oracle storage implementation is used if ( self.kv_storage == "OracleKVStorage" or self.vector_storage == "OracleVectorDBStorage" or self.graph_storage == "OracleGraphStorage" ): - # 从环境变量或配置文件获取参数 + # Get parameters from environment variables or config file dbconfig = { "user": os.environ.get( "ORACLE_USER", @@ -458,7 +458,7 @@ class LightRAG: ), } - # 初始化 OracleDB 对象 + # Initialize OracleDB object from .kg.oracle_impl import OracleDB oracle_db = OracleDB(dbconfig) @@ -466,7 +466,7 @@ class LightRAG: loop = always_get_an_event_loop() loop.run_until_complete(oracle_db.check_tables()) - # 只对 Oracle 实现的存储类注入 db 对象 + # Only inject db object for Oracle storage implementations if self.kv_storage == "OracleKVStorage": self.key_string_value_json_storage_cls.db = oracle_db if self.vector_storage == "OracleVectorDBStorage": @@ -474,13 +474,13 @@ class LightRAG: if self.graph_storage == "OracleGraphStorage": self.graph_storage_cls.db = oracle_db - # 检查是否使用了 TiDB 存储实现 + # Check if TiDB storage implementation is used if ( self.kv_storage == "TiDBKVStorage" or self.vector_storage == "TiDBVectorDBStorage" or self.graph_storage == "TiDBGraphStorage" ): - # 从环境变量或配置文件获取参数 + # Get parameters from environment variables or config file dbconfig = { "host": os.environ.get( "TIDB_HOST", @@ -507,7 +507,7 @@ class LightRAG: ), } - # 初始化 TiDB 对象 + # Initialize TiDB object from .kg.tidb_impl import TiDB tidb_db = TiDB(dbconfig) @@ -515,7 +515,7 @@ class LightRAG: loop = always_get_an_event_loop() loop.run_until_complete(tidb_db.check_tables()) - # 只对 TiDB 实现的存储类注入 db 对象 + # Only inject db object for TiDB storage implementations if self.kv_storage == "TiDBKVStorage": self.key_string_value_json_storage_cls.db = tidb_db if self.vector_storage == "TiDBVectorDBStorage": @@ -523,19 +523,19 @@ class LightRAG: if self.graph_storage == "TiDBGraphStorage": self.graph_storage_cls.db = tidb_db - # 检查是否使用了 PostgreSQL 存储实现 + # Check if PostgreSQL storage implementation is used if ( self.kv_storage == "PGKVStorage" or self.vector_storage == "PGVectorStorage" or self.graph_storage == "PGGraphStorage" or self.json_doc_status_storage == "PGDocStatusStorage" ): - # 读取配置文件 + # Read configuration file config_parser = configparser.ConfigParser() if os.path.exists("config.ini"): config_parser.read("config.ini") - # 从环境变量或配置文件获取参数 + # Get parameters from environment variables or config file dbconfig = { "host": os.environ.get( "POSTGRES_HOST", @@ -561,7 +561,7 @@ class LightRAG: ), } - # 初始化 PostgreSQLDB 对象 + # Initialize PostgreSQLDB object from .kg.postgres_impl import PostgreSQLDB postgres_db = PostgreSQLDB(dbconfig) @@ -571,7 +571,7 @@ class LightRAG: # Check if DB tables exist, if not, tables will be created loop.run_until_complete(postgres_db.check_tables()) - # 只对 PostgreSQL 实现的存储类注入 db 对象 + # Only inject db object for PostgreSQL storage implementations if self.kv_storage == "PGKVStorage": self.key_string_value_json_storage_cls.db = postgres_db if self.vector_storage == "PGVectorStorage": @@ -582,7 +582,7 @@ class LightRAG: self.json_doc_status_storage = postgres_db #### - # add embedding func by walter + # Add embedding function by walter #### self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( namespace=make_namespace( @@ -603,7 +603,7 @@ class LightRAG: embedding_func=self.embedding_func, ) #### - # add embedding func by walter over + # End of adding embedding function by walter #### self.entities_vdb = self.vector_db_storage_cls( From cddde8053ddd4b6581fabdc1c4d198d04a9db296 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 06:31:59 +0800 Subject: [PATCH 10/35] Add configuration examples for Oracle, TiDB, PostgreSQL and storage backends --- .env.example | 41 ++++++++++++++- config.ini.example | 25 +++++++++ lightrag/api/README.md | 117 ++++++++++++++++++++++++----------------- 3 files changed, 134 insertions(+), 49 deletions(-) diff --git a/.env.example b/.env.example index 6f868212..114f8554 100644 --- a/.env.example +++ b/.env.example @@ -72,6 +72,45 @@ LOG_LEVEL=INFO # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large # AZURE_EMBEDDING_API_VERSION=2023-05-15 - # Ollama Emulating Model Tag # OLLAMA_EMULATING_MODEL_TAG=latest + +# Oracle Database Configuration +ORACLE_DSN=localhost:1521/XEPDB1 +ORACLE_USER=your_username +ORACLE_PASSWORD=your_password +ORACLE_CONFIG_DIR=/path/to/oracle/config +ORACLE_WALLET_LOCATION=/path/to/wallet # 可选 +ORACLE_WALLET_PASSWORD=your_wallet_password # 可选 +ORACLE_WORKSPACE=default # 可选,默认为default + +# TiDB Configuration +TIDB_HOST=localhost +TIDB_PORT=4000 +TIDB_USER=your_username +TIDB_PASSWORD=your_password +TIDB_DATABASE=your_database +TIDB_WORKSPACE=default # 可选,默认为default + +# PostgreSQL Configuration +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_USER=your_username +POSTGRES_PASSWORD=your_password +POSTGRES_DATABASE=your_database +POSTGRES_WORKSPACE=default # 可选,默认为default + +# Database Configurations +# Neo4j +NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io +NEO4J_USERNAME=neo4j +NEO4J_PASSWORD=your-password + +# MongoDB (可选) +MONGODB_URI=mongodb+srv://name:password@your-cluster-address +MONGODB_DATABASE=lightrag +MONGODB_GRAPH=false + +# Qdrant +QDRANT_URL=http://localhost:16333 +QDRANT_API_KEY=your-api-key # 可选 diff --git a/config.ini.example b/config.ini.example index e7916b01..e6ceed0a 100644 --- a/config.ini.example +++ b/config.ini.example @@ -13,3 +13,28 @@ uri=redis://localhost:6379/1 [qdrant] uri = http://localhost:16333 + +[oracle] +dsn = localhost:1521/XEPDB1 +user = your_username +password = your_password +config_dir = /path/to/oracle/config +wallet_location = /path/to/wallet # 可选 +wallet_password = your_wallet_password # 可选 +workspace = default # 可选,默认为default + +[tidb] +host = localhost +port = 4000 +user = your_username +password = your_password +database = your_database +workspace = default # 可选,默认为default + +[postgres] +host = localhost +port = 5432 +user = your_username +password = your_password +database = your_database +workspace = default # 可选,默认为default diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 2f871cdc..9f5580fb 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -107,62 +107,19 @@ For better performance, the API server's default values for TOP_K and COSINE_THR ### Environment Variables -You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. Here's a complete example of available environment variables: +You can configure LightRAG using environment variables by creating a `.env` file in your project root directory. A sample file `.env.example` is provided for your convenience. -```env -# Server Configuration -HOST=0.0.0.0 -PORT=9621 +### Config.ini -# Directory Configuration -WORKING_DIR=/app/data/rag_storage -INPUT_DIR=/app/data/inputs - -# RAG Configuration -MAX_ASYNC=4 -MAX_TOKENS=32768 -EMBEDDING_DIM=1024 -MAX_EMBED_TOKENS=8192 -#HISTORY_TURNS=3 -#CHUNK_SIZE=1200 -#CHUNK_OVERLAP_SIZE=100 -#COSINE_THRESHOLD=0.4 -#TOP_K=50 - -# LLM Configuration -LLM_BINDING=ollama -LLM_BINDING_HOST=http://localhost:11434 -LLM_MODEL=mistral-nemo:latest - -# must be set if using OpenAI LLM (LLM_MODEL must be set or set by command line parms) -OPENAI_API_KEY=you_api_key - -# Embedding Configuration -EMBEDDING_BINDING=ollama -EMBEDDING_BINDING_HOST=http://localhost:11434 -EMBEDDING_MODEL=bge-m3:latest - -# Security -#LIGHTRAG_API_KEY=you-api-key-for-accessing-LightRAG - -# Logging -LOG_LEVEL=INFO - -# Optional SSL Configuration -#SSL=true -#SSL_CERTFILE=/path/to/cert.pem -#SSL_KEYFILE=/path/to/key.pem - -# Optional Timeout -#TIMEOUT=30 -``` +Datastorage configuration can be also set by config.ini. A sample file `config.ini.example` is provided for your convenience. ### Configuration Priority The configuration values are loaded in the following order (highest priority first): 1. Command-line arguments 2. Environment variables -3. Default values +3. Config.ini +4. Defaul values For example: ```bash @@ -173,6 +130,66 @@ python lightrag.py --port 8080 PORT=7000 python lightrag.py ``` +#### Storage Types Supported + +LightRAG uses 4 types of storage for difference purposes: + +* KV_STORAGE:llm response cache, text chunks, document information +* VECTOR_STORAGE:entities vectors, relation vectors, chunks vectors +* GRAPH_STORAGE:entity relation graph +* DOC_STATUS_STORAGE:documents indexing status + +Each storage type have servals implementations: + +* KV_STORAGE supported implement-name + +``` +JsonKVStorage JsonFile(default) +MongoKVStorage MogonDB +RedisKVStorage Redis +TiDBKVStorage TiDB +PGKVStorage Postgres +OracleKVStorage Oracle +``` + +* GRAPH_STORAGE supported implement-name + +``` +NetworkXStorage NetworkX(defualt) +Neo4JStorage Neo4J +MongoGraphStorage MongoDB +TiDBGraphStorage TiDB +AGEStorage AGE +GremlinStorage Gremlin +PGGraphStorage Postgres +OracleGraphStorage Postgres +``` + +* VECTOR_STORAGE supported implement-name + +``` +NanoVectorDBStorage NanoVector(default) +MilvusVectorDBStorge Milvus +ChromaVectorDBStorage Chroma +TiDBVectorDBStorage TiDB +PGVectorStorage Postgres +FaissVectorDBStorage Faiss +QdrantVectorDBStorage Qdrant +OracleVectorDBStorag Oracle +``` + +* DOC_STATUS_STORAGE:supported implement-name + +``` +JsonDocStatusStorage JsonFile(default) +PGDocStatusStorage Postgres +``` + +#### How Select Storage Type + +* Bye enviroment variables +* By command line arguments + #### LightRag Server Options | Parameter | Default | Description | @@ -200,6 +217,10 @@ PORT=7000 python lightrag.py | --ssl-keyfile | None | Path to SSL private key file (required if --ssl is enabled) | | --top-k | 50 | Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. | | --cosine-threshold | 0.4 | The cossine threshold for nodes and relations retrieval, works with top-k to control the retrieval of nodes and relations. | +| --kv-storage | JsonKVStorage | implement-name of KV_STORAGE | +| --graph-storage | NetworkXStorage | implement-name of GRAPH_STORAGE | +| --vector-storage | NanoVectorDBStorage | implement-name of VECTOR_STORAGE | +| --doc-status-storage | JsonDocStatusStorage | implement-name of DOC_STATUS_STORAGE | ### Example Usage From 0e660d400071cdd29835a16108c75ff2e66291c2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 10:17:51 +0800 Subject: [PATCH 11/35] Fix doc_status error --- lightrag/lightrag.py | 52 ++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index fa7ce2a8..8e2ed3a8 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -408,17 +408,8 @@ class LightRAG: self.graph_storage_cls, global_config=global_config ) - self.json_doc_status_storage = self.key_string_value_json_storage_cls( - namespace=self.namespace_prefix + "json_doc_status_storage", - embedding_func=None, - ) - - self.llm_response_cache = self.key_string_value_json_storage_cls( - namespace=make_namespace( - self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ), - embedding_func=self.embedding_func, - ) + # Initialize document status storage + self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) # Check if Oracle storage implementation is used if ( @@ -528,7 +519,7 @@ class LightRAG: self.kv_storage == "PGKVStorage" or self.vector_storage == "PGVectorStorage" or self.graph_storage == "PGGraphStorage" - or self.json_doc_status_storage == "PGDocStatusStorage" + or self.doc_status_storage == "PGDocStatusStorage" ): # Read configuration file config_parser = configparser.ConfigParser() @@ -578,12 +569,16 @@ class LightRAG: self.vector_db_storage_cls.db = postgres_db if self.graph_storage == "PGGraphStorage": self.graph_storage_cls.db = postgres_db - if self.json_doc_status_storage == "OracleGraphStorage": - self.json_doc_status_storage = postgres_db + if self.doc_status_storage == "OracleGraphStorage": + self.doc_status_storage_cls = postgres_db + + self.llm_response_cache = self.key_string_value_json_storage_cls( + namespace=make_namespace( + self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE + ), + embedding_func=self.embedding_func, + ) - #### - # Add embedding function by walter - #### self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS @@ -602,9 +597,6 @@ class LightRAG: ), embedding_func=self.embedding_func, ) - #### - # End of adding embedding function by walter - #### self.entities_vdb = self.vector_db_storage_cls( namespace=make_namespace( @@ -627,6 +619,7 @@ class LightRAG: embedding_func=self.embedding_func, ) + # What's for, Is this nessisary ? if self.llm_response_cache and hasattr( self.llm_response_cache, "global_config" ): @@ -639,6 +632,17 @@ class LightRAG: embedding_func=self.embedding_func, ) + # self.json_doc_status_storage = self.key_string_value_json_storage_cls( + # namespace=self.namespace_prefix + "json_doc_status_storage", + # embedding_func=None, + # ) + + self.doc_status: DocStatusStorage = self.doc_status_storage_cls( + namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), + global_config=global_config, + embedding_func=None, + ) + self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, @@ -647,14 +651,6 @@ class LightRAG: ) ) - # Initialize document status storage - self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - self.doc_status: DocStatusStorage = self.doc_status_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), - global_config=global_config, - embedding_func=None, - ) - async def get_graph_labels(self): text = await self.chunk_entity_relation_graph.get_all_labels() return text From 14fd89c168684b9d3888c705739b50f209adfaf4 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 12:05:37 +0800 Subject: [PATCH 12/35] Remove storage class method check --- lightrag/lightrag.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 8e2ed3a8..a6ab4e26 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -317,21 +317,6 @@ class LightRAG: f"Compatible implementations are: {', '.join(storage_info['implementations'])}" ) - # Get storage class - storage_class = self._get_storage_class(storage_name) - - # Check required methods - missing_methods = [] - for method in storage_info["required_methods"]: - if not hasattr(storage_class, method): - missing_methods.append(method) - - if missing_methods: - raise ValueError( - f"Storage implementation '{storage_name}' is missing required methods: " - f"{', '.join(missing_methods)}" - ) - def check_storage_env_vars(self, storage_name: str) -> None: """Check if all required environment variables for storage implementation exist From aaddc08336ff24098b8b46f7551768b5f48e44c5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 14:57:37 +0800 Subject: [PATCH 13/35] Add storage info to splash screen --- .env.example | 25 +++++++++++++++++-------- lightrag/api/lightrag_server.py | 14 ++++++++++++-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/.env.example b/.env.example index 114f8554..135852a5 100644 --- a/.env.example +++ b/.env.example @@ -75,36 +75,45 @@ LOG_LEVEL=INFO # Ollama Emulating Model Tag # OLLAMA_EMULATING_MODEL_TAG=latest + +# Data storage selection +# LIGHTRAG_KV_STORAGE=PGKVStorage +# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage +# LIGHTRAG_GRAPH_STORAGE=PGGraphStorage +# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage + # Oracle Database Configuration ORACLE_DSN=localhost:1521/XEPDB1 ORACLE_USER=your_username -ORACLE_PASSWORD=your_password +ORACLE_PASSWORD='your_password' ORACLE_CONFIG_DIR=/path/to/oracle/config ORACLE_WALLET_LOCATION=/path/to/wallet # 可选 -ORACLE_WALLET_PASSWORD=your_wallet_password # 可选 -ORACLE_WORKSPACE=default # 可选,默认为default +#ORACLE_WALLET_PASSWORD='your_password' # 可选 +#ORACLE_WORKSPACE=default # 可选,默认为default # TiDB Configuration TIDB_HOST=localhost TIDB_PORT=4000 TIDB_USER=your_username -TIDB_PASSWORD=your_password +TIDB_PASSWORD='your_password' TIDB_DATABASE=your_database -TIDB_WORKSPACE=default # 可选,默认为default +#TIDB_WORKSPACE=default # 可选,默认为default # PostgreSQL Configuration POSTGRES_HOST=localhost POSTGRES_PORT=5432 POSTGRES_USER=your_username -POSTGRES_PASSWORD=your_password +POSTGRES_PASSWORD='your_password' POSTGRES_DATABASE=your_database -POSTGRES_WORKSPACE=default # 可选,默认为default +#POSTGRES_WORKSPACE=default # 可选,默认为default +# AGE Configuration +AGE_GRAPH_NAME=dickens # Database Configurations # Neo4j NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io NEO4J_USERNAME=neo4j -NEO4J_PASSWORD=your-password +NEO4J_PASSWORD='your_password' # MongoDB (可选) MONGODB_URI=mongodb+srv://name:password@your-cluster-address diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8d85c292..97b3f5a5 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -47,9 +47,9 @@ class RAGStorageConfig: # 默认存储实现 DEFAULT_KV_STORAGE = "JsonKVStorage" - DEFAULT_DOC_STATUS_STORAGE = "JsonDocStatusStorage" - DEFAULT_GRAPH_STORAGE = "NetworkXStorage" DEFAULT_VECTOR_STORAGE = "NanoVectorDBStorage" + DEFAULT_GRAPH_STORAGE = "NetworkXStorage" + DEFAULT_DOC_STATUS_STORAGE = "JsonDocStatusStorage" def __init__(self): # 从环境变量读取配置,如果没有则使用默认值 @@ -219,6 +219,16 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.yellow(f"{args.top_k}") # System Configuration + ASCIIColors.magenta("\n💾 Storage Configuration:") + ASCIIColors.white(" ├─ KV Storage: ", end="") + ASCIIColors.yellow(f"{rag_storage_config.KV_STORAGE}") + ASCIIColors.white(" ├─ Document Status Storage: ", end="") + ASCIIColors.yellow(f"{rag_storage_config.DOC_STATUS_STORAGE}") + ASCIIColors.white(" ├─ Graph Storage: ", end="") + ASCIIColors.yellow(f"{rag_storage_config.GRAPH_STORAGE}") + ASCIIColors.white(" └─ Vector Storage: ", end="") + ASCIIColors.yellow(f"{rag_storage_config.VECTOR_STORAGE}") + ASCIIColors.magenta("\n🛠️ System Configuration:") ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") ASCIIColors.yellow(f"{ollama_server_infos.LIGHTRAG_MODEL}") From 8a56a5ea6ccacda50b3df64f82792c6a8483a636 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 16:11:15 +0800 Subject: [PATCH 14/35] fix: Add content column to doc status and fix SQL parameter indexing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add content column to doc status table • Fix SQL param index in get_by_status query • Update insert SQL to include content field --- lightrag/kg/postgres_impl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index d319f6f9..5bd0a949 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -471,7 +471,7 @@ class PGDocStatusStorage(DocStatusStorage): self, status: DocStatus ) -> Dict[str, DocProcessingStatus]: """Get all documents by status""" - sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1" + sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.db.workspace, "status": status} result = await self.db.query(sql, params, True) return { @@ -505,8 +505,8 @@ class PGDocStatusStorage(DocStatusStorage): Args: data: Dictionary of document IDs and their status data """ - sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status) - values($1,$2,$3,$4,$5,$6) + sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status) + values($1,$2,$3,$4,$5,$6,$7) on conflict(id,workspace) do update set content = EXCLUDED.content, content_summary = EXCLUDED.content_summary, @@ -1103,6 +1103,7 @@ TABLES = { "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS ( workspace varchar(255) NOT NULL, id varchar(255) NOT NULL, + content TEXT NULL, content_summary varchar(255) NULL, content_length int4 NULL, chunks_count int4 NULL, From afb527281659a4ab3710234cab9208789147f7c1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 11 Feb 2025 17:07:53 +0800 Subject: [PATCH 15/35] Fix type and class name in doc status storage class assignment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Corrected PGDocStatusStorage class name • Fixed db assignment to class not instance • Fixed incorrect Oracle reference --- lightrag/lightrag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index a6ab4e26..b1670850 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -554,8 +554,8 @@ class LightRAG: self.vector_db_storage_cls.db = postgres_db if self.graph_storage == "PGGraphStorage": self.graph_storage_cls.db = postgres_db - if self.doc_status_storage == "OracleGraphStorage": - self.doc_status_storage_cls = postgres_db + if self.doc_status_storage == "PGDocStatusStorage": + self.doc_status_storage_cls.db = postgres_db self.llm_response_cache = self.key_string_value_json_storage_cls( namespace=make_namespace( From cf61bed62c3d63690d81eaa53c1dfe10df3517a9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 12 Feb 2025 21:48:48 +0800 Subject: [PATCH 16/35] Reorganize env config sections, add data store config to env file. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add namespace prefix config option • Make AGE graph name optional • Update env variable requirements • Add comments for deprecated options --- .env.example | 109 +++++++++++++++++++++------------------- lightrag/kg/age_impl.py | 4 +- lightrag/lightrag.py | 7 ++- 3 files changed, 65 insertions(+), 55 deletions(-) diff --git a/.env.example b/.env.example index 135852a5..369bde4b 100644 --- a/.env.example +++ b/.env.example @@ -1,12 +1,30 @@ -# Server Configuration -HOST=0.0.0.0 -PORT=9621 +### Server Configuration +#HOST=0.0.0.0 +#PORT=9621 +#NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances -# Directory Configuration -WORKING_DIR=/app/data/rag_storage -INPUT_DIR=/app/data/inputs +### Optional SSL Configuration +#SSL=true +#SSL_CERTFILE=/path/to/cert.pem +#SSL_KEYFILE=/path/to/key.pem -# RAG Configuration +### Security (empty for no api-key is needed) +# LIGHTRAG_API_KEY=your-secure-api-key-here + +### Directory Configuration +# WORKING_DIR=./rag_storage +# INPUT_DIR=./inputs + +### Logging level +LOG_LEVEL=INFO + +### Optional Timeout +TIMEOUT=300 + +# Ollama Emulating Model Tag +# OLLAMA_EMULATING_MODEL_TAG=latest + +### RAG Configuration MAX_ASYNC=4 MAX_TOKENS=32768 EMBEDDING_DIM=1024 @@ -17,53 +35,39 @@ MAX_EMBED_TOKENS=8192 #COSINE_THRESHOLD=0.4 # 0.2 while not running API server #TOP_K=50 # 60 while not running API server -# LLM Configuration (Use valid host. For local services, you can use host.docker.internal) -# Ollama example +### LLM Configuration (Use valid host. For local services, you can use host.docker.internal) +### Ollama example LLM_BINDING=ollama LLM_BINDING_HOST=http://host.docker.internal:11434 LLM_MODEL=mistral-nemo:latest -# OpenAI alike example +### OpenAI alike example # LLM_BINDING=openai # LLM_MODEL=deepseek-chat # LLM_BINDING_HOST=https://api.deepseek.com # LLM_BINDING_API_KEY=your_api_key -# for OpenAI LLM (LLM_BINDING_API_KEY take priority) +### for OpenAI LLM (LLM_BINDING_API_KEY take priority) # OPENAI_API_KEY=your_api_key -# Lollms example +### Lollms example # LLM_BINDING=lollms # LLM_BINDING_HOST=http://host.docker.internal:9600 # LLM_MODEL=mistral-nemo:latest -# Embedding Configuration (Use valid host. For local services, you can use host.docker.internal) +### Embedding Configuration (Use valid host. For local services, you can use host.docker.internal) # Ollama example EMBEDDING_BINDING=ollama EMBEDDING_BINDING_HOST=http://host.docker.internal:11434 EMBEDDING_MODEL=bge-m3:latest -# Lollms example +### Lollms example # EMBEDDING_BINDING=lollms # EMBEDDING_BINDING_HOST=http://host.docker.internal:9600 # EMBEDDING_MODEL=bge-m3:latest -# Security (empty for no key) -LIGHTRAG_API_KEY=your-secure-api-key-here - -# Logging -LOG_LEVEL=INFO - -# Optional SSL Configuration -#SSL=true -#SSL_CERTFILE=/path/to/cert.pem -#SSL_KEYFILE=/path/to/key.pem - -# Optional Timeout -#TIMEOUT=30 - -# Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority) +### Optional for Azure (LLM_BINDING_HOST, LLM_BINDING_API_KEY take priority) # AZURE_OPENAI_API_VERSION=2024-08-01-preview # AZURE_OPENAI_DEPLOYMENT=gpt-4o # AZURE_OPENAI_API_KEY=myapikey @@ -72,54 +76,57 @@ LOG_LEVEL=INFO # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large # AZURE_EMBEDDING_API_VERSION=2023-05-15 -# Ollama Emulating Model Tag -# OLLAMA_EMULATING_MODEL_TAG=latest - - -# Data storage selection +### Data storage selection # LIGHTRAG_KV_STORAGE=PGKVStorage # LIGHTRAG_VECTOR_STORAGE=PGVectorStorage # LIGHTRAG_GRAPH_STORAGE=PGGraphStorage # LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage -# Oracle Database Configuration +### Oracle Database Configuration ORACLE_DSN=localhost:1521/XEPDB1 ORACLE_USER=your_username ORACLE_PASSWORD='your_password' ORACLE_CONFIG_DIR=/path/to/oracle/config -ORACLE_WALLET_LOCATION=/path/to/wallet # 可选 -#ORACLE_WALLET_PASSWORD='your_password' # 可选 -#ORACLE_WORKSPACE=default # 可选,默认为default +#ORACLE_WALLET_LOCATION=/path/to/wallet # optional +#ORACLE_WALLET_PASSWORD='your_password' # optional +#ORACLE_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future) -# TiDB Configuration +### TiDB Configuration TIDB_HOST=localhost TIDB_PORT=4000 TIDB_USER=your_username TIDB_PASSWORD='your_password' TIDB_DATABASE=your_database -#TIDB_WORKSPACE=default # 可选,默认为default +#TIDB_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future) -# PostgreSQL Configuration +### PostgreSQL Configuration POSTGRES_HOST=localhost POSTGRES_PORT=5432 POSTGRES_USER=your_username POSTGRES_PASSWORD='your_password' POSTGRES_DATABASE=your_database -#POSTGRES_WORKSPACE=default # 可选,默认为default -# AGE Configuration -AGE_GRAPH_NAME=dickens +#POSTGRES_WORKSPACE=default # separating all data from difference Lightrag instances(deprecated, use NAMESPACE_PREFIX in future) -# Database Configurations -# Neo4j +### Independent AGM Configuration(not for AMG embedded in PostreSQL) +AGE_POSTGRES_DB= +AGE_POSTGRES_USER= +AGE_POSTGRES_PASSWORD= +AGE_POSTGRES_HOST= +# AGE_POSTGRES_PORT=8529 + +# AGE Graph Name(apply to PostgreSQL and independent AGM) +# AGE_GRAPH_NAME=lightrag # deprecated, use NAME_SPACE_PREFIX instead + +### Neo4j Configuration NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io NEO4J_USERNAME=neo4j NEO4J_PASSWORD='your_password' -# MongoDB (可选) -MONGODB_URI=mongodb+srv://name:password@your-cluster-address -MONGODB_DATABASE=lightrag -MONGODB_GRAPH=false +### MongoDB Configuration +MONGODB_URI=mongodb://root:root@localhost:27017/ +MONGODB_DATABASE=LightRAG +MONGODB_GRAPH=false # deprecated (keep for backward compatibility) -# Qdrant +### Qdrant QDRANT_URL=http://localhost:16333 QDRANT_API_KEY=your-api-key # 可选 diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index df32b7cb..a6857f22 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -75,8 +75,8 @@ class AGEStorage(BaseGraphStorage): .replace("'", "\\'") ) HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'") - PORT = int(os.environ["AGE_POSTGRES_PORT"]) - self.graph_name = os.environ["AGE_GRAPH_NAME"] + PORT = os.environ.get("AGE_POSTGRES_PORT", "8529") + self.graph_name = namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cb1aa195..48c20428 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -107,10 +107,13 @@ STORAGE_ENV_REQUIREMENTS = { "AGE_POSTGRES_DB", "AGE_POSTGRES_USER", "AGE_POSTGRES_PASSWORD", - "AGE_GRAPH_NAME", ], "GremlinStorage": ["GREMLIN_HOST", "GREMLIN_PORT", "GREMLIN_GRAPH"], - "PGGraphStorage": ["POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_DATABASE"], + "PGGraphStorage": [ + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "POSTGRES_DATABASE", + ], "OracleGraphStorage": [ "ORACLE_DSN", "ORACLE_USER", From 7b79427097fe3d5678d824b390c4cc8fe409c0b7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 12 Feb 2025 22:25:34 +0800 Subject: [PATCH 17/35] refactor: improve database initialization by centralizing db instance injection - Move db configs to separate methods - Remove db field defaults in storage classes - Add _initialize_database_if_needed method - Inject db instances during initialization - Clean up storage implementation code --- lightrag/kg/oracle_impl.py | 12 +- lightrag/kg/postgres_impl.py | 24 +-- lightrag/kg/tidb_impl.py | 12 +- lightrag/lightrag.py | 357 ++++++++++++++++++----------------- 4 files changed, 205 insertions(+), 200 deletions(-) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index f34fe4b1..c2859829 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -172,8 +172,8 @@ class OracleDB: @dataclass class OracleKVStorage(BaseKVStorage): - # should pass db object to self.db - db: OracleDB = None + # db instance must be injected before use + # db: OracleDB meta_fields = None def __post_init__(self): @@ -318,8 +318,8 @@ class OracleKVStorage(BaseKVStorage): @dataclass class OracleVectorDBStorage(BaseVectorStorage): - # should pass db object to self.db - db: OracleDB = None + # db instance must be injected before use + # db: OracleDB cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): @@ -361,8 +361,8 @@ class OracleVectorDBStorage(BaseVectorStorage): @dataclass class OracleGraphStorage(BaseGraphStorage): - # should pass db object to self.db - db: OracleDB = None + # db instance must be injected before use + # db: OracleDB def __post_init__(self): """从graphml文件加载图""" diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 526f54a7..221202ab 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -177,7 +177,8 @@ class PostgreSQLDB: @dataclass class PGKVStorage(BaseKVStorage): - db: PostgreSQLDB = None + # db instance must be injected before use + # db: PostgreSQLDB def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] @@ -296,8 +297,9 @@ class PGKVStorage(BaseKVStorage): @dataclass class PGVectorStorage(BaseVectorStorage): + # db instance must be injected before use + # db: PostgreSQLDB cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) - db: PostgreSQLDB = None def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] @@ -418,10 +420,8 @@ class PGVectorStorage(BaseVectorStorage): class PGDocStatusStorage(DocStatusStorage): """PostgreSQL implementation of document status storage""" - db: PostgreSQLDB = None - - def __post_init__(self): - pass + # db instance must be injected before use + db: PostgreSQLDB async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that don't exist in storage""" @@ -577,19 +577,15 @@ class PGGraphQueryException(Exception): @dataclass class PGGraphStorage(BaseGraphStorage): - db: PostgreSQLDB = None + # db instance must be injected before use + # db: PostgreSQLDB @staticmethod def load_nx_graph(file_name): print("no preloading of graph with AGE in production") - def __init__(self, namespace, global_config, embedding_func): - super().__init__( - namespace=namespace, - global_config=global_config, - embedding_func=embedding_func, - ) - self.graph_name = os.environ["AGE_GRAPH_NAME"] + def __post_init__(self): + self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag") self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 7c75e2d8..ba5a6240 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -101,8 +101,8 @@ class TiDB: @dataclass class TiDBKVStorage(BaseKVStorage): - # should pass db object to self.db - db: TiDB = None + # db instance must be injected before use + # db: TiDB def __post_init__(self): self._data = {} @@ -210,8 +210,8 @@ class TiDBKVStorage(BaseKVStorage): @dataclass class TiDBVectorDBStorage(BaseVectorStorage): - # should pass db object to self.db - db: TiDB = None + # db instance must be injected before use + # db: TiDB cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) def __post_init__(self): @@ -333,8 +333,8 @@ class TiDBVectorDBStorage(BaseVectorStorage): @dataclass class TiDBGraphStorage(BaseGraphStorage): - # should pass db object to self.db - db: TiDB = None + # db instance must be injected before use + # db: TiDB def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 48c20428..5648c85d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -243,6 +243,9 @@ class LightRAG: graph_storage: str = field(default="NetworkXStorage") """Storage backend for knowledge graphs.""" + doc_status_storage: str = field(default="JsonDocStatusStorage") + """Storage type for tracking document processing statuses.""" + # Logging current_log_level = logger.level log_level: int = field(default=current_log_level) @@ -339,9 +342,6 @@ class LightRAG: convert_response_to_json ) - doc_status_storage: str = field(default="JsonDocStatusStorage") - """Storage type for tracking document processing statuses.""" - # Custom Chunking Function chunking_func: Callable[ [ @@ -355,6 +355,91 @@ class LightRAG: list[dict[str, Any]], ] = chunking_by_token_size + def _get_postgres_config(self): + return { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + def _get_oracle_config(self): + return { + "user": os.environ.get( + "ORACLE_USER", + config.get("oracle", "user", fallback=None), + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + def _get_tidb_config(self): + return { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + def verify_storage_implementation( self, storage_type: str, storage_name: str ) -> None: @@ -456,167 +541,6 @@ class LightRAG: # Initialize document status storage self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage) - # Check if Oracle storage implementation is used - if ( - self.kv_storage == "OracleKVStorage" - or self.vector_storage == "OracleVectorDBStorage" - or self.graph_storage == "OracleGraphStorage" - ): - # Get parameters from environment variables or config file - dbconfig = { - "user": os.environ.get( - "ORACLE_USER", - config.get("oracle", "user", fallback=None), - ), - "password": os.environ.get( - "ORACLE_PASSWORD", - config.get("oracle", "password", fallback=None), - ), - "dsn": os.environ.get( - "ORACLE_DSN", - config.get("oracle", "dsn", fallback=None), - ), - "config_dir": os.environ.get( - "ORACLE_CONFIG_DIR", - config.get("oracle", "config_dir", fallback=None), - ), - "wallet_location": os.environ.get( - "ORACLE_WALLET_LOCATION", - config.get("oracle", "wallet_location", fallback=None), - ), - "wallet_password": os.environ.get( - "ORACLE_WALLET_PASSWORD", - config.get("oracle", "wallet_password", fallback=None), - ), - "workspace": os.environ.get( - "ORACLE_WORKSPACE", - config.get("oracle", "workspace", fallback="default"), - ), - } - - # Initialize OracleDB object - from .kg.oracle_impl import OracleDB - - oracle_db = OracleDB(dbconfig) - # Check if DB tables exist, if not, tables will be created - loop = always_get_an_event_loop() - loop.run_until_complete(oracle_db.check_tables()) - - # Only inject db object for Oracle storage implementations - if self.kv_storage == "OracleKVStorage": - self.key_string_value_json_storage_cls.db = oracle_db - if self.vector_storage == "OracleVectorDBStorage": - self.vector_db_storage_cls.db = oracle_db - if self.graph_storage == "OracleGraphStorage": - self.graph_storage_cls.db = oracle_db - - # Check if TiDB storage implementation is used - if ( - self.kv_storage == "TiDBKVStorage" - or self.vector_storage == "TiDBVectorDBStorage" - or self.graph_storage == "TiDBGraphStorage" - ): - # Get parameters from environment variables or config file - dbconfig = { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - - # Initialize TiDB object - from .kg.tidb_impl import TiDB - - tidb_db = TiDB(dbconfig) - # Check if DB tables exist, if not, tables will be created - loop = always_get_an_event_loop() - loop.run_until_complete(tidb_db.check_tables()) - - # Only inject db object for TiDB storage implementations - if self.kv_storage == "TiDBKVStorage": - self.key_string_value_json_storage_cls.db = tidb_db - if self.vector_storage == "TiDBVectorDBStorage": - self.vector_db_storage_cls.db = tidb_db - if self.graph_storage == "TiDBGraphStorage": - self.graph_storage_cls.db = tidb_db - - # Check if PostgreSQL storage implementation is used - if ( - self.kv_storage == "PGKVStorage" - or self.vector_storage == "PGVectorStorage" - or self.graph_storage == "PGGraphStorage" - or self.doc_status_storage == "PGDocStatusStorage" - ): - # Read configuration file - config_parser = configparser.ConfigParser() - if os.path.exists("config.ini"): - config_parser.read("config.ini") - - # Get parameters from environment variables or config file - dbconfig = { - "host": os.environ.get( - "POSTGRES_HOST", - config.get("postgres", "host", fallback="localhost"), - ), - "port": os.environ.get( - "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) - ), - "user": os.environ.get( - "POSTGRES_USER", config.get("postgres", "user", fallback=None) - ), - "password": os.environ.get( - "POSTGRES_PASSWORD", - config.get("postgres", "password", fallback=None), - ), - "database": os.environ.get( - "POSTGRES_DATABASE", - config.get("postgres", "database", fallback=None), - ), - "workspace": os.environ.get( - "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), - ), - } - - # Initialize PostgreSQLDB object - from .kg.postgres_impl import PostgreSQLDB - - postgres_db = PostgreSQLDB(dbconfig) - # Initialize and check tables - loop = always_get_an_event_loop() - loop.run_until_complete(postgres_db.initdb()) - # Check if DB tables exist, if not, tables will be created - loop.run_until_complete(postgres_db.check_tables()) - - # Only inject db object for PostgreSQL storage implementations - if self.kv_storage == "PGKVStorage": - self.key_string_value_json_storage_cls.db = postgres_db - if self.vector_storage == "PGVectorStorage": - self.vector_db_storage_cls.db = postgres_db - if self.graph_storage == "PGGraphStorage": - self.graph_storage_cls.db = postgres_db - if self.doc_status_storage == "PGDocStatusStorage": - self.doc_status_storage_cls.db = postgres_db - self.llm_response_cache = self.key_string_value_json_storage_cls( namespace=make_namespace( self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE @@ -664,6 +588,13 @@ class LightRAG: embedding_func=self.embedding_func, ) + # Initialize document status storage + self.doc_status: DocStatusStorage = self.doc_status_storage_cls( + namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), + global_config=global_config, + embedding_func=None, + ) + # What's for, Is this nessisary ? if self.llm_response_cache and hasattr( self.llm_response_cache, "global_config" @@ -677,16 +608,21 @@ class LightRAG: embedding_func=self.embedding_func, ) - # self.json_doc_status_storage = self.key_string_value_json_storage_cls( - # namespace=self.namespace_prefix + "json_doc_status_storage", - # embedding_func=None, - # ) - self.doc_status: DocStatusStorage = self.doc_status_storage_cls( - namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS), - global_config=global_config, - embedding_func=None, - ) + # Collect all storage instances + storage_instances = [ + self.full_docs, + self.text_chunks, + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.chunks_vdb, + self.doc_status, + ] + + # Initialize database connections if needed + loop = always_get_an_event_loop() + loop.run_until_complete(self._initialize_database_if_needed(storage_instances)) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -710,8 +646,81 @@ class LightRAG: storage_class = lazy_external_import(import_path, storage_name) return storage_class + async def _initialize_database_if_needed(self, storage_instances: list): + """Intialize database connection and inject it to storage implementation if needed""" + from .kg.postgres_impl import PostgreSQLDB + from .kg.oracle_impl import OracleDB + from .kg.tidb_impl import TiDB + from .kg.postgres_impl import ( + PGKVStorage, + PGVectorStorage, + PGGraphStorage, + PGDocStatusStorage, + ) + from .kg.oracle_impl import ( + OracleKVStorage, + OracleVectorDBStorage, + OracleGraphStorage, + ) + from .kg.tidb_impl import ( + TiDBKVStorage, + TiDBVectorDBStorage, + TiDBGraphStorage) + + # Checking if PostgreSQL is needed + if any( + isinstance( + storage, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ) + for storage in storage_instances + ): + postgres_db = PostgreSQLDB(self._get_postgres_config()) + await postgres_db.initdb() + await postgres_db.check_tables() + for storage in storage_instances: + if isinstance( + storage, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ): + storage.db = postgres_db + logger.info(f"Injected postgres_db to {storage.__class__.__name__}") + + # Checking if Oracle is needed + if any( + isinstance( + storage, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage) + ) + for storage in storage_instances + ): + oracle_db = OracleDB(self._get_oracle_config()) + await oracle_db.check_tables() + for storage in storage_instances: + if isinstance( + storage, + (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), + ): + storage.db = oracle_db + logger.info(f"Injected oracle_db to {storage.__class__.__name__}") + + # Checking if TiDB is needed + if any( + isinstance(storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)) + for storage in storage_instances + ): + tidb_db = TiDB(self._get_tidb_config()) + await tidb_db.check_tables() + # 注入db实例 + for storage in storage_instances: + if isinstance( + storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage) + ): + storage.db = tidb_db + logger.info(f"Injected tidb_db to {storage.__class__.__name__}") + def set_storage_client(self, db_client): - # Now only tested on Oracle Database + # Inject db to storage implementation (only tested on Oracle Database + # Deprecated, seting correct value to *_storage creating LightRAG insteaded for storage in [ self.vector_db_storage_cls, self.graph_storage_cls, From 3372af7c3d6cf7e181403eba07b9739d09b26110 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 12 Feb 2025 22:54:22 +0800 Subject: [PATCH 18/35] refactor: remove injected db field from PGDocStatusStorage, it must be injected after object is created --- lightrag/kg/postgres_impl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 221202ab..eaa1dd92 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -418,10 +418,8 @@ class PGVectorStorage(BaseVectorStorage): @dataclass class PGDocStatusStorage(DocStatusStorage): - """PostgreSQL implementation of document status storage""" - # db instance must be injected before use - db: PostgreSQLDB + # db: PostgreSQLDB async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that don't exist in storage""" From 274cd73a8f98a4965df43d9eb41309c1b782a524 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 12 Feb 2025 22:55:47 +0800 Subject: [PATCH 19/35] refactor: improve storage initialization with named instances to aid logging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add storage names to instance list • Use tuples to store name with instance • Update type hints for storage instances • Improve logging with actual storage names • Clean up loop variable naming --- lightrag/lightrag.py | 54 ++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 5648c85d..dcd829eb 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -609,15 +609,15 @@ class LightRAG: ) - # Collect all storage instances + # Collect all storage instances with their names storage_instances = [ - self.full_docs, - self.text_chunks, - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.chunks_vdb, - self.doc_status, + ("full_docs", self.full_docs), + ("text_chunks", self.text_chunks), + ("chunk_entity_relation_graph", self.chunk_entity_relation_graph), + ("entities_vdb", self.entities_vdb), + ("relationships_vdb", self.relationships_vdb), + ("chunks_vdb", self.chunks_vdb), + ("doc_status", self.doc_status), ] # Initialize database connections if needed @@ -646,7 +646,7 @@ class LightRAG: storage_class = lazy_external_import(import_path, storage_name) return storage_class - async def _initialize_database_if_needed(self, storage_instances: list): + async def _initialize_database_if_needed(self, storage_instances: list[tuple[str, Any]]): """Intialize database connection and inject it to storage implementation if needed""" from .kg.postgres_impl import PostgreSQLDB from .kg.oracle_impl import OracleDB @@ -670,53 +670,53 @@ class LightRAG: # Checking if PostgreSQL is needed if any( isinstance( - storage, + storage_instance, (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), ) - for storage in storage_instances + for _, storage_instance in storage_instances ): postgres_db = PostgreSQLDB(self._get_postgres_config()) await postgres_db.initdb() await postgres_db.check_tables() - for storage in storage_instances: + for storage_name, storage_instance in storage_instances: if isinstance( - storage, + storage_instance, (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), ): - storage.db = postgres_db - logger.info(f"Injected postgres_db to {storage.__class__.__name__}") + storage_instance.db = postgres_db + logger.info(f"Injected postgres_db to {storage_name}") # Checking if Oracle is needed if any( isinstance( - storage, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage) + storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage) ) - for storage in storage_instances + for _, storage_instance in storage_instances ): oracle_db = OracleDB(self._get_oracle_config()) await oracle_db.check_tables() - for storage in storage_instances: + for storage_name, storage_instance in storage_instances: if isinstance( - storage, + storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), ): - storage.db = oracle_db - logger.info(f"Injected oracle_db to {storage.__class__.__name__}") + storage_instance.db = oracle_db + logger.info(f"Injected oracle_db to {storage_name}") # Checking if TiDB is needed if any( - isinstance(storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)) - for storage in storage_instances + isinstance(storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)) + for _, storage_instance in storage_instances ): tidb_db = TiDB(self._get_tidb_config()) await tidb_db.check_tables() # 注入db实例 - for storage in storage_instances: + for storage_name, storage_instance in storage_instances: if isinstance( - storage, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage) + storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage) ): - storage.db = tidb_db - logger.info(f"Injected tidb_db to {storage.__class__.__name__}") + storage_instance.db = tidb_db + logger.info(f"Injected tidb_db to {storage_name}") def set_storage_client(self, db_client): # Inject db to storage implementation (only tested on Oracle Database From 7c7cac1cfd0dfbe35ebd5cf07e868f58d0f867f1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 00:39:40 +0800 Subject: [PATCH 20/35] fix: remove unnecessary param binding, use direct workspace string interpolation --- lightrag/kg/postgres_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index eaa1dd92..08d559fb 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -425,9 +425,9 @@ class PGDocStatusStorage(DocStatusStorage): """Return keys that don't exist in storage""" keys = ",".join([f"'{_id}'" for _id in data]) sql = ( - f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})" + f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" ) - result = await self.db.query(sql, {"workspace": self.db.workspace}, True) + result = await self.db.query(sql, multirows=True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: return set(data) From 4c39cf399d4f6fdf4dede1b03afecdd114d4d60b Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 01:11:09 +0800 Subject: [PATCH 21/35] refactor: move database connection pool initialization to lifespan of FastAPI - Add proper database connection lifecycle management - Add connection pool cleanup in FastAPI lifespan --- lightrag/api/lightrag_server.py | 230 +++++++++++++++++++++++++++++--- lightrag/lightrag.py | 171 ------------------------ 2 files changed, 213 insertions(+), 188 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 97b3f5a5..0839c1f8 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -33,14 +33,39 @@ from contextlib import asynccontextmanager from starlette.status import HTTP_403_FORBIDDEN import pipmaster as pm from dotenv import load_dotenv +import configparser +from lightrag.utils import logger from .ollama_api import ( OllamaAPI, ) from .ollama_api import ollama_server_infos +from ..kg.postgres_impl import ( + PostgreSQLDB, + PGKVStorage, + PGVectorStorage, + PGGraphStorage, + PGDocStatusStorage, +) +from ..kg.oracle_impl import ( + OracleDB, + OracleKVStorage, + OracleVectorDBStorage, + OracleGraphStorage, +) +from ..kg.tidb_impl import ( + TiDB, + TiDBKVStorage, + TiDBVectorDBStorage, + TiDBGraphStorage, +) # Load environment variables load_dotenv(override=True) +# Initialize config parser +config = configparser.ConfigParser() +config.read("config.ini") + class RAGStorageConfig: """存储配置类,支持通过环境变量和命令行参数修改默认值""" @@ -714,25 +739,99 @@ def create_app(args): @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" - # Startup logic - if args.auto_scan_at_startup: - try: - new_files = doc_manager.scan_directory_for_new_files() - for file_path in new_files: - try: - await index_file(file_path) - except Exception as e: - trace_exception(e) - logging.error(f"Error indexing file {file_path}: {str(e)}") + # Initialize database connections + postgres_db = None + oracle_db = None + tidb_db = None - ASCIIColors.info( - f"Indexed {len(new_files)} documents from {args.input_dir}" + try: + # Check if PostgreSQL is needed + if any( + isinstance( + storage_instance, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), ) - except Exception as e: - logging.error(f"Error during startup indexing: {str(e)}") - yield - # Cleanup logic (if needed) - pass + for _, storage_instance in storage_instances + ): + postgres_db = PostgreSQLDB(_get_postgres_config()) + await postgres_db.initdb() + await postgres_db.check_tables() + for storage_name, storage_instance in storage_instances: + if isinstance( + storage_instance, + (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ): + storage_instance.db = postgres_db + logger.info(f"Injected postgres_db to {storage_name}") + + # Check if Oracle is needed + if any( + isinstance( + storage_instance, + (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), + ) + for _, storage_instance in storage_instances + ): + oracle_db = OracleDB(_get_oracle_config()) + await oracle_db.check_tables() + for storage_name, storage_instance in storage_instances: + if isinstance( + storage_instance, + (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), + ): + storage_instance.db = oracle_db + logger.info(f"Injected oracle_db to {storage_name}") + + # Check if TiDB is needed + if any( + isinstance( + storage_instance, + (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), + ) + for _, storage_instance in storage_instances + ): + tidb_db = TiDB(_get_tidb_config()) + await tidb_db.check_tables() + for storage_name, storage_instance in storage_instances: + if isinstance( + storage_instance, + (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage), + ): + storage_instance.db = tidb_db + logger.info(f"Injected tidb_db to {storage_name}") + + # Auto scan documents if enabled + if args.auto_scan_at_startup: + try: + new_files = doc_manager.scan_directory_for_new_files() + for file_path in new_files: + try: + await index_file(file_path) + except Exception as e: + trace_exception(e) + logging.error(f"Error indexing file {file_path}: {str(e)}") + + ASCIIColors.info( + f"Indexed {len(new_files)} documents from {args.input_dir}" + ) + except Exception as e: + logging.error(f"Error during startup indexing: {str(e)}") + + yield + + finally: + # Cleanup database connections + if postgres_db and hasattr(postgres_db, "pool"): + await postgres_db.pool.close() + logger.info("Closed PostgreSQL connection pool") + + if oracle_db and hasattr(oracle_db, "pool"): + await oracle_db.pool.close() + logger.info("Closed Oracle connection pool") + + if tidb_db and hasattr(tidb_db, "pool"): + await tidb_db.pool.close() + logger.info("Closed TiDB connection pool") # Initialize FastAPI app = FastAPI( @@ -755,6 +854,92 @@ def create_app(args): allow_headers=["*"], ) + # Database configuration functions + def _get_postgres_config(): + return { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback=None) + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback=None), + ), + "workspace": os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + } + + def _get_oracle_config(): + return { + "user": os.environ.get( + "ORACLE_USER", + config.get("oracle", "user", fallback=None), + ), + "password": os.environ.get( + "ORACLE_PASSWORD", + config.get("oracle", "password", fallback=None), + ), + "dsn": os.environ.get( + "ORACLE_DSN", + config.get("oracle", "dsn", fallback=None), + ), + "config_dir": os.environ.get( + "ORACLE_CONFIG_DIR", + config.get("oracle", "config_dir", fallback=None), + ), + "wallet_location": os.environ.get( + "ORACLE_WALLET_LOCATION", + config.get("oracle", "wallet_location", fallback=None), + ), + "wallet_password": os.environ.get( + "ORACLE_WALLET_PASSWORD", + config.get("oracle", "wallet_password", fallback=None), + ), + "workspace": os.environ.get( + "ORACLE_WORKSPACE", + config.get("oracle", "workspace", fallback="default"), + ), + } + + def _get_tidb_config(): + return { + "host": os.environ.get( + "TIDB_HOST", + config.get("tidb", "host", fallback="localhost"), + ), + "port": os.environ.get( + "TIDB_PORT", config.get("tidb", "port", fallback=4000) + ), + "user": os.environ.get( + "TIDB_USER", + config.get("tidb", "user", fallback=None), + ), + "password": os.environ.get( + "TIDB_PASSWORD", + config.get("tidb", "password", fallback=None), + ), + "database": os.environ.get( + "TIDB_DATABASE", + config.get("tidb", "database", fallback=None), + ), + "workspace": os.environ.get( + "TIDB_WORKSPACE", + config.get("tidb", "workspace", fallback="default"), + ), + } + # Create the optional API key dependency optional_api_key = get_api_key_dependency(api_key) @@ -921,6 +1106,17 @@ def create_app(args): namespace_prefix=args.namespace_prefix, ) + # Collect all storage instances + storage_instances = [ + ("full_docs", rag.full_docs), + ("text_chunks", rag.text_chunks), + ("chunk_entity_relation_graph", rag.chunk_entity_relation_graph), + ("entities_vdb", rag.entities_vdb), + ("relationships_vdb", rag.relationships_vdb), + ("chunks_vdb", rag.chunks_vdb), + ("doc_status", rag.doc_status), + ] + async def index_file(file_path: Union[str, Path]) -> None: """Index all files inside the folder with support for multiple file formats diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index dcd829eb..e6217572 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -355,91 +355,6 @@ class LightRAG: list[dict[str, Any]], ] = chunking_by_token_size - def _get_postgres_config(self): - return { - "host": os.environ.get( - "POSTGRES_HOST", - config.get("postgres", "host", fallback="localhost"), - ), - "port": os.environ.get( - "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) - ), - "user": os.environ.get( - "POSTGRES_USER", config.get("postgres", "user", fallback=None) - ), - "password": os.environ.get( - "POSTGRES_PASSWORD", - config.get("postgres", "password", fallback=None), - ), - "database": os.environ.get( - "POSTGRES_DATABASE", - config.get("postgres", "database", fallback=None), - ), - "workspace": os.environ.get( - "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), - ), - } - - def _get_oracle_config(self): - return { - "user": os.environ.get( - "ORACLE_USER", - config.get("oracle", "user", fallback=None), - ), - "password": os.environ.get( - "ORACLE_PASSWORD", - config.get("oracle", "password", fallback=None), - ), - "dsn": os.environ.get( - "ORACLE_DSN", - config.get("oracle", "dsn", fallback=None), - ), - "config_dir": os.environ.get( - "ORACLE_CONFIG_DIR", - config.get("oracle", "config_dir", fallback=None), - ), - "wallet_location": os.environ.get( - "ORACLE_WALLET_LOCATION", - config.get("oracle", "wallet_location", fallback=None), - ), - "wallet_password": os.environ.get( - "ORACLE_WALLET_PASSWORD", - config.get("oracle", "wallet_password", fallback=None), - ), - "workspace": os.environ.get( - "ORACLE_WORKSPACE", - config.get("oracle", "workspace", fallback="default"), - ), - } - - def _get_tidb_config(self): - return { - "host": os.environ.get( - "TIDB_HOST", - config.get("tidb", "host", fallback="localhost"), - ), - "port": os.environ.get( - "TIDB_PORT", config.get("tidb", "port", fallback=4000) - ), - "user": os.environ.get( - "TIDB_USER", - config.get("tidb", "user", fallback=None), - ), - "password": os.environ.get( - "TIDB_PASSWORD", - config.get("tidb", "password", fallback=None), - ), - "database": os.environ.get( - "TIDB_DATABASE", - config.get("tidb", "database", fallback=None), - ), - "workspace": os.environ.get( - "TIDB_WORKSPACE", - config.get("tidb", "workspace", fallback="default"), - ), - } - def verify_storage_implementation( self, storage_type: str, storage_name: str ) -> None: @@ -609,20 +524,6 @@ class LightRAG: ) - # Collect all storage instances with their names - storage_instances = [ - ("full_docs", self.full_docs), - ("text_chunks", self.text_chunks), - ("chunk_entity_relation_graph", self.chunk_entity_relation_graph), - ("entities_vdb", self.entities_vdb), - ("relationships_vdb", self.relationships_vdb), - ("chunks_vdb", self.chunks_vdb), - ("doc_status", self.doc_status), - ] - - # Initialize database connections if needed - loop = always_get_an_event_loop() - loop.run_until_complete(self._initialize_database_if_needed(storage_instances)) self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( @@ -646,78 +547,6 @@ class LightRAG: storage_class = lazy_external_import(import_path, storage_name) return storage_class - async def _initialize_database_if_needed(self, storage_instances: list[tuple[str, Any]]): - """Intialize database connection and inject it to storage implementation if needed""" - from .kg.postgres_impl import PostgreSQLDB - from .kg.oracle_impl import OracleDB - from .kg.tidb_impl import TiDB - from .kg.postgres_impl import ( - PGKVStorage, - PGVectorStorage, - PGGraphStorage, - PGDocStatusStorage, - ) - from .kg.oracle_impl import ( - OracleKVStorage, - OracleVectorDBStorage, - OracleGraphStorage, - ) - from .kg.tidb_impl import ( - TiDBKVStorage, - TiDBVectorDBStorage, - TiDBGraphStorage) - - # Checking if PostgreSQL is needed - if any( - isinstance( - storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), - ) - for _, storage_instance in storage_instances - ): - postgres_db = PostgreSQLDB(self._get_postgres_config()) - await postgres_db.initdb() - await postgres_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), - ): - storage_instance.db = postgres_db - logger.info(f"Injected postgres_db to {storage_name}") - - # Checking if Oracle is needed - if any( - isinstance( - storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage) - ) - for _, storage_instance in storage_instances - ): - oracle_db = OracleDB(self._get_oracle_config()) - await oracle_db.check_tables() - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, - (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage), - ): - storage_instance.db = oracle_db - logger.info(f"Injected oracle_db to {storage_name}") - - # Checking if TiDB is needed - if any( - isinstance(storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)) - for _, storage_instance in storage_instances - ): - tidb_db = TiDB(self._get_tidb_config()) - await tidb_db.check_tables() - # 注入db实例 - for storage_name, storage_instance in storage_instances: - if isinstance( - storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage) - ): - storage_instance.db = tidb_db - logger.info(f"Injected tidb_db to {storage_name}") - def set_storage_client(self, db_client): # Inject db to storage implementation (only tested on Oracle Database # Deprecated, seting correct value to *_storage creating LightRAG insteaded From 7a89916bab12966f1b1398afcf9fb048b1704cf6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 01:27:27 +0800 Subject: [PATCH 22/35] Add method to retrieve in-progress documents in DocStatusStorage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add get_processing_docs() abstract method • Override get_processing_docs() in PG storage • Method retrieves docs with PROCESSING status • Keep consistent with existing status methods --- lightrag/base.py | 4 ++++ lightrag/kg/postgres_impl.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/lightrag/base.py b/lightrag/base.py index bd79d990..147cb444 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -227,6 +227,10 @@ class DocStatusStorage(BaseKVStorage): """Get all pending documents""" raise NotImplementedError + async def get_processing_docs(self) -> dict[str, DocProcessingStatus]: + """Get all documents that are currently being processed""" + raise NotImplementedError + async def update_doc_status(self, data: dict[str, Any]) -> None: """Updates the status of a document. By default, it calls upsert.""" await self.upsert(data) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 08d559fb..4b6f524f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -493,6 +493,10 @@ class PGDocStatusStorage(DocStatusStorage): """Get all pending documents""" return await self.get_docs_by_status(DocStatus.PENDING) + async def get_processing_docs(self) -> Dict[str, DocProcessingStatus]: + """Get all documents that are currently being processed""" + return await self.get_docs_by_status(DocStatus.PROCESSING) + async def index_done_callback(self): """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here""" logger.info("Doc status had been saved into postgresql db!") From 9a77d9102390ecd9ea4e570817705e4e847caaa1 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 01:30:21 +0800 Subject: [PATCH 23/35] Add LLM response cache to registered RagServer components --- lightrag/api/lightrag_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0839c1f8..44443440 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -1115,6 +1115,7 @@ def create_app(args): ("relationships_vdb", rag.relationships_vdb), ("chunks_vdb", rag.chunks_vdb), ("doc_status", rag.doc_status), + ("llm_response_cache", rag.llm_response_cache), ] async def index_file(file_path: Union[str, Path]) -> None: From 3308ecfa69f9ecac3a39d845fb7d3669491b7092 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 02:14:32 +0800 Subject: [PATCH 24/35] Refactor logging for vector similarity search with configurable threshold --- lightrag/kg/nano_vector_db_impl.py | 3 --- lightrag/operate.py | 6 ++++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 6e8873fc..1cbd1b0b 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -139,9 +139,6 @@ class NanoVectorDBStorage(BaseVectorStorage): async def query(self, query: str, top_k=5): embedding = await self.embedding_func([query]) embedding = embedding[0] - logger.info( - f"Query: {query}, top_k: {top_k}, cosine: {self.cosine_better_than_threshold}" - ) results = self._client.query( query=embedding, top_k=top_k, diff --git a/lightrag/operate.py b/lightrag/operate.py index db7f59a5..ee3c4512 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,5 +1,6 @@ import asyncio import json +import os import re from tqdm.asyncio import tqdm as tqdm_async from typing import Any, Union @@ -34,6 +35,9 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS import time +COSINE_THRESHOLD = float(os.getenv("COSINE_THRESHOLD", "0.2")) + + def chunking_by_token_size( content: str, split_by_character: Union[str, None] = None, @@ -1055,6 +1059,7 @@ async def _get_node_data( query_param: QueryParam, ): # get similar entities + logger.info(f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {COSINE_THRESHOLD}") results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return "", "", "" @@ -1270,6 +1275,7 @@ async def _get_edge_data( text_chunks_db: BaseKVStorage, query_param: QueryParam, ): + logger.info(f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {COSINE_THRESHOLD}") results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): From f01f57d0daea8a78cbfeda284811e1a2cccf2ec2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 03:25:48 +0800 Subject: [PATCH 25/35] refactor: make cosine similarity threshold a required config parameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Remove default threshold from env var • Add validation for missing threshold • Move default to lightrag.py config init • Update all vector DB implementations • Improve threshold validation consistency --- lightrag/kg/chroma_impl.py | 10 +++++----- lightrag/kg/faiss_impl.py | 9 +++++---- lightrag/kg/milvus_impl.py | 10 +++++++++- lightrag/kg/nano_vector_db_impl.py | 9 +++++---- lightrag/kg/oracle_impl.py | 10 +++++----- lightrag/kg/postgres_impl.py | 10 +++++----- lightrag/kg/qdrant_impl.py | 12 +++++++++++- lightrag/kg/tidb_impl.py | 10 +++++----- lightrag/lightrag.py | 9 +++++++++ 9 files changed, 59 insertions(+), 30 deletions(-) diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 72a2627a..242c93ea 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -13,15 +13,15 @@ from lightrag.utils import logger class ChromaVectorDBStorage(BaseVectorStorage): """ChromaDB vector storage implementation.""" - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): try: - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold user_collection_settings = config.get("collection_settings", {}) # Default HNSW index settings for ChromaDB diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index fc6aa779..47111a47 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -23,14 +23,15 @@ class FaissVectorDBStorage(BaseVectorStorage): Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search. """ - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): # Grab config values if available config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold # Where to save index file if you want persistent storage self._faiss_index_file = os.path.join( diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index ae0daac2..dd50c026 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -19,6 +19,8 @@ config.read("config.ini", "utf-8") @dataclass class MilvusVectorDBStorge(BaseVectorStorage): + cosine_better_than_threshold: float = None + @staticmethod def create_collection_if_not_exist( client: MilvusClient, collection_name: str, **kwargs @@ -30,6 +32,12 @@ class MilvusVectorDBStorge(BaseVectorStorage): ) def __post_init__(self): + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold + self._client = MilvusClient( uri=os.environ.get( "MILVUS_URI", @@ -103,7 +111,7 @@ class MilvusVectorDBStorge(BaseVectorStorage): data=embedding, limit=top_k, output_fields=list(self.meta_fields), - search_params={"metric_type": "COSINE", "params": {"radius": 0.2}}, + search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}}, ) print(results) return [ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 1cbd1b0b..5a61bf4f 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -73,16 +73,17 @@ from lightrag.base import ( @dataclass class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): # Initialize lock only for file operations self._save_lock = asyncio.Lock() # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index c2859829..5a1e0616 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -320,14 +320,14 @@ class OracleKVStorage(BaseKVStorage): class OracleVectorDBStorage(BaseVectorStorage): # db instance must be injected before use # db: OracleDB - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold async def upsert(self, data: dict[str, dict]): """向向量数据库中插入数据""" diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 4b6f524f..dde88739 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -299,15 +299,15 @@ class PGKVStorage(BaseKVStorage): class PGVectorStorage(BaseVectorStorage): # db instance must be injected before use # db: PostgreSQLDB - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold def _upsert_chunks(self, item: dict): try: diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index bda23f8d..88dce27f 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -50,6 +50,8 @@ def compute_mdhash_id_for_qdrant( @dataclass class QdrantVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = None + @staticmethod def create_collection_if_not_exist( client: QdrantClient, collection_name: str, **kwargs @@ -59,6 +61,12 @@ class QdrantVectorDBStorage(BaseVectorStorage): client.create_collection(collection_name, **kwargs) def __post_init__(self): + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold + self._client = QdrantClient( url=os.environ.get( "QDRANT_URL", config.get("qdrant", "uri", fallback=None) @@ -131,4 +139,6 @@ class QdrantVectorDBStorage(BaseVectorStorage): with_payload=True, ) logger.debug(f"query result: {results}") - return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in results] + # 添加余弦相似度过滤 + filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold] + return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results] diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index ba5a6240..248f2c85 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -212,18 +212,18 @@ class TiDBKVStorage(BaseKVStorage): class TiDBVectorDBStorage(BaseVectorStorage): # db instance must be injected before use # db: TiDB - cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2")) + cosine_better_than_threshold: float = None def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] - # Use global config value if specified, otherwise use default config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.cosine_better_than_threshold = config.get( - "cosine_better_than_threshold", self.cosine_better_than_threshold - ) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + self.cosine_better_than_threshold = cosine_threshold async def query(self, query: str, top_k: int) -> list[dict]: """Search from tidb vector""" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e6217572..66508faf 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -420,6 +420,15 @@ class LightRAG: # Check environment variables self.check_storage_env_vars(storage_name) + # Ensure vector_db_storage_cls_kwargs has required fields + default_vector_db_kwargs = { + "cosine_better_than_threshold": float(os.getenv("COSINE_THRESHOLD", "0.2")) + } + self.vector_db_storage_cls_kwargs = { + **default_vector_db_kwargs, + **self.vector_db_storage_cls_kwargs + } + # show config global_config = asdict(self) _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()]) From 11c7af7fd86f552d00d3a1265574a3b8ced2fa33 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 03:34:31 +0800 Subject: [PATCH 26/35] refactor: use vdb instance's cosine threshold instead of global constant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Remove global COSINE_THRESHOLD • Use instance-level threshold config • Update logging statements • Reference vdb threshold directly --- lightrag/operate.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index ee3c4512..f8d484af 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -35,8 +35,6 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS import time -COSINE_THRESHOLD = float(os.getenv("COSINE_THRESHOLD", "0.2")) - def chunking_by_token_size( content: str, @@ -1059,7 +1057,7 @@ async def _get_node_data( query_param: QueryParam, ): # get similar entities - logger.info(f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {COSINE_THRESHOLD}") + logger.info(f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}") results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return "", "", "" @@ -1275,7 +1273,7 @@ async def _get_edge_data( text_chunks_db: BaseKVStorage, query_param: QueryParam, ): - logger.info(f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {COSINE_THRESHOLD}") + logger.info(f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}") results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): From d25386ff1b2f4acdc195418723b1795f786686d6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 04:04:51 +0800 Subject: [PATCH 27/35] refactor: simplify storage configuration handling while maintaining the same functionality --- lightrag/api/lightrag_server.py | 96 +++++++++++---------------------- 1 file changed, 31 insertions(+), 65 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 44443440..1f531c4f 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -66,43 +66,11 @@ load_dotenv(override=True) config = configparser.ConfigParser() config.read("config.ini") - -class RAGStorageConfig: - """存储配置类,支持通过环境变量和命令行参数修改默认值""" - - # 默认存储实现 - DEFAULT_KV_STORAGE = "JsonKVStorage" - DEFAULT_VECTOR_STORAGE = "NanoVectorDBStorage" - DEFAULT_GRAPH_STORAGE = "NetworkXStorage" - DEFAULT_DOC_STATUS_STORAGE = "JsonDocStatusStorage" - - def __init__(self): - # 从环境变量读取配置,如果没有则使用默认值 - self.KV_STORAGE = os.getenv("LIGHTRAG_KV_STORAGE", self.DEFAULT_KV_STORAGE) - self.DOC_STATUS_STORAGE = os.getenv( - "LIGHTRAG_DOC_STATUS_STORAGE", self.DEFAULT_DOC_STATUS_STORAGE - ) - self.GRAPH_STORAGE = os.getenv( - "LIGHTRAG_GRAPH_STORAGE", self.DEFAULT_GRAPH_STORAGE - ) - self.VECTOR_STORAGE = os.getenv( - "LIGHTRAG_VECTOR_STORAGE", self.DEFAULT_VECTOR_STORAGE - ) - - def update_from_args(self, args): - """从命令行参数更新配置""" - if hasattr(args, "kv_storage"): - self.KV_STORAGE = args.kv_storage - if hasattr(args, "doc_status_storage"): - self.DOC_STATUS_STORAGE = args.doc_status_storage - if hasattr(args, "graph_storage"): - self.GRAPH_STORAGE = args.graph_storage - if hasattr(args, "vector_storage"): - self.VECTOR_STORAGE = args.vector_storage - - -# 初始化存储配置 -rag_storage_config = RAGStorageConfig() +class DefaultRAGStorageConfig: + KV_STORAGE = "JsonKVStorage" + VECTOR_STORAGE = "NanoVectorDBStorage" + GRAPH_STORAGE = "NetworkXStorage" + DOC_STATUS_STORAGE = "JsonDocStatusStorage" # Global progress tracker scan_progress: Dict = { @@ -246,13 +214,13 @@ def display_splash_screen(args: argparse.Namespace) -> None: # System Configuration ASCIIColors.magenta("\n💾 Storage Configuration:") ASCIIColors.white(" ├─ KV Storage: ", end="") - ASCIIColors.yellow(f"{rag_storage_config.KV_STORAGE}") - ASCIIColors.white(" ├─ Document Status Storage: ", end="") - ASCIIColors.yellow(f"{rag_storage_config.DOC_STATUS_STORAGE}") + ASCIIColors.yellow(f"{args.kv_storage}") + ASCIIColors.white(" ├─ Vector Storage: ", end="") + ASCIIColors.yellow(f"{args.vector_storage}") ASCIIColors.white(" ├─ Graph Storage: ", end="") - ASCIIColors.yellow(f"{rag_storage_config.GRAPH_STORAGE}") - ASCIIColors.white(" └─ Vector Storage: ", end="") - ASCIIColors.yellow(f"{rag_storage_config.VECTOR_STORAGE}") + ASCIIColors.yellow(f"{args.graph_storage}") + ASCIIColors.white(" └─ Document Status Storage: ", end="") + ASCIIColors.yellow(f"{args.doc_status_storage}") ASCIIColors.magenta("\n🛠️ System Configuration:") ASCIIColors.white(" ├─ Ollama Emulating Model: ", end="") @@ -349,23 +317,23 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--kv-storage", - default=rag_storage_config.KV_STORAGE, - help=f"KV存储实现 (default: {rag_storage_config.KV_STORAGE})", + default=get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE), + help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})", ) parser.add_argument( "--doc-status-storage", - default=rag_storage_config.DOC_STATUS_STORAGE, - help=f"文档状态存储实现 (default: {rag_storage_config.DOC_STATUS_STORAGE})", + default=get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE), + help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", ) parser.add_argument( "--graph-storage", - default=rag_storage_config.GRAPH_STORAGE, - help=f"图存储实现 (default: {rag_storage_config.GRAPH_STORAGE})", + default=get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE), + help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", ) parser.add_argument( "--vector-storage", - default=rag_storage_config.VECTOR_STORAGE, - help=f"向量存储实现 (default: {rag_storage_config.VECTOR_STORAGE})", + default=get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE), + help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", ) # Bindings configuration @@ -582,8 +550,6 @@ def parse_args() -> argparse.Namespace: args = parser.parse_args() - rag_storage_config.update_from_args(args) - ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args @@ -1058,10 +1024,10 @@ def create_app(args): if args.llm_binding == "lollms" or args.llm_binding == "ollama" else {}, embedding_func=embedding_func, - kv_storage=rag_storage_config.KV_STORAGE, - graph_storage=rag_storage_config.GRAPH_STORAGE, - vector_storage=rag_storage_config.VECTOR_STORAGE, - doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, + kv_storage=args.kv_storage, + graph_storage=args.graph_storage, + vector_storage=args.vector_storage, + doc_status_storage=args.doc_status_storage, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -1089,10 +1055,10 @@ def create_app(args): llm_model_max_async=args.max_async, llm_model_max_token_size=args.max_tokens, embedding_func=embedding_func, - kv_storage=rag_storage_config.KV_STORAGE, - graph_storage=rag_storage_config.GRAPH_STORAGE, - vector_storage=rag_storage_config.VECTOR_STORAGE, - doc_status_storage=rag_storage_config.DOC_STATUS_STORAGE, + kv_storage=args.kv_storage, + graph_storage=args.graph_storage, + vector_storage=args.vector_storage, + doc_status_storage=args.doc_status_storage, vector_db_storage_cls_kwargs={ "cosine_better_than_threshold": args.cosine_threshold }, @@ -1658,10 +1624,10 @@ def create_app(args): "embedding_binding_host": args.embedding_binding_host, "embedding_model": args.embedding_model, "max_tokens": args.max_tokens, - "kv_storage": rag_storage_config.KV_STORAGE, - "doc_status_storage": rag_storage_config.DOC_STATUS_STORAGE, - "graph_storage": rag_storage_config.GRAPH_STORAGE, - "vector_storage": rag_storage_config.VECTOR_STORAGE, + "kv_storage": args.kv_storage, + "doc_status_storage": args.doc_status_storage, + "graph_storage": args.graph_storage, + "vector_storage": args.vector_storage, }, } From ed73ea407643a9c004fed56b9383ee42ce741e66 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 04:12:00 +0800 Subject: [PATCH 28/35] Fix linting --- lightrag/api/lightrag_server.py | 29 ++++++++++++++++++++++------- lightrag/kg/chroma_impl.py | 5 +++-- lightrag/kg/faiss_impl.py | 4 +++- lightrag/kg/milvus_impl.py | 9 +++++++-- lightrag/kg/nano_vector_db_impl.py | 4 +++- lightrag/kg/oracle_impl.py | 5 +++-- lightrag/kg/postgres_impl.py | 8 ++++---- lightrag/kg/qdrant_impl.py | 12 +++++++++--- lightrag/kg/tidb_impl.py | 4 +++- lightrag/lightrag.py | 4 +--- lightrag/operate.py | 10 ++++++---- 11 files changed, 64 insertions(+), 30 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 1f531c4f..b8e4f1e6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -66,12 +66,14 @@ load_dotenv(override=True) config = configparser.ConfigParser() config.read("config.ini") + class DefaultRAGStorageConfig: KV_STORAGE = "JsonKVStorage" VECTOR_STORAGE = "NanoVectorDBStorage" GRAPH_STORAGE = "NetworkXStorage" DOC_STATUS_STORAGE = "JsonDocStatusStorage" + # Global progress tracker scan_progress: Dict = { "is_scanning": False, @@ -317,22 +319,30 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--kv-storage", - default=get_env_value("LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE), + default=get_env_value( + "LIGHTRAG_KV_STORAGE", DefaultRAGStorageConfig.KV_STORAGE + ), help=f"KV存储实现 (default: {DefaultRAGStorageConfig.KV_STORAGE})", ) parser.add_argument( "--doc-status-storage", - default=get_env_value("LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE), + default=get_env_value( + "LIGHTRAG_DOC_STATUS_STORAGE", DefaultRAGStorageConfig.DOC_STATUS_STORAGE + ), help=f"文档状态存储实现 (default: {DefaultRAGStorageConfig.DOC_STATUS_STORAGE})", ) parser.add_argument( "--graph-storage", - default=get_env_value("LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE), + default=get_env_value( + "LIGHTRAG_GRAPH_STORAGE", DefaultRAGStorageConfig.GRAPH_STORAGE + ), help=f"图存储实现 (default: {DefaultRAGStorageConfig.GRAPH_STORAGE})", ) parser.add_argument( "--vector-storage", - default=get_env_value("LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE), + default=get_env_value( + "LIGHTRAG_VECTOR_STORAGE", DefaultRAGStorageConfig.VECTOR_STORAGE + ), help=f"向量存储实现 (default: {DefaultRAGStorageConfig.VECTOR_STORAGE})", ) @@ -725,7 +735,12 @@ def create_app(args): for storage_name, storage_instance in storage_instances: if isinstance( storage_instance, - (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage), + ( + PGKVStorage, + PGVectorStorage, + PGGraphStorage, + PGDocStatusStorage, + ), ): storage_instance.db = postgres_db logger.info(f"Injected postgres_db to {storage_name}") @@ -790,11 +805,11 @@ def create_app(args): if postgres_db and hasattr(postgres_db, "pool"): await postgres_db.pool.close() logger.info("Closed PostgreSQL connection pool") - + if oracle_db and hasattr(oracle_db, "pool"): await oracle_db.pool.close() logger.info("Closed Oracle connection pool") - + if tidb_db and hasattr(tidb_db, "pool"): await tidb_db.pool.close() logger.info("Closed TiDB connection pool") diff --git a/lightrag/kg/chroma_impl.py b/lightrag/kg/chroma_impl.py index 242c93ea..82e723a1 100644 --- a/lightrag/kg/chroma_impl.py +++ b/lightrag/kg/chroma_impl.py @@ -1,4 +1,3 @@ -import os import asyncio from dataclasses import dataclass from typing import Union @@ -20,7 +19,9 @@ class ChromaVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold user_collection_settings = config.get("collection_settings", {}) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 47111a47..0dca9e4c 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -30,7 +30,9 @@ class FaissVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold # Where to save index file if you want persistent storage diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index dd50c026..1abec502 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -35,7 +35,9 @@ class MilvusVectorDBStorge(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold self._client = MilvusClient( @@ -111,7 +113,10 @@ class MilvusVectorDBStorge(BaseVectorStorage): data=embedding, limit=top_k, output_fields=list(self.meta_fields), - search_params={"metric_type": "COSINE", "params": {"radius": self.cosine_better_than_threshold}}, + search_params={ + "metric_type": "COSINE", + "params": {"radius": self.cosine_better_than_threshold}, + }, ) print(results) return [ diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 5a61bf4f..2db8f72a 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -82,7 +82,9 @@ class NanoVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold self._client_file_name = os.path.join( diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 5a1e0616..65f1060c 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -1,6 +1,5 @@ import array import asyncio -import os # import html # import os @@ -326,7 +325,9 @@ class OracleVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold async def upsert(self, data: dict[str, dict]): diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index dde88739..cb636d7f 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -306,7 +306,9 @@ class PGVectorStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold def _upsert_chunks(self, item: dict): @@ -424,9 +426,7 @@ class PGDocStatusStorage(DocStatusStorage): async def filter_keys(self, data: set[str]) -> set[str]: """Return keys that don't exist in storage""" keys = ",".join([f"'{_id}'" for _id in data]) - sql = ( - f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" - ) + sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace='{self.db.workspace}' AND id IN ({keys})" result = await self.db.query(sql, multirows=True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 88dce27f..7c9f21a0 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -64,7 +64,9 @@ class QdrantVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold self._client = QdrantClient( @@ -140,5 +142,9 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) logger.debug(f"query result: {results}") # 添加余弦相似度过滤 - filtered_results = [dp for dp in results if dp.score >= self.cosine_better_than_threshold] - return [{**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results] + filtered_results = [ + dp for dp in results if dp.score >= self.cosine_better_than_threshold + ] + return [ + {**dp.payload, "id": dp.id, "distance": dp.score} for dp in filtered_results + ] diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index 248f2c85..00b8003d 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -222,7 +222,9 @@ class TiDBVectorDBStorage(BaseVectorStorage): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: - raise ValueError("cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs") + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) self.cosine_better_than_threshold = cosine_threshold async def query(self, query: str, top_k: int) -> list[dict]: diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 66508faf..cdb0462e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -426,7 +426,7 @@ class LightRAG: } self.vector_db_storage_cls_kwargs = { **default_vector_db_kwargs, - **self.vector_db_storage_cls_kwargs + **self.vector_db_storage_cls_kwargs, } # show config @@ -532,8 +532,6 @@ class LightRAG: embedding_func=self.embedding_func, ) - - self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( partial( self.llm_model_func, diff --git a/lightrag/operate.py b/lightrag/operate.py index f8d484af..04aad0d4 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,6 +1,5 @@ import asyncio import json -import os import re from tqdm.asyncio import tqdm as tqdm_async from typing import Any, Union @@ -35,7 +34,6 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS import time - def chunking_by_token_size( content: str, split_by_character: Union[str, None] = None, @@ -1057,7 +1055,9 @@ async def _get_node_data( query_param: QueryParam, ): # get similar entities - logger.info(f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}") + logger.info( + f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}" + ) results = await entities_vdb.query(query, top_k=query_param.top_k) if not len(results): return "", "", "" @@ -1273,7 +1273,9 @@ async def _get_edge_data( text_chunks_db: BaseKVStorage, query_param: QueryParam, ): - logger.info(f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}") + logger.info( + f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}" + ) results = await relationships_vdb.query(keywords, top_k=query_param.top_k) if not len(results): From 76164a1b17a301e7e4eef8397262d799dc505456 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 04:52:54 +0800 Subject: [PATCH 29/35] Use namespace for graph_name before falling back to env or default value - Update graph_name initialization - Add namespace override support - Maintain backward compatibility - Prioritize namespace over env variable --- lightrag/kg/postgres_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index cb636d7f..377d5979 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -587,7 +587,7 @@ class PGGraphStorage(BaseGraphStorage): print("no preloading of graph with AGE in production") def __post_init__(self): - self.graph_name = os.environ.get("AGE_GRAPH_NAME", "lightrag") + self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } From 28b17b327b38e492c2e9fb3257310b01738d7a03 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 06:05:21 +0800 Subject: [PATCH 30/35] Fix: top_k param handling error, unify top_k and cosine default value. --- .env.example | 4 ++-- lightrag/api/README.md | 2 +- lightrag/api/lightrag_server.py | 18 ++++++++++++------ lightrag/api/ollama_api.py | 5 +++-- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/.env.example b/.env.example index 369bde4b..4b64ecb4 100644 --- a/.env.example +++ b/.env.example @@ -32,8 +32,8 @@ MAX_EMBED_TOKENS=8192 #HISTORY_TURNS=3 #CHUNK_SIZE=1200 #CHUNK_OVERLAP_SIZE=100 -#COSINE_THRESHOLD=0.4 # 0.2 while not running API server -#TOP_K=50 # 60 while not running API server +#COSINE_THRESHOLD=0.2 +#TOP_K=60 ### LLM Configuration (Use valid host. For local services, you can use host.docker.internal) ### Ollama example diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 9f5580fb..b68476d4 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -103,7 +103,7 @@ After starting the lightrag-server, you can add an Ollama-type connection in the LightRAG can be configured using either command-line arguments or environment variables. When both are provided, command-line arguments take precedence over environment variables. -For better performance, the API server's default values for TOP_K and COSINE_THRESHOLD are set to 50 and 0.4 respectively. If COSINE_THRESHOLD remains at its default value of 0.2 in LightRAG, many irrelevant entities and relations would be retrieved and sent to the LLM. +Default `TOP_K` is set to `60`. Default `COSINE_THRESHOLD` are set to `0.2`. ### Environment Variables diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b8e4f1e6..4a58abe2 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -530,13 +530,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--top-k", type=int, - default=get_env_value("TOP_K", 50, int), - help="Number of most similar results to return (default: from env or 50)", + default=get_env_value("TOP_K", 60, int), + help="Number of most similar results to return (default: from env or 60)", ) parser.add_argument( "--cosine-threshold", type=float, - default=get_env_value("COSINE_THRESHOLD", 0.4, float), + default=get_env_value("COSINE_THRESHOLD", 0.2, float), help="Cosine similarity threshold (default: from env or 0.4)", ) @@ -669,7 +669,13 @@ def get_api_key_dependency(api_key: Optional[str]): return api_key_auth +# Global configuration +global_top_k = 60 # default value + def create_app(args): + global global_top_k + global_top_k = args.top_k # save top_k from args + # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", @@ -1279,7 +1285,7 @@ def create_app(args): mode=request.mode, stream=request.stream, only_need_context=request.only_need_context, - top_k=args.top_k, + top_k=global_top_k, ), ) @@ -1321,7 +1327,7 @@ def create_app(args): mode=request.mode, stream=True, only_need_context=request.only_need_context, - top_k=args.top_k, + top_k=global_top_k, ), ) @@ -1611,7 +1617,7 @@ def create_app(args): return await rag.get_graps(nodel_label=label, max_depth=100) # Add Ollama API routes - ollama_api = OllamaAPI(rag) + ollama_api = OllamaAPI(rag, top_k=args.top_k) app.include_router(ollama_api.router, prefix="/api") @app.get("/documents", dependencies=[Depends(optional_api_key)]) diff --git a/lightrag/api/ollama_api.py b/lightrag/api/ollama_api.py index 0d96e16d..01a883ca 100644 --- a/lightrag/api/ollama_api.py +++ b/lightrag/api/ollama_api.py @@ -148,9 +148,10 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode]: class OllamaAPI: - def __init__(self, rag: LightRAG): + def __init__(self, rag: LightRAG, top_k: int = 60): self.rag = rag self.ollama_server_infos = ollama_server_infos + self.top_k = top_k self.router = APIRouter() self.setup_routes() @@ -381,7 +382,7 @@ class OllamaAPI: "stream": request.stream, "only_need_context": False, "conversation_history": conversation_history, - "top_k": self.rag.args.top_k if hasattr(self.rag, "args") else 50, + "top_k": self.top_k, } if ( From e5adb2e0f3331804b4aa50f463d04661abd94076 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 13:53:52 +0800 Subject: [PATCH 31/35] Improve cache logging and add more detailed log messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add cache type to log data structure • Make debug logs more detailed • Add high-level info logs for cache hits • Add null check for best_response • Improve log message readability --- lightrag/utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 28d9bfaa..f3b62d39 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -416,7 +416,7 @@ async def get_best_cached_response( if best_similarity > similarity_threshold: # If LLM check is enabled and all required parameters are provided - if use_llm_check and llm_func and original_prompt and best_prompt: + if use_llm_check and llm_func and original_prompt and best_prompt and best_response is not None: compare_prompt = PROMPTS["similarity_check"].format( original_prompt=original_prompt, cached_prompt=best_prompt ) @@ -430,7 +430,9 @@ async def get_best_cached_response( best_similarity = llm_similarity if best_similarity < similarity_threshold: log_data = { - "event": "llm_check_cache_rejected", + "event": "cache_rejected_by_llm", + "type": cache_type, + "mode": mode, "original_question": original_prompt[:100] + "..." if len(original_prompt) > 100 else original_prompt, @@ -440,7 +442,8 @@ async def get_best_cached_response( "similarity_score": round(best_similarity, 4), "threshold": similarity_threshold, } - logger.info(json.dumps(log_data, ensure_ascii=False)) + logger.debug(json.dumps(log_data, ensure_ascii=False)) + logger.info(f"Cache rejected by LLM(mode:{mode} tpye:{cache_type})") return None except Exception as e: # Catch all possible exceptions logger.warning(f"LLM similarity check failed: {e}") @@ -451,12 +454,13 @@ async def get_best_cached_response( ) log_data = { "event": "cache_hit", + "type": cache_type, "mode": mode, "similarity": round(best_similarity, 4), "cache_id": best_cache_id, "original_prompt": prompt_display, } - logger.info(json.dumps(log_data, ensure_ascii=False)) + logger.debug(json.dumps(log_data, ensure_ascii=False)) return best_response return None @@ -534,19 +538,24 @@ async def handle_cache( cache_type=cache_type, ) if best_cached_response is not None: + logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})") return best_cached_response, None, None, None else: + # if caching keyword embedding is enabled, return the quantized embedding for saving it latter + logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})") return None, quantized, min_val, max_val - # For default mode(extract_entities or naive query) or is_embedding_cache_enabled is False - # Use regular cache + # For default mode or is_embedding_cache_enabled is False, use regular cache + # default mode is for extract_entities or naive query if exists_func(hashing_kv, "get_by_mode_and_id"): mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {} else: mode_cache = await hashing_kv.get_by_id(mode) or {} if args_hash in mode_cache: + logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})") return mode_cache[args_hash]["return"], None, None, None + logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})") return None, None, None, None From cdd52809b0580ed2b251a62277e305a61f2a9599 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 14:07:36 +0800 Subject: [PATCH 32/35] Fix linting --- lightrag/api/lightrag_server.py | 3 ++- lightrag/utils.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 4a58abe2..5fff5851 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -672,10 +672,11 @@ def get_api_key_dependency(api_key: Optional[str]): # Global configuration global_top_k = 60 # default value + def create_app(args): global global_top_k global_top_k = args.top_k # save top_k from args - + # Verify that bindings are correctly setup if args.llm_binding not in [ "lollms", diff --git a/lightrag/utils.py b/lightrag/utils.py index f3b62d39..9df325ca 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -416,7 +416,13 @@ async def get_best_cached_response( if best_similarity > similarity_threshold: # If LLM check is enabled and all required parameters are provided - if use_llm_check and llm_func and original_prompt and best_prompt and best_response is not None: + if ( + use_llm_check + and llm_func + and original_prompt + and best_prompt + and best_response is not None + ): compare_prompt = PROMPTS["similarity_check"].format( original_prompt=original_prompt, cached_prompt=best_prompt ) From 5ad3555f4c252d18925b5d9a042ce9567fcff183 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 14:31:12 +0800 Subject: [PATCH 33/35] docs: add MongoDB storage support and improve storage client comment --- lightrag/api/README.md | 1 + lightrag/lightrag.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index b68476d4..485a8dfa 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -183,6 +183,7 @@ OracleVectorDBStorag Oracle ``` JsonDocStatusStorage JsonFile(default) PGDocStatusStorage Postgres +MongoDocStatusStorage MongoDB ``` #### How Select Storage Type diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 2544fe49..bfbc4c75 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -557,8 +557,8 @@ class LightRAG: return storage_class def set_storage_client(self, db_client): - # Inject db to storage implementation (only tested on Oracle Database - # Deprecated, seting correct value to *_storage creating LightRAG insteaded + # Deprecated, seting correct value to *_storage of LightRAG insteaded + # Inject db to storage implementation (only tested on Oracle Database) for storage in [ self.vector_db_storage_cls, self.graph_storage_cls, From 906a9f664148de7091749382f1f02cd44c2d5455 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 15:22:18 +0800 Subject: [PATCH 34/35] docs: improve organization and clarity of API documentation --- lightrag/api/README.md | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/lightrag/api/README.md b/lightrag/api/README.md index 485a8dfa..8e5a61d5 100644 --- a/lightrag/api/README.md +++ b/lightrag/api/README.md @@ -130,7 +130,9 @@ python lightrag.py --port 8080 PORT=7000 python lightrag.py ``` -#### Storage Types Supported +> Best practices: you can set your database setting in Config.ini while testing, and you use .env for production. + +### Storage Types Supported LightRAG uses 4 types of storage for difference purposes: @@ -186,12 +188,11 @@ PGDocStatusStorage Postgres MongoDocStatusStorage MongoDB ``` -#### How Select Storage Type +### How Select Storage Implementation -* Bye enviroment variables -* By command line arguments +You can select storage implementation by enviroment variables or command line arguments. You can not change storage implementation selection after you add documents to LightRAG. Data migration from one storage implementation to anthor is not supported yet. For further information please read the sample env file or config.ini file. -#### LightRag Server Options +### LightRag API Server Comand Line Options | Parameter | Default | Description | |-----------|---------|-------------| @@ -365,6 +366,14 @@ curl -X POST "http://localhost:9621/documents/scan" --max-time 1800 > Ajust max-time according to the estimated index time for all new files. +#### DELETE /documents + +Clear all documents from the RAG system. + +```bash +curl -X DELETE "http://localhost:9621/documents" +``` + ### Ollama Emulation Endpoints #### GET /api/version @@ -394,14 +403,6 @@ curl -N -X POST http://localhost:9621/api/chat -H "Content-Type: application/jso > For more information about Ollama API pls. visit : [Ollama API documentation](https://github.com/ollama/ollama/blob/main/docs/api.md) -#### DELETE /documents - -Clear all documents from the RAG system. - -```bash -curl -X DELETE "http://localhost:9621/documents" -``` - ### Utility Endpoints #### GET /health From fab33cfc91b325189788e164915ebab5a31ccec5 Mon Sep 17 00:00:00 2001 From: yangdx Date: Thu, 13 Feb 2025 19:33:15 +0800 Subject: [PATCH 35/35] chore: bump api version from 1.0.4 to 1.0.5 --- lightrag/api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index 9f0b3540..7eadf511 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "1.0.4" +__api_version__ = "1.0.5"