add concurrent embedding limit

This commit is contained in:
david
2024-12-10 09:01:21 +08:00
parent d0a4ef252e
commit f6eeedb050

View File

@@ -17,6 +17,17 @@ import tiktoken
from lightrag.prompt import PROMPTS from lightrag.prompt import PROMPTS
class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""
async def __aenter__(self):
pass
async def __aexit__(self, exc_type, exc, tb):
pass
ENCODER = None ENCODER = None
logger = logging.getLogger("lightrag") logger = logging.getLogger("lightrag")
@@ -42,9 +53,17 @@ class EmbeddingFunc:
embedding_dim: int embedding_dim: int
max_token_size: int max_token_size: int
func: callable func: callable
concurrent_limit: int = 16
def __post_init__(self):
if self.concurrent_limit != 0:
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
else:
self._semaphore = UnlimitedSemaphore()
async def __call__(self, *args, **kwargs) -> np.ndarray: async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs) async with self._semaphore:
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]: def locate_json_string_body_from_string(content: str) -> Union[str, None]: