chore: added pre-commit-hooks and ruff formatting for commit-hooks
This commit is contained in:
@@ -3,10 +3,12 @@ import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Type, cast, Any
|
||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import Type, cast
|
||||
|
||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
|
||||
from .llm import (
|
||||
gpt_4o_mini_complete,
|
||||
openai_embedding,
|
||||
)
|
||||
from .operate import (
|
||||
chunking_by_token_size,
|
||||
extract_entities,
|
||||
@@ -37,6 +39,7 @@ from .base import (
|
||||
QueryParam,
|
||||
)
|
||||
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -69,7 +72,6 @@ class LightRAG:
|
||||
"dimensions": 1536,
|
||||
"num_walks": 10,
|
||||
"walk_length": 40,
|
||||
"num_walks": 10,
|
||||
"window_size": 2,
|
||||
"iterations": 3,
|
||||
"random_seed": 3,
|
||||
@@ -77,13 +79,13 @@ class LightRAG:
|
||||
)
|
||||
|
||||
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
|
||||
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
||||
embedding_batch_num: int = 32
|
||||
embedding_func_max_async: int = 16
|
||||
|
||||
# LLM
|
||||
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
|
||||
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
||||
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
||||
llm_model_max_token_size: int = 32768
|
||||
llm_model_max_async: int = 16
|
||||
|
||||
@@ -98,11 +100,11 @@ class LightRAG:
|
||||
addon_params: dict = field(default_factory=dict)
|
||||
convert_response_to_json_func: callable = convert_response_to_json
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self):
|
||||
log_file = os.path.join(self.working_dir, "lightrag.log")
|
||||
set_logger(log_file)
|
||||
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
||||
|
||||
|
||||
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
||||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
@@ -133,30 +135,24 @@ class LightRAG:
|
||||
self.embedding_func
|
||||
)
|
||||
|
||||
self.entities_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"}
|
||||
)
|
||||
self.entities_vdb = self.vector_db_storage_cls(
|
||||
namespace="entities",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"entity_name"},
|
||||
)
|
||||
self.relationships_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="relationships",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"}
|
||||
)
|
||||
self.relationships_vdb = self.vector_db_storage_cls(
|
||||
namespace="relationships",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
meta_fields={"src_id", "tgt_id"},
|
||||
)
|
||||
self.chunks_vdb = (
|
||||
self.vector_db_storage_cls(
|
||||
namespace="chunks",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
self.chunks_vdb = self.vector_db_storage_cls(
|
||||
namespace="chunks",
|
||||
global_config=asdict(self),
|
||||
embedding_func=self.embedding_func,
|
||||
)
|
||||
|
||||
|
||||
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
||||
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
||||
)
|
||||
@@ -177,7 +173,7 @@ class LightRAG:
|
||||
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
||||
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
||||
if not len(new_docs):
|
||||
logger.warning(f"All docs are already in the storage")
|
||||
logger.warning("All docs are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
||||
|
||||
@@ -203,7 +199,7 @@ class LightRAG:
|
||||
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
||||
}
|
||||
if not len(inserting_chunks):
|
||||
logger.warning(f"All chunks are already in the storage")
|
||||
logger.warning("All chunks are already in the storage")
|
||||
return
|
||||
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
||||
|
||||
@@ -246,7 +242,7 @@ class LightRAG:
|
||||
def query(self, query: str, param: QueryParam = QueryParam()):
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery(query, param))
|
||||
|
||||
|
||||
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
||||
if param.mode == "local":
|
||||
response = await local_query(
|
||||
@@ -290,7 +286,6 @@ class LightRAG:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
await self._query_done()
|
||||
return response
|
||||
|
||||
|
||||
async def _query_done(self):
|
||||
tasks = []
|
||||
@@ -299,5 +294,3 @@ class LightRAG:
|
||||
continue
|
||||
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user