From d0d246bef86ae33c75573caa89ca1149de82fd35 Mon Sep 17 00:00:00 2001 From: Alex Z Date: Wed, 2 Apr 2025 21:06:49 -0700 Subject: [PATCH 1/4] Fix 'TOO MANY OPEN FILE' problem while using redis vector DB: Enhance RedisKVStorage: Implement connection pooling and error handling. Refactor async methods to use context managers for Redis operations, improving resource management and error logging. Batch processing added for key operations to optimize performance. --- lightrag/kg/redis_impl.py | 229 ++++++++++++++++++++++++-------------- lightrag/operate.py | 15 ++- 2 files changed, 156 insertions(+), 88 deletions(-) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 3feb4985..01842a67 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -3,12 +3,14 @@ from typing import Any, final from dataclasses import dataclass import pipmaster as pm import configparser +from contextlib import asynccontextmanager if not pm.is_installed("redis"): pm.install("redis") # aioredis is a depricated library, replaced with redis -from redis.asyncio import Redis +from redis.asyncio import Redis, ConnectionPool +from redis.exceptions import RedisError, ConnectionError from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseKVStorage import json @@ -17,6 +19,11 @@ import json config = configparser.ConfigParser() config.read("config.ini", "utf-8") +# Constants for Redis connection pool +MAX_CONNECTIONS = 50 +SOCKET_TIMEOUT = 5.0 +SOCKET_CONNECT_TIMEOUT = 3.0 + @final @dataclass @@ -25,125 +32,177 @@ class RedisKVStorage(BaseKVStorage): redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") ) - self._redis = Redis.from_url(redis_url, decode_responses=True) - logger.info(f"Use Redis as KV {self.namespace}") + # Create a connection pool with limits + self._pool = ConnectionPool.from_url( + redis_url, + max_connections=MAX_CONNECTIONS, + decode_responses=True, + socket_timeout=SOCKET_TIMEOUT, + socket_connect_timeout=SOCKET_CONNECT_TIMEOUT + ) + self._redis = Redis(connection_pool=self._pool) + logger.info(f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections") + + @asynccontextmanager + async def _get_redis_connection(self): + """Safe context manager for Redis operations.""" + try: + yield self._redis + except ConnectionError as e: + logger.error(f"Redis connection error in {self.namespace}: {e}") + raise + except RedisError as e: + logger.error(f"Redis operation error in {self.namespace}: {e}") + raise + except Exception as e: + logger.error(f"Unexpected error in Redis operation for {self.namespace}: {e}") + raise + + async def close(self): + """Close the Redis connection pool to prevent resource leaks.""" + if hasattr(self, '_redis') and self._redis: + await self._redis.close() + await self._pool.disconnect() + logger.debug(f"Closed Redis connection pool for {self.namespace}") + + async def __aenter__(self): + """Support for async context manager.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Ensure Redis resources are cleaned up when exiting context.""" + await self.close() async def get_by_id(self, id: str) -> dict[str, Any] | None: - data = await self._redis.get(f"{self.namespace}:{id}") - return json.loads(data) if data else None + async with self._get_redis_connection() as redis: + try: + data = await redis.get(f"{self.namespace}:{id}") + return json.loads(data) if data else None + except json.JSONDecodeError as e: + logger.error(f"JSON decode error for id {id}: {e}") + return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - pipe = self._redis.pipeline() - for id in ids: - pipe.get(f"{self.namespace}:{id}") - results = await pipe.execute() - return [json.loads(result) if result else None for result in results] + async with self._get_redis_connection() as redis: + try: + pipe = redis.pipeline() + for id in ids: + pipe.get(f"{self.namespace}:{id}") + results = await pipe.execute() + return [json.loads(result) if result else None for result in results] + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in batch get: {e}") + return [None] * len(ids) async def filter_keys(self, keys: set[str]) -> set[str]: - pipe = self._redis.pipeline() - for key in keys: - pipe.exists(f"{self.namespace}:{key}") - results = await pipe.execute() + async with self._get_redis_connection() as redis: + pipe = redis.pipeline() + for key in keys: + pipe.exists(f"{self.namespace}:{key}") + results = await pipe.execute() - existing_ids = {keys[i] for i, exists in enumerate(results) if exists} - return set(keys) - existing_ids + existing_ids = {keys[i] for i, exists in enumerate(results) if exists} + return set(keys) - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return - pipe = self._redis.pipeline() + + logger.info(f"Inserting {len(data)} items to {self.namespace}") + async with self._get_redis_connection() as redis: + try: + pipe = redis.pipeline() + for k, v in data.items(): + pipe.set(f"{self.namespace}:{k}", json.dumps(v)) + await pipe.execute() - for k, v in data.items(): - pipe.set(f"{self.namespace}:{k}", json.dumps(v)) - await pipe.execute() - - for k in data: - data[k]["_id"] = k - - async def index_done_callback(self) -> None: - # Redis handles persistence automatically - pass + for k in data: + data[k]["_id"] = k + except json.JSONEncodeError as e: + logger.error(f"JSON encode error during upsert: {e}") + raise async def delete(self, ids: list[str]) -> None: - """Delete entries with specified IDs - - Args: - ids: List of entry IDs to be deleted - """ + """Delete entries with specified IDs""" if not ids: return - pipe = self._redis.pipeline() - for id in ids: - pipe.delete(f"{self.namespace}:{id}") + async with self._get_redis_connection() as redis: + pipe = redis.pipeline() + for id in ids: + pipe.delete(f"{self.namespace}:{id}") - results = await pipe.execute() - deleted_count = sum(results) - logger.info( - f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" - ) + results = await pipe.execute() + deleted_count = sum(results) + logger.info( + f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" + ) async def delete_entity(self, entity_name: str) -> None: - """Delete an entity by name - - Args: - entity_name: Name of the entity to delete - """ - + """Delete an entity by name""" try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( f"Attempting to delete entity {entity_name} with ID {entity_id}" ) - # Delete the entity - result = await self._redis.delete(f"{self.namespace}:{entity_id}") + async with self._get_redis_connection() as redis: + result = await redis.delete(f"{self.namespace}:{entity_id}") - if result: - logger.debug(f"Successfully deleted entity {entity_name}") - else: - logger.debug(f"Entity {entity_name} not found in storage") + if result: + logger.debug(f"Successfully deleted entity {entity_name}") + else: + logger.debug(f"Entity {entity_name} not found in storage") except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: - """Delete all relations associated with an entity - - Args: - entity_name: Name of the entity whose relations should be deleted - """ + """Delete all relations associated with an entity""" try: - # Get all keys in this namespace - cursor = 0 - relation_keys = [] - pattern = f"{self.namespace}:*" + async with self._get_redis_connection() as redis: + cursor = 0 + relation_keys = [] + pattern = f"{self.namespace}:*" - while True: - cursor, keys = await self._redis.scan(cursor, match=pattern) + while True: + cursor, keys = await redis.scan(cursor, match=pattern) + + # Process keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + for key, value in zip(keys, values): + if value: + try: + data = json.loads(value) + if ( + data.get("src_id") == entity_name + or data.get("tgt_id") == entity_name + ): + relation_keys.append(key) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON in key {key}") + continue - # For each key, get the value and check if it's related to entity_name - for key in keys: - value = await self._redis.get(key) - if value: - data = json.loads(value) - # Check if this is a relation involving the entity - if ( - data.get("src_id") == entity_name - or data.get("tgt_id") == entity_name - ): - relation_keys.append(key) + if cursor == 0: + break - # Exit loop when cursor returns to 0 - if cursor == 0: - break - - # Delete the relation keys - if relation_keys: - deleted = await self._redis.delete(*relation_keys) - logger.debug(f"Deleted {deleted} relations for {entity_name}") - else: - logger.debug(f"No relations found for entity {entity_name}") + # Delete relations in batches + if relation_keys: + # Delete in chunks to avoid too many arguments + chunk_size = 1000 + for i in range(0, len(relation_keys), chunk_size): + chunk = relation_keys[i:i + chunk_size] + deleted = await redis.delete(*chunk) + logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})") + else: + logger.debug(f"No relations found for entity {entity_name}") except Exception as e: logger.error(f"Error deleting relations for {entity_name}: {e}") + + async def index_done_callback(self) -> None: + # Redis handles persistence automatically + pass diff --git a/lightrag/operate.py b/lightrag/operate.py index dcf833c2..af17f8bf 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import traceback import json import re import os @@ -994,6 +995,7 @@ async def mix_kg_vector_query( except Exception as e: logger.error(f"Error in get_kg_context: {str(e)}") + traceback.print_exc() return None async def get_vector_context(): @@ -1382,9 +1384,16 @@ async def _find_most_related_text_unit_from_entities( all_text_units_lookup[c_id] = index tasks.append((c_id, index, this_edges)) - results = await asyncio.gather( - *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks] - ) + # Process in batches of 25 tasks at a time to avoid overwhelming resources + batch_size = 25 + results = [] + + for i in range(0, len(tasks), batch_size): + batch_tasks = tasks[i:i + batch_size] + batch_results = await asyncio.gather( + *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks] + ) + results.extend(batch_results) for (c_id, index, this_edges), data in zip(tasks, results): all_text_units_lookup[c_id] = { From b45c5f930472f63a4ebee15aeb695a125562b5e7 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 6 Apr 2025 17:42:13 +0800 Subject: [PATCH 2/4] Change get_by_id batch size from 25 to 5 to reserve db connection resouces --- lightrag/kg/redis_impl.py | 6 +++--- lightrag/operate.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 38a12ab5..db343dee 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -9,9 +9,9 @@ if not pm.is_installed("redis"): pm.install("redis") # aioredis is a depricated library, replaced with redis -from redis.asyncio import Redis, ConnectionPool -from redis.exceptions import RedisError, ConnectionError -from lightrag.utils import logger, compute_mdhash_id +from redis.asyncio import Redis, ConnectionPool # type: ignore +from redis.exceptions import RedisError, ConnectionError # type: ignore +from lightrag.utils import logger from lightrag.base import BaseKVStorage import json diff --git a/lightrag/operate.py b/lightrag/operate.py index 97eccca0..1e023004 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1395,7 +1395,7 @@ async def _find_most_related_text_unit_from_entities( tasks.append((c_id, index, this_edges)) # Process in batches of 25 tasks at a time to avoid overwhelming resources - batch_size = 25 + batch_size = 5 results = [] for i in range(0, len(tasks), batch_size): From f1ee478cfb5903c58cf37ff0922bb74fd7e386b9 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 6 Apr 2025 17:42:58 +0800 Subject: [PATCH 3/4] Bump api version to 0137 --- lightrag/api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py index 75eb6b64..bf049174 100644 --- a/lightrag/api/__init__.py +++ b/lightrag/api/__init__.py @@ -1 +1 @@ -__api_version__ = "0136" +__api_version__ = "0137" From b2284c8b9d5a0186fb606c2a5fab53dff1b1f587 Mon Sep 17 00:00:00 2001 From: yangdx Date: Sun, 6 Apr 2025 17:45:32 +0800 Subject: [PATCH 4/4] Fix linting --- lightrag/kg/redis_impl.py | 28 +++++++++++++++++----------- lightrag/operate.py | 4 ++-- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index db343dee..65c25bfc 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -9,8 +9,8 @@ if not pm.is_installed("redis"): pm.install("redis") # aioredis is a depricated library, replaced with redis -from redis.asyncio import Redis, ConnectionPool # type: ignore -from redis.exceptions import RedisError, ConnectionError # type: ignore +from redis.asyncio import Redis, ConnectionPool # type: ignore +from redis.exceptions import RedisError, ConnectionError # type: ignore from lightrag.utils import logger from lightrag.base import BaseKVStorage @@ -39,10 +39,12 @@ class RedisKVStorage(BaseKVStorage): max_connections=MAX_CONNECTIONS, decode_responses=True, socket_timeout=SOCKET_TIMEOUT, - socket_connect_timeout=SOCKET_CONNECT_TIMEOUT + socket_connect_timeout=SOCKET_CONNECT_TIMEOUT, ) self._redis = Redis(connection_pool=self._pool) - logger.info(f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections") + logger.info( + f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections" + ) @asynccontextmanager async def _get_redis_connection(self): @@ -56,12 +58,14 @@ class RedisKVStorage(BaseKVStorage): logger.error(f"Redis operation error in {self.namespace}: {e}") raise except Exception as e: - logger.error(f"Unexpected error in Redis operation for {self.namespace}: {e}") + logger.error( + f"Unexpected error in Redis operation for {self.namespace}: {e}" + ) raise async def close(self): """Close the Redis connection pool to prevent resource leaks.""" - if hasattr(self, '_redis') and self._redis: + if hasattr(self, "_redis") and self._redis: await self._redis.close() await self._pool.disconnect() logger.debug(f"Closed Redis connection pool for {self.namespace}") @@ -108,7 +112,7 @@ class RedisKVStorage(BaseKVStorage): async def upsert(self, data: dict[str, dict[str, Any]]) -> None: if not data: return - + logger.info(f"Inserting {len(data)} items to {self.namespace}") async with self._get_redis_connection() as redis: try: @@ -122,11 +126,11 @@ class RedisKVStorage(BaseKVStorage): except json.JSONEncodeError as e: logger.error(f"JSON encode error during upsert: {e}") raise - + async def index_done_callback(self) -> None: # Redis handles persistence automatically pass - + async def delete(self, ids: list[str]) -> None: """Delete entries with specified IDs""" if not ids: @@ -183,7 +187,10 @@ class RedisKVStorage(BaseKVStorage): deleted_count = sum(results) logger.info(f"Dropped {deleted_count} keys from {self.namespace}") - return {"status": "success", "message": f"{deleted_count} keys dropped"} + return { + "status": "success", + "message": f"{deleted_count} keys dropped", + } else: logger.info(f"No keys found to drop in {self.namespace}") return {"status": "success", "message": "no keys to drop"} @@ -191,4 +198,3 @@ class RedisKVStorage(BaseKVStorage): except Exception as e: logger.error(f"Error dropping keys from {self.namespace}: {e}") return {"status": "error", "message": str(e)} - diff --git a/lightrag/operate.py b/lightrag/operate.py index 1e023004..0e223bb6 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1397,9 +1397,9 @@ async def _find_most_related_text_unit_from_entities( # Process in batches of 25 tasks at a time to avoid overwhelming resources batch_size = 5 results = [] - + for i in range(0, len(tasks), batch_size): - batch_tasks = tasks[i:i + batch_size] + batch_tasks = tasks[i : i + batch_size] batch_results = await asyncio.gather( *[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks] )