diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 833926e5..3a4276cb 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -48,18 +48,25 @@ from .storage import ( def lazy_external_import(module_name: str, class_name: str): - """Lazily import an external module and return a class from it.""" + """Lazily import a class from an external module based on the package of the caller.""" - def import_class(): + # Get the caller's module and package + import inspect + + caller_frame = inspect.currentframe().f_back + module = inspect.getmodule(caller_frame) + package = module.__package__ if module else None + + def import_class(*args, **kwargs): import importlib # Import the module using importlib - module = importlib.import_module(module_name) + module = importlib.import_module(module_name, package=package) - # Get the class from the module - return getattr(module, class_name) + # Get the class from the module and instantiate it + cls = getattr(module, class_name) + return cls(*args, **kwargs) - # Return the import_class function itself, not its result return import_class diff --git a/lightrag/llm.py b/lightrag/llm.py index f3fed23f..636f03cb 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -64,6 +64,7 @@ async def openai_complete_if_cache( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) kwargs.pop("hashing_kv", None) + kwargs.pop("keyword_extraction", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) diff --git a/lightrag/storage.py b/lightrag/storage.py index 007d6534..4c043893 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -107,10 +107,16 @@ class NanoVectorDBStorage(BaseVectorStorage): embeddings = await f embeddings_list.append(embeddings) embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) - return results + if len(embeddings) == len(list_data): + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + results = self._client.upsert(datas=list_data) + return results + else: + # sometimes the embedding is not returned correctly. just log it. + logger.error( + f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}" + ) async def query(self, query: str, top_k=5): embedding = await self.embedding_func([query]) diff --git a/lightrag/utils.py b/lightrag/utils.py index 0220af06..bdb47592 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -17,6 +17,17 @@ import tiktoken from lightrag.prompt import PROMPTS + +class UnlimitedSemaphore: + """A context manager that allows unlimited access.""" + + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc, tb): + pass + + ENCODER = None logger = logging.getLogger("lightrag") @@ -42,9 +53,17 @@ class EmbeddingFunc: embedding_dim: int max_token_size: int func: callable + concurrent_limit: int = 16 + + def __post_init__(self): + if self.concurrent_limit != 0: + self._semaphore = asyncio.Semaphore(self.concurrent_limit) + else: + self._semaphore = UnlimitedSemaphore() async def __call__(self, *args, **kwargs) -> np.ndarray: - return await self.func(*args, **kwargs) + async with self._semaphore: + return await self.func(*args, **kwargs) def locate_json_string_body_from_string(content: str) -> Union[str, None]: