Fix linting

This commit is contained in:
yangdx
2025-02-28 01:25:59 +08:00
parent 3dcfa561d7
commit cd7648791a
6 changed files with 23 additions and 21 deletions

View File

@@ -17,6 +17,7 @@ if not pm.is_installed("faiss"):
import faiss # type: ignore import faiss # type: ignore
from threading import Lock as ThreadLock from threading import Lock as ThreadLock
@final @final
@dataclass @dataclass
class FaissVectorDBStorage(BaseVectorStorage): class FaissVectorDBStorage(BaseVectorStorage):
@@ -59,7 +60,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
with self._storage_lock: with self._storage_lock:
self._load_faiss_index() self._load_faiss_index()
def _get_index(self): def _get_index(self):
"""Check if the shtorage should be reloaded""" """Check if the shtorage should be reloaded"""
return self._index return self._index
@@ -224,10 +224,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
logger.debug(f"Searching relations for entity {entity_name}") logger.debug(f"Searching relations for entity {entity_name}")
relations = [] relations = []
for fid, meta in self._id_to_meta.items(): for fid, meta in self._id_to_meta.items():
if ( if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
meta.get("src_id") == entity_name
or meta.get("tgt_id") == entity_name
):
relations.append(fid) relations.append(fid)
logger.debug(f"Found {len(relations)} relations for {entity_name}") logger.debug(f"Found {len(relations)} relations for {entity_name}")
@@ -273,7 +270,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._id_to_meta = new_id_to_meta self._id_to_meta = new_id_to_meta
def _save_faiss_index(self): def _save_faiss_index(self):
""" """
Save the current Faiss index + metadata to disk so it can persist across runs. Save the current Faiss index + metadata to disk so it can persist across runs.
@@ -290,7 +286,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
with open(self._meta_file, "w", encoding="utf-8") as f: with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f) json.dump(serializable_dict, f)
def _load_faiss_index(self): def _load_faiss_index(self):
""" """
Load the Faiss index + metadata from disk if it exists, Load the Faiss index + metadata from disk if it exists,

View File

@@ -84,7 +84,9 @@ class JsonDocStatusStorage(DocStatusStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
with self._storage_lock: with self._storage_lock:
data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:

View File

@@ -36,7 +36,9 @@ class JsonKVStorage(BaseKVStorage):
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
with self._storage_lock: with self._storage_lock:
data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data data_dict = (
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
)
write_json(data_dict, self._file_name) write_json(data_dict, self._file_name)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:

View File

@@ -18,6 +18,7 @@ if not pm.is_installed("nano-vectordb"):
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
from threading import Lock as ThreadLock from threading import Lock as ThreadLock
@final @final
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
@@ -148,9 +149,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
for dp in storage["data"] for dp in storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
] ]
logger.debug( logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
f"Found {len(relations)} relations for entity {entity_name}"
)
ids_to_delete = [relation["__id__"] for relation in relations] ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete: if ids_to_delete:

View File

@@ -19,6 +19,7 @@ import networkx as nx
from graspologic import embed from graspologic import embed
from threading import Lock as ThreadLock from threading import Lock as ThreadLock
@final @final
@dataclass @dataclass
class NetworkXStorage(BaseGraphStorage): class NetworkXStorage(BaseGraphStorage):
@@ -231,9 +232,9 @@ class NetworkXStorage(BaseGraphStorage):
if len(subgraph.nodes()) > max_graph_nodes: if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
top_nodes = sorted( top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
node_degrees.items(), key=lambda x: x[1], reverse=True :max_graph_nodes
)[:max_graph_nodes] ]
top_node_ids = [node[0] for node in top_nodes] top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes # Create new subgraph with only top nodes
subgraph = subgraph.subgraph(top_node_ids) subgraph = subgraph.subgraph(top_node_ids)

View File

@@ -26,6 +26,7 @@ _global_lock: Optional[LockType] = None
_shared_dicts: Optional[Dict[str, Any]] = None _shared_dicts: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
def initialize_share_data(workers: int = 1): def initialize_share_data(workers: int = 1):
""" """
Initialize shared storage data for single or multi-process mode. Initialize shared storage data for single or multi-process mode.
@@ -66,9 +67,7 @@ def initialize_share_data(workers: int = 1):
is_multiprocess = True is_multiprocess = True
_global_lock = _manager.Lock() _global_lock = _manager.Lock()
_shared_dicts = _manager.dict() _shared_dicts = _manager.dict()
_init_flags = ( _init_flags = _manager.dict()
_manager.dict()
)
direct_log( direct_log(
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
) )
@@ -95,9 +94,13 @@ def try_initialize_namespace(namespace: str) -> bool:
if namespace not in _init_flags: if namespace not in _init_flags:
_init_flags[namespace] = True _init_flags[namespace] = True
direct_log(f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]") direct_log(
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
)
return True return True
direct_log(f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]") direct_log(
f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]"
)
return False return False