chore: added pre-commit-hooks and ruff formatting for commit-hooks

This commit is contained in:
Sanketh Kumar
2024-10-19 09:43:17 +05:30
parent b854ab4737
commit 744dad339d
26 changed files with 635 additions and 393 deletions

View File

@@ -3,10 +3,12 @@ import os
from dataclasses import asdict, dataclass, field
from datetime import datetime
from functools import partial
from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from typing import Type, cast
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
from .llm import (
gpt_4o_mini_complete,
openai_embedding,
)
from .operate import (
chunking_by_token_size,
extract_entities,
@@ -37,6 +39,7 @@ from .base import (
QueryParam,
)
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
@@ -69,7 +72,6 @@ class LightRAG:
"dimensions": 1536,
"num_walks": 10,
"walk_length": 40,
"num_walks": 10,
"window_size": 2,
"iterations": 3,
"random_seed": 3,
@@ -77,13 +79,13 @@ class LightRAG:
)
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
embedding_batch_num: int = 32
embedding_func_max_async: int = 16
# LLM
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
llm_model_max_token_size: int = 32768
llm_model_max_async: int = 16
@@ -98,11 +100,11 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json
def __post_init__(self):
def __post_init__(self):
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file)
logger.info(f"Logger initialized for working directory: {self.working_dir}")
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -133,30 +135,24 @@ class LightRAG:
self.embedding_func
)
self.entities_vdb = (
self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"}
)
self.entities_vdb = self.vector_db_storage_cls(
namespace="entities",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"entity_name"},
)
self.relationships_vdb = (
self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"}
)
self.relationships_vdb = self.vector_db_storage_cls(
namespace="relationships",
global_config=asdict(self),
embedding_func=self.embedding_func,
meta_fields={"src_id", "tgt_id"},
)
self.chunks_vdb = (
self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.chunks_vdb = self.vector_db_storage_cls(
namespace="chunks",
global_config=asdict(self),
embedding_func=self.embedding_func,
)
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
)
@@ -177,7 +173,7 @@ class LightRAG:
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if not len(new_docs):
logger.warning(f"All docs are already in the storage")
logger.warning("All docs are already in the storage")
return
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
@@ -203,7 +199,7 @@ class LightRAG:
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if not len(inserting_chunks):
logger.warning(f"All chunks are already in the storage")
logger.warning("All chunks are already in the storage")
return
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
@@ -246,7 +242,7 @@ class LightRAG:
def query(self, query: str, param: QueryParam = QueryParam()):
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param))
async def aquery(self, query: str, param: QueryParam = QueryParam()):
if param.mode == "local":
response = await local_query(
@@ -290,7 +286,6 @@ class LightRAG:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
async def _query_done(self):
tasks = []
@@ -299,5 +294,3 @@ class LightRAG:
continue
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
await asyncio.gather(*tasks)