Merge branch 'shmily1012/main'

This commit is contained in:
yangdx
2025-04-06 17:21:55 +08:00
2 changed files with 129 additions and 59 deletions

View File

@@ -3,13 +3,16 @@ from typing import Any, final
from dataclasses import dataclass from dataclasses import dataclass
import pipmaster as pm import pipmaster as pm
import configparser import configparser
from contextlib import asynccontextmanager
if not pm.is_installed("redis"): if not pm.is_installed("redis"):
pm.install("redis") pm.install("redis")
# aioredis is a depricated library, replaced with redis # aioredis is a depricated library, replaced with redis
from redis.asyncio import Redis # type: ignore from redis.asyncio import Redis, ConnectionPool
from lightrag.utils import logger from redis.exceptions import RedisError, ConnectionError
from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseKVStorage from lightrag.base import BaseKVStorage
import json import json
@@ -17,6 +20,11 @@ import json
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read("config.ini", "utf-8") config.read("config.ini", "utf-8")
# Constants for Redis connection pool
MAX_CONNECTIONS = 50
SOCKET_TIMEOUT = 5.0
SOCKET_CONNECT_TIMEOUT = 3.0
@final @final
@dataclass @dataclass
@@ -25,64 +33,115 @@ class RedisKVStorage(BaseKVStorage):
redis_url = os.environ.get( redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
) )
self._redis = Redis.from_url(redis_url, decode_responses=True) # Create a connection pool with limits
logger.info(f"Use Redis as KV {self.namespace}") self._pool = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT
)
self._redis = Redis(connection_pool=self._pool)
logger.info(f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections")
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
try:
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in {self.namespace}: {e}")
raise
except RedisError as e:
logger.error(f"Redis operation error in {self.namespace}: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in Redis operation for {self.namespace}: {e}")
raise
async def close(self):
"""Close the Redis connection pool to prevent resource leaks."""
if hasattr(self, '_redis') and self._redis:
await self._redis.close()
await self._pool.disconnect()
logger.debug(f"Closed Redis connection pool for {self.namespace}")
async def __aenter__(self):
"""Support for async context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure Redis resources are cleaned up when exiting context."""
await self.close()
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
data = await self._redis.get(f"{self.namespace}:{id}") async with self._get_redis_connection() as redis:
return json.loads(data) if data else None try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
pipe = self._redis.pipeline() async with self._get_redis_connection() as redis:
for id in ids: try:
pipe.get(f"{self.namespace}:{id}") pipe = redis.pipeline()
results = await pipe.execute() for id in ids:
return [json.loads(result) if result else None for result in results] pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids)
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
pipe = self._redis.pipeline() async with self._get_redis_connection() as redis:
for key in keys: pipe = redis.pipeline()
pipe.exists(f"{self.namespace}:{key}") for key in keys:
results = await pipe.execute() pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys[i] for i, exists in enumerate(results) if exists} existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data: if not data:
return return
pipe = self._redis.pipeline()
for k, v in data.items(): logger.info(f"Inserting {len(data)} items to {self.namespace}")
pipe.set(f"{self.namespace}:{k}", json.dumps(v)) async with self._get_redis_connection() as redis:
await pipe.execute() try:
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k in data: for k in data:
data[k]["_id"] = k data[k]["_id"] = k
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# Redis handles persistence automatically # Redis handles persistence automatically
pass pass
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs """Delete entries with specified IDs"""
Args:
ids: List of entry IDs to be deleted
"""
if not ids: if not ids:
return return
pipe = self._redis.pipeline() async with self._get_redis_connection() as redis:
for id in ids: pipe = redis.pipeline()
pipe.delete(f"{self.namespace}:{id}") for id in ids:
pipe.delete(f"{self.namespace}:{id}")
results = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results) deleted_count = sum(results)
logger.info( logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
) )
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode """Delete specific records from storage by by cache mode
@@ -112,22 +171,24 @@ class RedisKVStorage(BaseKVStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
try: async with self._get_redis_connection() as redis:
keys = await self._redis.keys(f"{self.namespace}:*") try:
keys = await redis.keys(f"{self.namespace}:*")
if keys: if keys:
pipe = self._redis.pipeline() pipe = redis.pipeline()
for key in keys: for key in keys:
pipe.delete(key) pipe.delete(key)
results = await pipe.execute() results = await pipe.execute()
deleted_count = sum(results) deleted_count = sum(results)
logger.info(f"Dropped {deleted_count} keys from {self.namespace}") logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
return {"status": "success", "message": f"{deleted_count} keys dropped"} return {"status": "success", "message": f"{deleted_count} keys dropped"}
else: else:
logger.info(f"No keys found to drop in {self.namespace}") logger.info(f"No keys found to drop in {self.namespace}")
return {"status": "success", "message": "no keys to drop"} return {"status": "success", "message": "no keys to drop"}
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import traceback
import json import json
import re import re
import os import os
@@ -1002,6 +1003,7 @@ async def mix_kg_vector_query(
except Exception as e: except Exception as e:
logger.error(f"Error in get_kg_context: {str(e)}") logger.error(f"Error in get_kg_context: {str(e)}")
traceback.print_exc()
return None return None
async def get_vector_context(): async def get_vector_context():
@@ -1392,9 +1394,16 @@ async def _find_most_related_text_unit_from_entities(
all_text_units_lookup[c_id] = index all_text_units_lookup[c_id] = index
tasks.append((c_id, index, this_edges)) tasks.append((c_id, index, this_edges))
results = await asyncio.gather( # Process in batches of 25 tasks at a time to avoid overwhelming resources
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in tasks] batch_size = 25
) results = []
for i in range(0, len(tasks), batch_size):
batch_tasks = tasks[i:i + batch_size]
batch_results = await asyncio.gather(
*[text_chunks_db.get_by_id(c_id) for c_id, _, _ in batch_tasks]
)
results.extend(batch_results)
for (c_id, index, this_edges), data in zip(tasks, results): for (c_id, index, this_edges), data in zip(tasks, results):
all_text_units_lookup[c_id] = { all_text_units_lookup[c_id] = {