Fix linting
This commit is contained in:
@@ -17,6 +17,7 @@ if not pm.is_installed("faiss"):
|
|||||||
import faiss # type: ignore
|
import faiss # type: ignore
|
||||||
from threading import Lock as ThreadLock
|
from threading import Lock as ThreadLock
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class FaissVectorDBStorage(BaseVectorStorage):
|
class FaissVectorDBStorage(BaseVectorStorage):
|
||||||
@@ -59,7 +60,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
self._load_faiss_index()
|
self._load_faiss_index()
|
||||||
|
|
||||||
|
|
||||||
def _get_index(self):
|
def _get_index(self):
|
||||||
"""Check if the shtorage should be reloaded"""
|
"""Check if the shtorage should be reloaded"""
|
||||||
return self._index
|
return self._index
|
||||||
@@ -224,10 +224,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
logger.debug(f"Searching relations for entity {entity_name}")
|
logger.debug(f"Searching relations for entity {entity_name}")
|
||||||
relations = []
|
relations = []
|
||||||
for fid, meta in self._id_to_meta.items():
|
for fid, meta in self._id_to_meta.items():
|
||||||
if (
|
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
||||||
meta.get("src_id") == entity_name
|
|
||||||
or meta.get("tgt_id") == entity_name
|
|
||||||
):
|
|
||||||
relations.append(fid)
|
relations.append(fid)
|
||||||
|
|
||||||
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
||||||
@@ -265,7 +262,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
new_id_to_meta[new_fid] = vec_meta
|
new_id_to_meta[new_fid] = vec_meta
|
||||||
|
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
# Re-init index
|
# Re-init index
|
||||||
self._index = faiss.IndexFlatIP(self._dim)
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
if vectors_to_keep:
|
if vectors_to_keep:
|
||||||
arr = np.array(vectors_to_keep, dtype=np.float32)
|
arr = np.array(vectors_to_keep, dtype=np.float32)
|
||||||
@@ -273,7 +270,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
self._id_to_meta = new_id_to_meta
|
self._id_to_meta = new_id_to_meta
|
||||||
|
|
||||||
|
|
||||||
def _save_faiss_index(self):
|
def _save_faiss_index(self):
|
||||||
"""
|
"""
|
||||||
Save the current Faiss index + metadata to disk so it can persist across runs.
|
Save the current Faiss index + metadata to disk so it can persist across runs.
|
||||||
@@ -290,7 +286,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
with open(self._meta_file, "w", encoding="utf-8") as f:
|
with open(self._meta_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(serializable_dict, f)
|
json.dump(serializable_dict, f)
|
||||||
|
|
||||||
|
|
||||||
def _load_faiss_index(self):
|
def _load_faiss_index(self):
|
||||||
"""
|
"""
|
||||||
Load the Faiss index + metadata from disk if it exists,
|
Load the Faiss index + metadata from disk if it exists,
|
||||||
|
@@ -84,7 +84,9 @@ class JsonDocStatusStorage(DocStatusStorage):
|
|||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
data_dict = (
|
||||||
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
|
)
|
||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
|
@@ -36,7 +36,9 @@ class JsonKVStorage(BaseKVStorage):
|
|||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
with self._storage_lock:
|
with self._storage_lock:
|
||||||
data_dict = dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
data_dict = (
|
||||||
|
dict(self._data) if hasattr(self._data, "_getvalue") else self._data
|
||||||
|
)
|
||||||
write_json(data_dict, self._file_name)
|
write_json(data_dict, self._file_name)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
|
@@ -18,6 +18,7 @@ if not pm.is_installed("nano-vectordb"):
|
|||||||
from nano_vectordb import NanoVectorDB
|
from nano_vectordb import NanoVectorDB
|
||||||
from threading import Lock as ThreadLock
|
from threading import Lock as ThreadLock
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
@@ -148,9 +149,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
for dp in storage["data"]
|
for dp in storage["data"]
|
||||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||||
]
|
]
|
||||||
logger.debug(
|
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
||||||
f"Found {len(relations)} relations for entity {entity_name}"
|
|
||||||
)
|
|
||||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||||
|
|
||||||
if ids_to_delete:
|
if ids_to_delete:
|
||||||
|
@@ -19,6 +19,7 @@ import networkx as nx
|
|||||||
from graspologic import embed
|
from graspologic import embed
|
||||||
from threading import Lock as ThreadLock
|
from threading import Lock as ThreadLock
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class NetworkXStorage(BaseGraphStorage):
|
class NetworkXStorage(BaseGraphStorage):
|
||||||
@@ -231,9 +232,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
if len(subgraph.nodes()) > max_graph_nodes:
|
if len(subgraph.nodes()) > max_graph_nodes:
|
||||||
origin_nodes = len(subgraph.nodes())
|
origin_nodes = len(subgraph.nodes())
|
||||||
node_degrees = dict(subgraph.degree())
|
node_degrees = dict(subgraph.degree())
|
||||||
top_nodes = sorted(
|
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
||||||
node_degrees.items(), key=lambda x: x[1], reverse=True
|
:max_graph_nodes
|
||||||
)[:max_graph_nodes]
|
]
|
||||||
top_node_ids = [node[0] for node in top_nodes]
|
top_node_ids = [node[0] for node in top_nodes]
|
||||||
# Create new subgraph with only top nodes
|
# Create new subgraph with only top nodes
|
||||||
subgraph = subgraph.subgraph(top_node_ids)
|
subgraph = subgraph.subgraph(top_node_ids)
|
||||||
|
@@ -26,6 +26,7 @@ _global_lock: Optional[LockType] = None
|
|||||||
_shared_dicts: Optional[Dict[str, Any]] = None
|
_shared_dicts: Optional[Dict[str, Any]] = None
|
||||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||||
|
|
||||||
|
|
||||||
def initialize_share_data(workers: int = 1):
|
def initialize_share_data(workers: int = 1):
|
||||||
"""
|
"""
|
||||||
Initialize shared storage data for single or multi-process mode.
|
Initialize shared storage data for single or multi-process mode.
|
||||||
@@ -66,9 +67,7 @@ def initialize_share_data(workers: int = 1):
|
|||||||
is_multiprocess = True
|
is_multiprocess = True
|
||||||
_global_lock = _manager.Lock()
|
_global_lock = _manager.Lock()
|
||||||
_shared_dicts = _manager.dict()
|
_shared_dicts = _manager.dict()
|
||||||
_init_flags = (
|
_init_flags = _manager.dict()
|
||||||
_manager.dict()
|
|
||||||
)
|
|
||||||
direct_log(
|
direct_log(
|
||||||
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
|
f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})"
|
||||||
)
|
)
|
||||||
@@ -95,9 +94,13 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
|
|
||||||
if namespace not in _init_flags:
|
if namespace not in _init_flags:
|
||||||
_init_flags[namespace] = True
|
_init_flags[namespace] = True
|
||||||
direct_log(f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]")
|
direct_log(
|
||||||
|
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
direct_log(f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]")
|
direct_log(
|
||||||
|
f"Process {os.getpid()} storage namespace already to initialized: [{namespace}]"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user