From f6eeedb050a9bdf7d2c598634c76fa152de8e69b Mon Sep 17 00:00:00 2001 From: david Date: Tue, 10 Dec 2024 09:01:21 +0800 Subject: [PATCH] add concurrent embedding limit --- lightrag/utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 0220af06..bdb47592 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -17,6 +17,17 @@ import tiktoken from lightrag.prompt import PROMPTS + +class UnlimitedSemaphore: + """A context manager that allows unlimited access.""" + + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc, tb): + pass + + ENCODER = None logger = logging.getLogger("lightrag") @@ -42,9 +53,17 @@ class EmbeddingFunc: embedding_dim: int max_token_size: int func: callable + concurrent_limit: int = 16 + + def __post_init__(self): + if self.concurrent_limit != 0: + self._semaphore = asyncio.Semaphore(self.concurrent_limit) + else: + self._semaphore = UnlimitedSemaphore() async def __call__(self, *args, **kwargs) -> np.ndarray: - return await self.func(*args, **kwargs) + async with self._semaphore: + return await self.func(*args, **kwargs) def locate_json_string_body_from_string(content: str) -> Union[str, None]: