Merge pull request #444 from davidleon/fix/lazy_import

Fix/lazy import
This commit is contained in:
zrguo
2024-12-11 14:19:48 +08:00
committed by GitHub
4 changed files with 44 additions and 11 deletions

View File

@@ -48,18 +48,25 @@ from .storage import (
def lazy_external_import(module_name: str, class_name: str):
"""Lazily import an external module and return a class from it."""
"""Lazily import a class from an external module based on the package of the caller."""
def import_class():
# Get the caller's module and package
import inspect
caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None
def import_class(*args, **kwargs):
import importlib
# Import the module using importlib
module = importlib.import_module(module_name)
module = importlib.import_module(module_name, package=package)
# Get the class from the module
return getattr(module, class_name)
# Get the class from the module and instantiate it
cls = getattr(module, class_name)
return cls(*args, **kwargs)
# Return the import_class function itself, not its result
return import_class

View File

@@ -64,6 +64,7 @@ async def openai_complete_if_cache(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})

View File

@@ -107,10 +107,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
embeddings = await f
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])

View File

@@ -17,6 +17,17 @@ import tiktoken
from lightrag.prompt import PROMPTS
class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""
async def __aenter__(self):
pass
async def __aexit__(self, exc_type, exc, tb):
pass
ENCODER = None
logger = logging.getLogger("lightrag")
@@ -42,9 +53,17 @@ class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
concurrent_limit: int = 16
def __post_init__(self):
if self.concurrent_limit != 0:
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
else:
self._semaphore = UnlimitedSemaphore()
async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
async with self._semaphore:
return await self.func(*args, **kwargs)
def locate_json_string_body_from_string(content: str) -> Union[str, None]: