Fix linting
This commit is contained in:
@@ -8,14 +8,19 @@ import numpy as np
|
||||
from dataclasses import dataclass
|
||||
import pipmaster as pm
|
||||
|
||||
from lightrag.utils import logger,compute_mdhash_id
|
||||
from lightrag.utils import logger, compute_mdhash_id
|
||||
from lightrag.base import BaseVectorStorage
|
||||
from .shared_storage import get_namespace_data, get_storage_lock, get_namespace_object, is_multiprocess
|
||||
from .shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_lock,
|
||||
get_namespace_object,
|
||||
is_multiprocess,
|
||||
)
|
||||
|
||||
if not pm.is_installed("faiss"):
|
||||
pm.install("faiss")
|
||||
|
||||
import faiss # type: ignore
|
||||
import faiss # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
@@ -46,10 +51,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
# Embedding dimension (e.g. 768) must match your embedding function
|
||||
self._dim = self.embedding_func.embedding_dim
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
self._index = get_namespace_object('faiss_indices')
|
||||
self._id_to_meta = get_namespace_data('faiss_meta')
|
||||
|
||||
|
||||
self._index = get_namespace_object("faiss_indices")
|
||||
self._id_to_meta = get_namespace_data("faiss_meta")
|
||||
|
||||
with self._storage_lock:
|
||||
if is_multiprocess:
|
||||
if self._index.value is None:
|
||||
@@ -68,7 +73,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._id_to_meta.update({})
|
||||
self._load_faiss_index()
|
||||
|
||||
|
||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Insert or update vectors in the Faiss index.
|
||||
@@ -168,7 +172,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
|
||||
# Perform the similarity search
|
||||
with self._storage_lock:
|
||||
distances, indices = (self._index.value if is_multiprocess else self._index).search(embedding, top_k)
|
||||
distances, indices = (
|
||||
self._index.value if is_multiprocess else self._index
|
||||
).search(embedding, top_k)
|
||||
|
||||
distances = distances[0]
|
||||
indices = indices[0]
|
||||
@@ -232,7 +238,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
with self._storage_lock:
|
||||
relations = []
|
||||
for fid, meta in self._id_to_meta.items():
|
||||
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
||||
if (
|
||||
meta.get("src_id") == entity_name
|
||||
or meta.get("tgt_id") == entity_name
|
||||
):
|
||||
relations.append(fid)
|
||||
|
||||
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
||||
@@ -292,7 +301,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
Save the current Faiss index + metadata to disk so it can persist across runs.
|
||||
"""
|
||||
with self._storage_lock:
|
||||
faiss.write_index(self._index.value if is_multiprocess else self._index, self._faiss_index_file)
|
||||
faiss.write_index(
|
||||
self._index.value if is_multiprocess else self._index,
|
||||
self._faiss_index_file,
|
||||
)
|
||||
|
||||
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
||||
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
||||
@@ -320,7 +332,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
||||
self._index.value = loaded_index
|
||||
else:
|
||||
self._index = loaded_index
|
||||
|
||||
|
||||
# Load metadata
|
||||
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||
stored_dict = json.load(f)
|
||||
|
@@ -26,7 +26,6 @@ class JsonKVStorage(BaseKVStorage):
|
||||
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
||||
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
||||
|
||||
|
||||
async def index_done_callback(self) -> None:
|
||||
# 文件写入需要加锁,防止多个进程同时写入导致文件损坏
|
||||
with self._storage_lock:
|
||||
|
@@ -25,7 +25,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
def __post_init__(self):
|
||||
# Initialize lock only for file operations
|
||||
self._storage_lock = get_storage_lock()
|
||||
|
||||
|
||||
# Use global config value if specified, otherwise use default
|
||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||
cosine_threshold = kwargs.get("cosine_better_than_threshold")
|
||||
@@ -39,22 +39,28 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
|
||||
|
||||
self._client = get_namespace_object(self.namespace)
|
||||
|
||||
|
||||
with self._storage_lock:
|
||||
if is_multiprocess:
|
||||
if self._client.value is None:
|
||||
self._client.value = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
logger.info(
|
||||
f"Initialized vector DB client for namespace {self.namespace}"
|
||||
)
|
||||
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
|
||||
else:
|
||||
if self._client is None:
|
||||
self._client = NanoVectorDB(
|
||||
self.embedding_func.embedding_dim, storage_file=self._client_file_name
|
||||
self.embedding_func.embedding_dim,
|
||||
storage_file=self._client_file_name,
|
||||
)
|
||||
logger.info(
|
||||
f"Initialized vector DB client for namespace {self.namespace}"
|
||||
)
|
||||
logger.info(f"Initialized vector DB client for namespace {self.namespace}")
|
||||
|
||||
def _get_client(self):
|
||||
"""Get the appropriate client instance based on multiprocess mode"""
|
||||
@@ -104,7 +110,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
# Execute embedding outside of lock to avoid long lock times
|
||||
embedding = await self.embedding_func([query])
|
||||
embedding = embedding[0]
|
||||
|
||||
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
results = client.query(
|
||||
@@ -150,7 +156,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
logger.debug(
|
||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||
)
|
||||
|
||||
|
||||
with self._storage_lock:
|
||||
client = self._get_client()
|
||||
# Check if the entity exists
|
||||
@@ -172,7 +178,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||
for dp in storage["data"]
|
||||
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
|
||||
]
|
||||
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
|
||||
logger.debug(
|
||||
f"Found {len(relations)} relations for entity {entity_name}"
|
||||
)
|
||||
ids_to_delete = [relation["__id__"] for relation in relations]
|
||||
|
||||
if ids_to_delete:
|
||||
|
@@ -78,29 +78,33 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
with self._storage_lock:
|
||||
if is_multiprocess:
|
||||
if self._graph.value is None:
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(
|
||||
self._graphml_xml_file
|
||||
)
|
||||
self._graph.value = preloaded_graph or nx.Graph()
|
||||
if preloaded_graph:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
else:
|
||||
logger.info("Created new empty graph")
|
||||
else:
|
||||
if self._graph is None:
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||
preloaded_graph = NetworkXStorage.load_nx_graph(
|
||||
self._graphml_xml_file
|
||||
)
|
||||
self._graph = preloaded_graph or nx.Graph()
|
||||
if preloaded_graph:
|
||||
logger.info(
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||
)
|
||||
else:
|
||||
logger.info("Created new empty graph")
|
||||
|
||||
self._node_embed_algorithms = {
|
||||
"node2vec": self._node2vec_embed,
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
|
||||
def _get_graph(self):
|
||||
"""Get the appropriate graph instance based on multiprocess mode"""
|
||||
if is_multiprocess:
|
||||
@@ -248,11 +252,13 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
|
||||
with self._storage_lock:
|
||||
graph = self._get_graph()
|
||||
|
||||
|
||||
# Handle special case for "*" label
|
||||
if node_label == "*":
|
||||
# For "*", return the entire graph including all nodes and edges
|
||||
subgraph = graph.copy() # Create a copy to avoid modifying the original graph
|
||||
subgraph = (
|
||||
graph.copy()
|
||||
) # Create a copy to avoid modifying the original graph
|
||||
else:
|
||||
# Find nodes with matching node id (partial match)
|
||||
nodes_to_explore = []
|
||||
@@ -272,9 +278,9 @@ class NetworkXStorage(BaseGraphStorage):
|
||||
if len(subgraph.nodes()) > max_graph_nodes:
|
||||
origin_nodes = len(subgraph.nodes())
|
||||
node_degrees = dict(subgraph.degree())
|
||||
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
||||
:max_graph_nodes
|
||||
]
|
||||
top_nodes = sorted(
|
||||
node_degrees.items(), key=lambda x: x[1], reverse=True
|
||||
)[:max_graph_nodes]
|
||||
top_node_ids = [node[0] for node in top_nodes]
|
||||
# Create new subgraph with only top nodes
|
||||
subgraph = subgraph.subgraph(top_node_ids)
|
||||
|
@@ -17,106 +17,125 @@ _shared_dicts: Optional[Dict[str, Any]] = {}
|
||||
_share_objects: Optional[Dict[str, Any]] = {}
|
||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||
|
||||
|
||||
def initialize_share_data():
|
||||
"""Initialize shared data, only called if multiple processes where workers > 1"""
|
||||
global _manager, _shared_dicts, _share_objects, _init_flags, is_multiprocess
|
||||
is_multiprocess = True
|
||||
|
||||
|
||||
logger.info(f"Process {os.getpid()} initializing shared storage")
|
||||
|
||||
|
||||
# Initialize manager
|
||||
if _manager is None:
|
||||
_manager = Manager()
|
||||
logger.info(f"Process {os.getpid()} created manager")
|
||||
|
||||
|
||||
# Create shared dictionaries with manager
|
||||
_shared_dicts = _manager.dict()
|
||||
_share_objects = _manager.dict()
|
||||
_init_flags = _manager.dict() # 使用共享字典存储初始化标志
|
||||
logger.info(f"Process {os.getpid()} created shared dictionaries")
|
||||
|
||||
|
||||
def try_initialize_namespace(namespace: str) -> bool:
|
||||
"""
|
||||
尝试初始化命名空间。返回True表示当前进程获得了初始化权限。
|
||||
使用共享字典的原子操作确保只有一个进程能成功初始化。
|
||||
"""
|
||||
global _init_flags, _manager
|
||||
|
||||
|
||||
if is_multiprocess:
|
||||
if _init_flags is None:
|
||||
raise RuntimeError("Shared storage not initialized. Call initialize_share_data() first.")
|
||||
raise RuntimeError(
|
||||
"Shared storage not initialized. Call initialize_share_data() first."
|
||||
)
|
||||
else:
|
||||
if _init_flags is None:
|
||||
_init_flags = {}
|
||||
|
||||
|
||||
logger.info(f"Process {os.getpid()} trying to initialize namespace {namespace}")
|
||||
|
||||
|
||||
# 使用全局锁保护共享字典的访问
|
||||
with _get_global_lock():
|
||||
# 检查是否已经初始化
|
||||
if namespace not in _init_flags:
|
||||
# 设置初始化标志
|
||||
_init_flags[namespace] = True
|
||||
logger.info(f"Process {os.getpid()} ready to initialize namespace {namespace}")
|
||||
logger.info(
|
||||
f"Process {os.getpid()} ready to initialize namespace {namespace}"
|
||||
)
|
||||
return True
|
||||
|
||||
logger.info(f"Process {os.getpid()} found namespace {namespace} already initialized")
|
||||
|
||||
logger.info(
|
||||
f"Process {os.getpid()} found namespace {namespace} already initialized"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _get_global_lock() -> LockType:
|
||||
global _global_lock, is_multiprocess, _manager
|
||||
|
||||
|
||||
if _global_lock is None:
|
||||
if is_multiprocess:
|
||||
_global_lock = _manager.Lock() # Use manager for lock
|
||||
else:
|
||||
_global_lock = ThreadLock()
|
||||
|
||||
|
||||
return _global_lock
|
||||
|
||||
|
||||
def get_storage_lock() -> LockType:
|
||||
"""return storage lock for data consistency"""
|
||||
return _get_global_lock()
|
||||
|
||||
|
||||
def get_scan_lock() -> LockType:
|
||||
"""return scan_progress lock for data consistency"""
|
||||
return get_storage_lock()
|
||||
|
||||
|
||||
def get_namespace_object(namespace: str) -> Any:
|
||||
"""Get an object for specific namespace"""
|
||||
global _share_objects, is_multiprocess, _manager
|
||||
|
||||
|
||||
if is_multiprocess and not _manager:
|
||||
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
||||
raise RuntimeError(
|
||||
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
|
||||
)
|
||||
|
||||
if namespace not in _share_objects:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if namespace not in _share_objects:
|
||||
if is_multiprocess:
|
||||
_share_objects[namespace] = _manager.Value('O', None)
|
||||
_share_objects[namespace] = _manager.Value("O", None)
|
||||
else:
|
||||
_share_objects[namespace] = None
|
||||
|
||||
|
||||
return _share_objects[namespace]
|
||||
|
||||
|
||||
# 移除不再使用的函数
|
||||
|
||||
|
||||
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||
"""get storage space for specific storage type(namespace)"""
|
||||
global _shared_dicts, is_multiprocess, _manager
|
||||
|
||||
|
||||
if is_multiprocess and not _manager:
|
||||
raise RuntimeError("Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first.")
|
||||
raise RuntimeError(
|
||||
"Multiprocess mode detected but shared storage not initialized. Call initialize_share_data() first."
|
||||
)
|
||||
|
||||
if namespace not in _shared_dicts:
|
||||
lock = _get_global_lock()
|
||||
with lock:
|
||||
if namespace not in _shared_dicts:
|
||||
_shared_dicts[namespace] = {}
|
||||
|
||||
|
||||
return _shared_dicts[namespace]
|
||||
|
||||
|
||||
def get_scan_progress() -> Dict[str, Any]:
|
||||
"""get storage space for document scanning progress data"""
|
||||
return get_namespace_data('scan_progress')
|
||||
return get_namespace_data("scan_progress")
|
||||
|
Reference in New Issue
Block a user