add concurrent embedding limit
This commit is contained in:
@@ -17,6 +17,17 @@ import tiktoken
|
|||||||
|
|
||||||
from lightrag.prompt import PROMPTS
|
from lightrag.prompt import PROMPTS
|
||||||
|
|
||||||
|
|
||||||
|
class UnlimitedSemaphore:
|
||||||
|
"""A context manager that allows unlimited access."""
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
ENCODER = None
|
ENCODER = None
|
||||||
|
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
@@ -42,8 +53,16 @@ class EmbeddingFunc:
|
|||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
max_token_size: int
|
max_token_size: int
|
||||||
func: callable
|
func: callable
|
||||||
|
concurrent_limit: int = 16
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.concurrent_limit != 0:
|
||||||
|
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
||||||
|
else:
|
||||||
|
self._semaphore = UnlimitedSemaphore()
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
|
async with self._semaphore:
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user