From cd7648791a72af93efc04031c3fd7397550fe2ab Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 28 Feb 2025 01:25:59 +0800 Subject: [PATCH] Fix linting --- lightrag/kg/faiss_impl.py | 11 +++-------- lightrag/kg/json_doc_status_impl.py | 4 +++- lightrag/kg/json_kv_impl.py | 4 +++- lightrag/kg/nano_vector_db_impl.py | 5 ++--- lightrag/kg/networkx_impl.py | 7 ++++--- lightrag/kg/shared_storage.py | 13 ++++++++----- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index a3520653..d0ef6ed0 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -17,6 +17,7 @@ if not pm.is_installed("faiss"): import faiss # type: ignore from threading import Lock as ThreadLock + @final @dataclass class FaissVectorDBStorage(BaseVectorStorage): @@ -59,7 +60,6 @@ class FaissVectorDBStorage(BaseVectorStorage): with self._storage_lock: self._load_faiss_index() - def _get_index(self): """Check if the shtorage should be reloaded""" return self._index @@ -224,10 +224,7 @@ class FaissVectorDBStorage(BaseVectorStorage): logger.debug(f"Searching relations for entity {entity_name}") relations = [] for fid, meta in self._id_to_meta.items(): - if ( - meta.get("src_id") == entity_name - or meta.get("tgt_id") == entity_name - ): + if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name: relations.append(fid) logger.debug(f"Found {len(relations)} relations for {entity_name}") @@ -265,7 +262,7 @@ class FaissVectorDBStorage(BaseVectorStorage): new_id_to_meta[new_fid] = vec_meta with self._storage_lock: - # Re-init index + # Re-init index self._index = faiss.IndexFlatIP(self._dim) if vectors_to_keep: arr = np.array(vectors_to_keep, dtype=np.float32) @@ -273,7 +270,6 @@ class FaissVectorDBStorage(BaseVectorStorage): self._id_to_meta = new_id_to_meta - def _save_faiss_index(self): """ Save the current Faiss index + metadata to disk so it can persist across runs. @@ -290,7 +286,6 @@ class FaissVectorDBStorage(BaseVectorStorage): with open(self._meta_file, "w", encoding="utf-8") as f: json.dump(serializable_dict, f) - def _load_faiss_index(self): """ Load the Faiss index + metadata from disk if it exists, diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index b71cf618..05e6da37 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -84,7 +84,9 @@ class JsonDocStatusStorage(DocStatusStorage): async def index_done_callback(self) -> None: with self._storage_lock: - data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) write_json(data_dict, self._file_name) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index c5bff177..a4ce91a5 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -36,7 +36,9 @@ class JsonKVStorage(BaseKVStorage): async def index_done_callback(self) -> None: with self._storage_lock: - data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data + data_dict = ( + dict(self._data) if hasattr(self._data, "_getvalue") else self._data + ) write_json(data_dict, self._file_name) async def get_by_id(self, id: str) -> dict[str, Any] | None: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index b8fe573d..bbf991bf 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -18,6 +18,7 @@ if not pm.is_installed("nano-vectordb"): from nano_vectordb import NanoVectorDB from threading import Lock as ThreadLock + @final @dataclass class NanoVectorDBStorage(BaseVectorStorage): @@ -148,9 +149,7 @@ class NanoVectorDBStorage(BaseVectorStorage): for dp in storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name ] - logger.debug( - f"Found {len(relations)} relations for entity {entity_name}" - ) + logger.debug(f"Found {len(relations)} relations for entity {entity_name}") ids_to_delete = [relation["__id__"] for relation in relations] if ids_to_delete: diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 1f14d5b0..ccf85855 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -19,6 +19,7 @@ import networkx as nx from graspologic import embed from threading import Lock as ThreadLock + @final @dataclass class NetworkXStorage(BaseGraphStorage): @@ -231,9 +232,9 @@ class NetworkXStorage(BaseGraphStorage): if len(subgraph.nodes()) > max_graph_nodes: origin_nodes = len(subgraph.nodes()) node_degrees = dict(subgraph.degree()) - top_nodes = sorted( - node_degrees.items(), key=lambda x: x[1], reverse=True - )[:max_graph_nodes] + top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[ + :max_graph_nodes + ] top_node_ids = [node[0] for node in top_nodes] # Create new subgraph with only top nodes subgraph = subgraph.subgraph(top_node_ids) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index f7c2e909..19b1b1cb 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -26,6 +26,7 @@ _global_lock: Optional[LockType] = None _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized + def initialize_share_data(workers: int = 1): """ Initialize shared storage data for single or multi-process mode. @@ -66,9 +67,7 @@ def initialize_share_data(workers: int = 1): is_multiprocess = True _global_lock = _manager.Lock() _shared_dicts = _manager.dict() - _init_flags = ( - _manager.dict() - ) + _init_flags = _manager.dict() direct_log( f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" ) @@ -95,9 +94,13 @@ def try_initialize_namespace(namespace: str) -> bool: if namespace not in _init_flags: _init_flags[namespace] = True - direct_log(f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]") + direct_log( + f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + ) return True - direct_log(f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]") + direct_log( + f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]" + ) return False