From 9717ad87fcaf591ea27d6b8c23912dafa45944ac Mon Sep 17 00:00:00 2001 From: david Date: Mon, 9 Dec 2024 15:35:35 +0800 Subject: [PATCH] fix extra kwargs error: keyword_extraction. add lazy_external_load to reduce external lib deps whenever it's not necessary for user. --- lightrag/lightrag.py | 32 ++++++++++++++++++++++++-------- lightrag/llm.py | 2 ++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 0a44187e..de5befa4 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -40,14 +40,6 @@ from .storage import ( NetworkXStorage, ) -from .kg.neo4j_impl import Neo4JStorage - -from .kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage - -from .kg.milvus_impl import MilvusVectorDBStorge - -from .kg.mongo_impl import MongoKVStorage - # future KG integrations # from .kg.ArangoDB_impl import ( @@ -55,6 +47,30 @@ from .kg.mongo_impl import MongoKVStorage # ) +def lazy_external_import(module_name: str, class_name: str): + """Lazily import an external module and return a class from it.""" + + def import_class(): + import importlib + + # Import the module using importlib + module = importlib.import_module(module_name) + + # Get the class from the module + return getattr(module, class_name) + + # Return the import_class function itself, not its result + return import_class + + +Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage") +OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage") +OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage") +OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage") +MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge") +MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage") + + def always_get_an_event_loop() -> asyncio.AbstractEventLoop: """ Ensure that there is always an event loop available. diff --git a/lightrag/llm.py b/lightrag/llm.py index 0f8b6ef8..678db48e 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -1074,6 +1074,8 @@ class MultiModel: self, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: kwargs.pop("model", None) # stop from overwriting the custom model name + kwargs.pop("keyword_extraction", None) + kwargs.pop("mode", None) next_model = self._next_model() args = dict( prompt=prompt,