set kg by start param, defaults to networkx
This commit is contained in:
@@ -25,7 +25,7 @@ from .storage import (
|
||||
NetworkXStorage,
|
||||
)
|
||||
|
||||
from .kg.neo4j import (
|
||||
from .kg.neo4j_impl import (
|
||||
GraphStorage as Neo4JStorage
|
||||
)
|
||||
|
||||
@@ -58,10 +58,14 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
|
||||
@dataclass
|
||||
class LightRAG:
|
||||
|
||||
working_dir: str = field(
|
||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
|
||||
kg: str = field(default="NetworkXStorage")
|
||||
|
||||
|
||||
# text chunking
|
||||
chunk_token_size: int = 1200
|
||||
chunk_overlap_token_size: int = 100
|
||||
@@ -99,20 +103,15 @@ class LightRAG:
|
||||
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
||||
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
||||
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# module = importlib.import_module('kg.neo4j')
|
||||
# Neo4JStorage = getattr(module, 'GraphStorage')
|
||||
if True==True:
|
||||
print ("using KG")
|
||||
graph_storage_cls: Type[BaseGraphStorage] = Neo4JStorage
|
||||
else:
|
||||
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
||||
enable_llm_cache: bool = True
|
||||
|
||||
# extension
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
convert_response_to_json_func: callable = convert_response_to_json
|
||||
|
||||
# def get_configured_KG(self):
|
||||
# return self.kg
|
||||
|
||||
def __post_init__(self):
|
||||
log_file = os.path.join(self.working_dir, "lightrag.log")
|
||||
set_logger(log_file)
|
||||
@@ -121,6 +120,10 @@ class LightRAG:
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
#should move all storage setup here to leverage initial start params attached to self.
|
||||
print (f"self.kg set to: {self.kg}")
|
||||
self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
|
||||
|
||||
if not os.path.exists(self.working_dir):
|
||||
logger.info(f"Creating working directory {self.working_dir}")
|
||||
os.makedirs(self.working_dir)
|
||||
@@ -169,6 +172,12 @@ class LightRAG:
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
||||
)
|
||||
def _get_storage_class(self) -> Type[BaseGraphStorage]:
|
||||
return {
|
||||
"Neo4JStorage": Neo4JStorage,
|
||||
"NetworkXStorage": NetworkXStorage,
|
||||
# "new_kg_here": KGClass
|
||||
}
|
||||
|
||||
def insert(self, string_or_strings):
|
||||
loop = always_get_an_event_loop()
|
||||
|
Reference in New Issue
Block a user