Fix 'TOO MANY OPEN FILE' problem while using redis vector DB:

Enhance RedisKVStorage: Implement connection pooling and error handling. Refactor async methods to use context managers for Redis operations, improving resource management and error logging. Batch processing added for key operations to optimize performance.
This commit is contained in:
Alex Z
2025-04-02 21:06:49 -07:00
parent 7a67f6c2fd
commit d0d246bef8
2 changed files with 156 additions and 88 deletions

View File

@@ -3,12 +3,14 @@ from typing import Any, final
from dataclasses import dataclass
import pipmaster as pm
import configparser
from contextlib import asynccontextmanager
if not pm.is_installed("redis"):
pm.install("redis")
# aioredis is a depricated library, replaced with redis
from redis.asyncio import Redis
from redis.asyncio import Redis, ConnectionPool
from redis.exceptions import RedisError, ConnectionError
from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseKVStorage
import json
@@ -17,6 +19,11 @@ import json
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Constants for Redis connection pool
MAX_CONNECTIONS = 50
SOCKET_TIMEOUT = 5.0
SOCKET_CONNECT_TIMEOUT = 3.0
@final
@dataclass
@@ -25,125 +32,177 @@ class RedisKVStorage(BaseKVStorage):
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
self._redis = Redis.from_url(redis_url, decode_responses=True)
logger.info(f"Use Redis as KV {self.namespace}")
# Create a connection pool with limits
self._pool = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT
)
self._redis = Redis(connection_pool=self._pool)
logger.info(f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections")
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
try:
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in {self.namespace}: {e}")
raise
except RedisError as e:
logger.error(f"Redis operation error in {self.namespace}: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in Redis operation for {self.namespace}: {e}")
raise
async def close(self):
"""Close the Redis connection pool to prevent resource leaks."""
if hasattr(self, '_redis') and self._redis:
await self._redis.close()
await self._pool.disconnect()
logger.debug(f"Closed Redis connection pool for {self.namespace}")
async def __aenter__(self):
"""Support for async context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure Redis resources are cleaned up when exiting context."""
await self.close()
async def get_by_id(self, id: str) -> dict[str, Any] | None:
data = await self._redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
pipe = self._redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids)
async def filter_keys(self, keys: set[str]) -> set[str]:
pipe = self._redis.pipeline()
for key in keys:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for key in keys:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
pipe = self._redis.pipeline()
logger.info(f"Inserting {len(data)} items to {self.namespace}")
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k in data:
data[k]["_id"] = k
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass
for k in data:
data[k]["_id"] = k
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs
Args:
ids: List of entry IDs to be deleted
"""
"""Delete entries with specified IDs"""
if not ids:
return
pipe = self._redis.pipeline()
for id in ids:
pipe.delete(f"{self.namespace}:{id}")
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for id in ids:
pipe.delete(f"{self.namespace}:{id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by name
Args:
entity_name: Name of the entity to delete
"""
"""Delete an entity by name"""
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Delete the entity
result = await self._redis.delete(f"{self.namespace}:{entity_id}")
async with self._get_redis_connection() as redis:
result = await redis.delete(f"{self.namespace}:{entity_id}")
if result:
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
if result:
logger.debug(f"Successfully deleted entity {entity_name}")
else:
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity
Args:
entity_name: Name of the entity whose relations should be deleted
"""
"""Delete all relations associated with an entity"""
try:
# Get all keys in this namespace
cursor = 0
relation_keys = []
pattern = f"{self.namespace}:*"
async with self._get_redis_connection() as redis:
cursor = 0
relation_keys = []
pattern = f"{self.namespace}:*"
while True:
cursor, keys = await self._redis.scan(cursor, match=pattern)
while True:
cursor, keys = await redis.scan(cursor, match=pattern)
# Process keys in batches
pipe = redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
try:
data = json.loads(value)
if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key)
except json.JSONDecodeError:
logger.warning(f"Invalid JSON in key {key}")
continue
# For each key, get the value and check if it's related to entity_name
for key in keys:
value = await self._redis.get(key)
if value:
data = json.loads(value)
# Check if this is a relation involving the entity
if (
data.get("src_id") == entity_name
or data.get("tgt_id") == entity_name
):
relation_keys.append(key)
if cursor == 0:
break
# Exit loop when cursor returns to 0
if cursor == 0:
break
# Delete the relation keys
if relation_keys:
deleted = await self._redis.delete(*relation_keys)
logger.debug(f"Deleted {deleted} relations for {entity_name}")
else:
logger.debug(f"No relations found for entity {entity_name}")
# Delete relations in batches
if relation_keys:
# Delete in chunks to avoid too many arguments
chunk_size = 1000
for i in range(0, len(relation_keys), chunk_size):
chunk = relation_keys[i:i + chunk_size]
deleted = await redis.delete(*chunk)
logger.debug(f"Deleted {deleted} relations for {entity_name} (batch {i//chunk_size + 1})")
else:
logger.debug(f"No relations found for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import traceback
import json
import re
import os
@@ -994,6 +995,7 @@ async def mix_kg_vector_query(
except Exception as e:
logger.error(f"Error in get_kg_context: {str(e)}")
traceback.print_exc()
return None
async def get_vector_context():
@@ -1382,9 +1384,16 @@ async def _find_most_related_text_unit_from_entities(
all_text_units_lookup[c_id] = index
tasks.append((c_id, index, this_edges))
results = await asyncio.gather(
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks]
)
# Process in batches of 25 tasks at a time to avoid overwhelming resources
batch_size = 25
results = []
for i in range(0, len(tasks), batch_size):
batch_tasks = tasks[i:i + batch_size]
batch_results = await asyncio.gather(
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks]
)
results.extend(batch_results)
for (c_id, index, this_edges), data in zip(tasks, results):
all_text_units_lookup[c_id] = {