revert vector and graph use local data(single process)
This commit is contained in:
@@ -10,19 +10,12 @@ import pipmaster as pm
|
|||||||
|
|
||||||
from lightrag.utils import logger, compute_mdhash_id
|
from lightrag.utils import logger, compute_mdhash_id
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from .shared_storage import (
|
|
||||||
get_namespace_data,
|
|
||||||
get_storage_lock,
|
|
||||||
get_namespace_object,
|
|
||||||
is_multiprocess,
|
|
||||||
try_initialize_namespace,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not pm.is_installed("faiss"):
|
if not pm.is_installed("faiss"):
|
||||||
pm.install("faiss")
|
pm.install("faiss")
|
||||||
|
|
||||||
import faiss # type: ignore
|
import faiss # type: ignore
|
||||||
|
from threading import Lock as ThreadLock
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -51,35 +44,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
# Embedding dimension (e.g. 768) must match your embedding function
|
# Embedding dimension (e.g. 768) must match your embedding function
|
||||||
self._dim = self.embedding_func.embedding_dim
|
self._dim = self.embedding_func.embedding_dim
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = ThreadLock()
|
||||||
|
|
||||||
# check need_init must before get_namespace_object/get_namespace_data
|
|
||||||
need_init = try_initialize_namespace("faiss_indices")
|
|
||||||
self._index = get_namespace_object("faiss_indices")
|
|
||||||
self._id_to_meta = get_namespace_data("faiss_meta")
|
|
||||||
|
|
||||||
if need_init:
|
|
||||||
if is_multiprocess:
|
|
||||||
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
||||||
# If you have a large number of vectors, you might want IVF or other indexes.
|
# If you have a large number of vectors, you might want IVF or other indexes.
|
||||||
# For demonstration, we use a simple IndexFlatIP.
|
# For demonstration, we use a simple IndexFlatIP.
|
||||||
self._index.value = faiss.IndexFlatIP(self._dim)
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
|
|
||||||
# Keep a local store for metadata, IDs, etc.
|
# Keep a local store for metadata, IDs, etc.
|
||||||
# Maps <int faiss_id> → metadata (including your original ID).
|
# Maps <int faiss_id> → metadata (including your original ID).
|
||||||
self._id_to_meta.update({})
|
self._id_to_meta = {}
|
||||||
|
|
||||||
# Attempt to load an existing index + metadata from disk
|
# Attempt to load an existing index + metadata from disk
|
||||||
self._load_faiss_index()
|
with self._storage_lock:
|
||||||
else:
|
|
||||||
self._index = faiss.IndexFlatIP(self._dim)
|
|
||||||
self._id_to_meta.update({})
|
|
||||||
self._load_faiss_index()
|
self._load_faiss_index()
|
||||||
|
|
||||||
|
|
||||||
def _get_index(self):
|
def _get_index(self):
|
||||||
"""
|
"""Check if the shtorage should be reloaded"""
|
||||||
Helper method to get the correct index object based on multiprocess mode.
|
return self._index
|
||||||
Returns the actual index object that can be used for operations.
|
|
||||||
"""
|
async def index_done_callback(self) -> None:
|
||||||
return self._index.value if is_multiprocess else self._index
|
with self._storage_lock:
|
||||||
|
self._save_faiss_index()
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -134,7 +121,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
# Normalize embeddings for cosine similarity (in-place)
|
# Normalize embeddings for cosine similarity (in-place)
|
||||||
faiss.normalize_L2(embeddings)
|
faiss.normalize_L2(embeddings)
|
||||||
|
|
||||||
with self._storage_lock:
|
|
||||||
# Upsert logic:
|
# Upsert logic:
|
||||||
# 1. Identify which vectors to remove if they exist
|
# 1. Identify which vectors to remove if they exist
|
||||||
# 2. Remove them
|
# 2. Remove them
|
||||||
@@ -177,7 +163,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Perform the similarity search
|
# Perform the similarity search
|
||||||
with self._storage_lock:
|
|
||||||
distances, indices = self._get_index().search(embedding, top_k)
|
distances, indices = self._get_index().search(embedding, top_k)
|
||||||
|
|
||||||
distances = distances[0]
|
distances = distances[0]
|
||||||
@@ -208,7 +193,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
@property
|
@property
|
||||||
def client_storage(self):
|
def client_storage(self):
|
||||||
# Return whatever structure LightRAG might need for debugging
|
# Return whatever structure LightRAG might need for debugging
|
||||||
with self._storage_lock:
|
|
||||||
return {"data": list(self._id_to_meta.values())}
|
return {"data": list(self._id_to_meta.values())}
|
||||||
|
|
||||||
async def delete(self, ids: list[str]):
|
async def delete(self, ids: list[str]):
|
||||||
@@ -216,7 +200,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Delete vectors for the provided custom IDs.
|
Delete vectors for the provided custom IDs.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
||||||
with self._storage_lock:
|
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for cid in ids:
|
for cid in ids:
|
||||||
fid = self._find_faiss_id_by_custom_id(cid)
|
fid = self._find_faiss_id_by_custom_id(cid)
|
||||||
@@ -239,7 +222,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Delete relations for a given entity by scanning metadata.
|
Delete relations for a given entity by scanning metadata.
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Searching relations for entity {entity_name}")
|
logger.debug(f"Searching relations for entity {entity_name}")
|
||||||
with self._storage_lock:
|
|
||||||
relations = []
|
relations = []
|
||||||
for fid, meta in self._id_to_meta.items():
|
for fid, meta in self._id_to_meta.items():
|
||||||
if (
|
if (
|
||||||
@@ -253,10 +235,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
self._remove_faiss_ids(relations)
|
self._remove_faiss_ids(relations)
|
||||||
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
|
||||||
with self._storage_lock:
|
|
||||||
self._save_faiss_index()
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------
|
||||||
# Internal helper methods
|
# Internal helper methods
|
||||||
# --------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------
|
||||||
@@ -265,7 +243,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
"""
|
"""
|
||||||
Return the Faiss internal ID for a given custom ID, or None if not found.
|
Return the Faiss internal ID for a given custom ID, or None if not found.
|
||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
|
||||||
for fid, meta in self._id_to_meta.items():
|
for fid, meta in self._id_to_meta.items():
|
||||||
if meta.get("__id__") == custom_id:
|
if meta.get("__id__") == custom_id:
|
||||||
return fid
|
return fid
|
||||||
@@ -277,7 +254,6 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
Because IndexFlatIP doesn't support 'removals',
|
Because IndexFlatIP doesn't support 'removals',
|
||||||
we rebuild the index excluding those vectors.
|
we rebuild the index excluding those vectors.
|
||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
|
||||||
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
||||||
|
|
||||||
# Rebuild the index
|
# Rebuild the index
|
||||||
@@ -288,27 +264,21 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
||||||
new_id_to_meta[new_fid] = vec_meta
|
new_id_to_meta[new_fid] = vec_meta
|
||||||
|
|
||||||
|
with self._storage_lock:
|
||||||
# Re-init index
|
# Re-init index
|
||||||
new_index = faiss.IndexFlatIP(self._dim)
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
if vectors_to_keep:
|
if vectors_to_keep:
|
||||||
arr = np.array(vectors_to_keep, dtype=np.float32)
|
arr = np.array(vectors_to_keep, dtype=np.float32)
|
||||||
new_index.add(arr)
|
self._index.add(arr)
|
||||||
if is_multiprocess:
|
|
||||||
self._index.value = new_index
|
self._id_to_meta = new_id_to_meta
|
||||||
else:
|
|
||||||
self._index = new_index
|
|
||||||
|
|
||||||
self._id_to_meta.update(new_id_to_meta)
|
|
||||||
|
|
||||||
def _save_faiss_index(self):
|
def _save_faiss_index(self):
|
||||||
"""
|
"""
|
||||||
Save the current Faiss index + metadata to disk so it can persist across runs.
|
Save the current Faiss index + metadata to disk so it can persist across runs.
|
||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
faiss.write_index(self._index, self._faiss_index_file)
|
||||||
faiss.write_index(
|
|
||||||
self._get_index(),
|
|
||||||
self._faiss_index_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
||||||
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
||||||
@@ -320,6 +290,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
with open(self._meta_file, "w", encoding="utf-8") as f:
|
with open(self._meta_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(serializable_dict, f)
|
json.dump(serializable_dict, f)
|
||||||
|
|
||||||
|
|
||||||
def _load_faiss_index(self):
|
def _load_faiss_index(self):
|
||||||
"""
|
"""
|
||||||
Load the Faiss index + metadata from disk if it exists,
|
Load the Faiss index + metadata from disk if it exists,
|
||||||
@@ -331,31 +302,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the Faiss index
|
# Load the Faiss index
|
||||||
loaded_index = faiss.read_index(self._faiss_index_file)
|
self._index = faiss.read_index(self._faiss_index_file)
|
||||||
if is_multiprocess:
|
|
||||||
self._index.value = loaded_index
|
|
||||||
else:
|
|
||||||
self._index = loaded_index
|
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
with open(self._meta_file, "r", encoding="utf-8") as f:
|
with open(self._meta_file, "r", encoding="utf-8") as f:
|
||||||
stored_dict = json.load(f)
|
stored_dict = json.load(f)
|
||||||
|
|
||||||
# Convert string keys back to int
|
# Convert string keys back to int
|
||||||
self._id_to_meta.update({})
|
self._id_to_meta = {}
|
||||||
for fid_str, meta in stored_dict.items():
|
for fid_str, meta in stored_dict.items():
|
||||||
fid = int(fid_str)
|
fid = int(fid_str)
|
||||||
self._id_to_meta[fid] = meta
|
self._id_to_meta[fid] = meta
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}"
|
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
||||||
logger.warning("Starting with an empty Faiss index.")
|
logger.warning("Starting with an empty Faiss index.")
|
||||||
new_index = faiss.IndexFlatIP(self._dim)
|
self._index = faiss.IndexFlatIP(self._dim)
|
||||||
if is_multiprocess:
|
self._id_to_meta = {}
|
||||||
self._index.value = new_index
|
|
||||||
else:
|
|
||||||
self._index = new_index
|
|
||||||
self._id_to_meta.update({})
|
|
||||||
|
@@ -11,25 +11,19 @@ from lightrag.utils import (
|
|||||||
)
|
)
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from .shared_storage import (
|
|
||||||
get_storage_lock,
|
|
||||||
get_namespace_object,
|
|
||||||
is_multiprocess,
|
|
||||||
try_initialize_namespace,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not pm.is_installed("nano-vectordb"):
|
if not pm.is_installed("nano-vectordb"):
|
||||||
pm.install("nano-vectordb")
|
pm.install("nano-vectordb")
|
||||||
|
|
||||||
from nano_vectordb import NanoVectorDB
|
from nano_vectordb import NanoVectorDB
|
||||||
|
from threading import Lock as ThreadLock
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class NanoVectorDBStorage(BaseVectorStorage):
|
class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize lock only for file operations
|
# Initialize lock only for file operations
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = ThreadLock()
|
||||||
|
|
||||||
# Use global config value if specified, otherwise use default
|
# Use global config value if specified, otherwise use default
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
@@ -45,32 +39,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
# check need_init must before get_namespace_object
|
with self._storage_lock:
|
||||||
need_init = try_initialize_namespace(self.namespace)
|
|
||||||
self._client = get_namespace_object(self.namespace)
|
|
||||||
|
|
||||||
if need_init:
|
|
||||||
if is_multiprocess:
|
|
||||||
self._client.value = NanoVectorDB(
|
|
||||||
self.embedding_func.embedding_dim,
|
|
||||||
storage_file=self._client_file_name,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Initialized vector DB client for namespace {self.namespace}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._client = NanoVectorDB(
|
self._client = NanoVectorDB(
|
||||||
self.embedding_func.embedding_dim,
|
self.embedding_func.embedding_dim,
|
||||||
storage_file=self._client_file_name,
|
storage_file=self._client_file_name,
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"Initialized vector DB client for namespace {self.namespace}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_client(self):
|
def _get_client(self):
|
||||||
"""Get the appropriate client instance based on multiprocess mode"""
|
"""Check if the shtorage should be reloaded"""
|
||||||
if is_multiprocess:
|
|
||||||
return self._client.value
|
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
@@ -101,7 +77,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
if len(embeddings) == len(list_data):
|
if len(embeddings) == len(list_data):
|
||||||
for i, d in enumerate(list_data):
|
for i, d in enumerate(list_data):
|
||||||
d["__vector__"] = embeddings[i]
|
d["__vector__"] = embeddings[i]
|
||||||
with self._storage_lock:
|
|
||||||
results = self._get_client().upsert(datas=list_data)
|
results = self._get_client().upsert(datas=list_data)
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
@@ -115,7 +90,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
embedding = await self.embedding_func([query])
|
embedding = await self.embedding_func([query])
|
||||||
embedding = embedding[0]
|
embedding = embedding[0]
|
||||||
|
|
||||||
with self._storage_lock:
|
|
||||||
results = self._get_client().query(
|
results = self._get_client().query(
|
||||||
query=embedding,
|
query=embedding,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -143,7 +117,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
ids: List of vector IDs to be deleted
|
ids: List of vector IDs to be deleted
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with self._storage_lock:
|
|
||||||
self._get_client().delete(ids)
|
self._get_client().delete(ids)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
|
||||||
@@ -158,7 +131,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
f"Attempting to delete entity {entity_name} with ID {entity_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._storage_lock:
|
|
||||||
# Check if the entity exists
|
# Check if the entity exists
|
||||||
if self._get_client().get([entity_id]):
|
if self._get_client().get([entity_id]):
|
||||||
self._get_client().delete([entity_id])
|
self._get_client().delete([entity_id])
|
||||||
@@ -170,7 +142,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
async def delete_entity_relation(self, entity_name: str) -> None:
|
async def delete_entity_relation(self, entity_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
with self._storage_lock:
|
|
||||||
storage = getattr(self._get_client(), "_NanoVectorDB__storage")
|
storage = getattr(self._get_client(), "_NanoVectorDB__storage")
|
||||||
relations = [
|
relations = [
|
||||||
dp
|
dp
|
||||||
|
@@ -6,12 +6,6 @@ import numpy as np
|
|||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
from lightrag.base import BaseGraphStorage
|
from lightrag.base import BaseGraphStorage
|
||||||
from .shared_storage import (
|
|
||||||
get_storage_lock,
|
|
||||||
get_namespace_object,
|
|
||||||
is_multiprocess,
|
|
||||||
try_initialize_namespace,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
@@ -23,7 +17,7 @@ if not pm.is_installed("graspologic"):
|
|||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from graspologic import embed
|
from graspologic import embed
|
||||||
|
from threading import Lock as ThreadLock
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -78,38 +72,23 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
self._graphml_xml_file = os.path.join(
|
self._graphml_xml_file = os.path.join(
|
||||||
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
|
||||||
)
|
)
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = ThreadLock()
|
||||||
|
|
||||||
# check need_init must before get_namespace_object
|
with self._storage_lock:
|
||||||
need_init = try_initialize_namespace(self.namespace)
|
|
||||||
self._graph = get_namespace_object(self.namespace)
|
|
||||||
|
|
||||||
if need_init:
|
|
||||||
if is_multiprocess:
|
|
||||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
||||||
self._graph.value = preloaded_graph or nx.Graph()
|
if preloaded_graph is not None:
|
||||||
if preloaded_graph:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
|
||||||
self._graph = preloaded_graph or nx.Graph()
|
|
||||||
if preloaded_graph:
|
|
||||||
logger.info(
|
|
||||||
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Created new empty graph")
|
logger.info("Created new empty graph")
|
||||||
|
self._graph = preloaded_graph or nx.Graph()
|
||||||
self._node_embed_algorithms = {
|
self._node_embed_algorithms = {
|
||||||
"node2vec": self._node2vec_embed,
|
"node2vec": self._node2vec_embed,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_graph(self):
|
def _get_graph(self):
|
||||||
"""Get the appropriate graph instance based on multiprocess mode"""
|
"""Check if the shtorage should be reloaded"""
|
||||||
if is_multiprocess:
|
|
||||||
return self._graph.value
|
|
||||||
return self._graph
|
return self._graph
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
@@ -117,49 +96,39 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
|
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
with self._storage_lock:
|
|
||||||
return self._get_graph().has_node(node_id)
|
return self._get_graph().has_node(node_id)
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
with self._storage_lock:
|
|
||||||
return self._get_graph().has_edge(source_node_id, target_node_id)
|
return self._get_graph().has_edge(source_node_id, target_node_id)
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
with self._storage_lock:
|
|
||||||
return self._get_graph().nodes.get(node_id)
|
return self._get_graph().nodes.get(node_id)
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
with self._storage_lock:
|
|
||||||
return self._get_graph().degree(node_id)
|
return self._get_graph().degree(node_id)
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
with self._storage_lock:
|
|
||||||
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
|
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
with self._storage_lock:
|
|
||||||
return self._get_graph().edges.get((source_node_id, target_node_id))
|
return self._get_graph().edges.get((source_node_id, target_node_id))
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
with self._storage_lock:
|
|
||||||
if self._get_graph().has_node(source_node_id):
|
if self._get_graph().has_node(source_node_id):
|
||||||
return list(self._get_graph().edges(source_node_id))
|
return list(self._get_graph().edges(source_node_id))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
|
||||||
with self._storage_lock:
|
|
||||||
self._get_graph().add_node(node_id, **node_data)
|
self._get_graph().add_node(node_id, **node_data)
|
||||||
|
|
||||||
async def upsert_edge(
|
async def upsert_edge(
|
||||||
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
||||||
) -> None:
|
) -> None:
|
||||||
with self._storage_lock:
|
|
||||||
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
|
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
|
||||||
|
|
||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
with self._storage_lock:
|
|
||||||
if self._get_graph().has_node(node_id):
|
if self._get_graph().has_node(node_id):
|
||||||
self._get_graph().remove_node(node_id)
|
self._get_graph().remove_node(node_id)
|
||||||
logger.debug(f"Node {node_id} deleted from the graph.")
|
logger.debug(f"Node {node_id} deleted from the graph.")
|
||||||
@@ -175,7 +144,6 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
# TODO: NOT USED
|
# TODO: NOT USED
|
||||||
async def _node2vec_embed(self):
|
async def _node2vec_embed(self):
|
||||||
with self._storage_lock:
|
|
||||||
graph = self._get_graph()
|
graph = self._get_graph()
|
||||||
embeddings, nodes = embed.node2vec_embed(
|
embeddings, nodes = embed.node2vec_embed(
|
||||||
graph,
|
graph,
|
||||||
@@ -190,7 +158,6 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
nodes: List of node IDs to be deleted
|
nodes: List of node IDs to be deleted
|
||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
|
||||||
graph = self._get_graph()
|
graph = self._get_graph()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if graph.has_node(node):
|
if graph.has_node(node):
|
||||||
@@ -202,7 +169,6 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
Args:
|
Args:
|
||||||
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
edges: List of edges to be deleted, each edge is a (source, target) tuple
|
||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
|
||||||
graph = self._get_graph()
|
graph = self._get_graph()
|
||||||
for source, target in edges:
|
for source, target in edges:
|
||||||
if graph.has_edge(source, target):
|
if graph.has_edge(source, target):
|
||||||
@@ -214,7 +180,6 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
Returns:
|
Returns:
|
||||||
[label1, label2, ...] # Alphabetically sorted label list
|
[label1, label2, ...] # Alphabetically sorted label list
|
||||||
"""
|
"""
|
||||||
with self._storage_lock:
|
|
||||||
labels = set()
|
labels = set()
|
||||||
for node in self._get_graph().nodes():
|
for node in self._get_graph().nodes():
|
||||||
labels.add(str(node)) # Add node id as a label
|
labels.add(str(node)) # Add node id as a label
|
||||||
@@ -239,7 +204,6 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
seen_nodes = set()
|
seen_nodes = set()
|
||||||
seen_edges = set()
|
seen_edges = set()
|
||||||
|
|
||||||
with self._storage_lock:
|
|
||||||
graph = self._get_graph()
|
graph = self._get_graph()
|
||||||
|
|
||||||
# Handle special case for "*" label
|
# Handle special case for "*" label
|
||||||
|
@@ -20,15 +20,12 @@ LockType = Union[ProcessLock, ThreadLock]
|
|||||||
_manager = None
|
_manager = None
|
||||||
_initialized = None
|
_initialized = None
|
||||||
is_multiprocess = None
|
is_multiprocess = None
|
||||||
|
_global_lock: Optional[LockType] = None
|
||||||
|
|
||||||
# shared data for storage across processes
|
# shared data for storage across processes
|
||||||
_shared_dicts: Optional[Dict[str, Any]] = None
|
_shared_dicts: Optional[Dict[str, Any]] = None
|
||||||
_share_objects: Optional[Dict[str, Any]] = None
|
|
||||||
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
||||||
|
|
||||||
_global_lock: Optional[LockType] = None
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_share_data(workers: int = 1):
|
def initialize_share_data(workers: int = 1):
|
||||||
"""
|
"""
|
||||||
Initialize shared storage data for single or multi-process mode.
|
Initialize shared storage data for single or multi-process mode.
|
||||||
@@ -53,7 +50,6 @@ def initialize_share_data(workers: int = 1):
|
|||||||
is_multiprocess, \
|
is_multiprocess, \
|
||||||
_global_lock, \
|
_global_lock, \
|
||||||
_shared_dicts, \
|
_shared_dicts, \
|
||||||
_share_objects, \
|
|
||||||
_init_flags, \
|
_init_flags, \
|
||||||
_initialized
|
_initialized
|
||||||
|
|
||||||
@@ -72,7 +68,6 @@ def initialize_share_data(workers: int = 1):
|
|||||||
_global_lock = _manager.Lock()
|
_global_lock = _manager.Lock()
|
||||||
# Create shared dictionaries with manager
|
# Create shared dictionaries with manager
|
||||||
_shared_dicts = _manager.dict()
|
_shared_dicts = _manager.dict()
|
||||||
_share_objects = _manager.dict()
|
|
||||||
_init_flags = (
|
_init_flags = (
|
||||||
_manager.dict()
|
_manager.dict()
|
||||||
) # Use shared dictionary to store initialization flags
|
) # Use shared dictionary to store initialization flags
|
||||||
@@ -83,7 +78,6 @@ def initialize_share_data(workers: int = 1):
|
|||||||
is_multiprocess = False
|
is_multiprocess = False
|
||||||
_global_lock = ThreadLock()
|
_global_lock = ThreadLock()
|
||||||
_shared_dicts = {}
|
_shared_dicts = {}
|
||||||
_share_objects = {}
|
|
||||||
_init_flags = {}
|
_init_flags = {}
|
||||||
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
|
||||||
|
|
||||||
@@ -99,11 +93,7 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
global _init_flags, _manager
|
global _init_flags, _manager
|
||||||
|
|
||||||
if _init_flags is None:
|
if _init_flags is None:
|
||||||
direct_log(
|
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
||||||
f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}",
|
|
||||||
level="ERROR",
|
|
||||||
)
|
|
||||||
raise ValueError("Shared dictionaries not initialized")
|
|
||||||
|
|
||||||
if namespace not in _init_flags:
|
if namespace not in _init_flags:
|
||||||
_init_flags[namespace] = True
|
_init_flags[namespace] = True
|
||||||
@@ -113,43 +103,9 @@ def try_initialize_namespace(namespace: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _get_global_lock() -> LockType:
|
|
||||||
return _global_lock
|
|
||||||
|
|
||||||
|
|
||||||
def get_storage_lock() -> LockType:
|
def get_storage_lock() -> LockType:
|
||||||
"""return storage lock for data consistency"""
|
"""return storage lock for data consistency"""
|
||||||
return _get_global_lock()
|
return _global_lock
|
||||||
|
|
||||||
|
|
||||||
def get_scan_lock() -> LockType:
|
|
||||||
"""return scan_progress lock for data consistency"""
|
|
||||||
return get_storage_lock()
|
|
||||||
|
|
||||||
|
|
||||||
def get_namespace_object(namespace: str) -> Any:
|
|
||||||
"""Get an object for specific namespace"""
|
|
||||||
|
|
||||||
if _share_objects is None:
|
|
||||||
direct_log(
|
|
||||||
f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}",
|
|
||||||
level="ERROR",
|
|
||||||
)
|
|
||||||
raise ValueError("Shared dictionaries not initialized")
|
|
||||||
|
|
||||||
lock = _get_global_lock()
|
|
||||||
with lock:
|
|
||||||
if namespace not in _share_objects:
|
|
||||||
if namespace not in _share_objects:
|
|
||||||
if is_multiprocess:
|
|
||||||
_share_objects[namespace] = _manager.Value("O", None)
|
|
||||||
else:
|
|
||||||
_share_objects[namespace] = None
|
|
||||||
direct_log(
|
|
||||||
f"Created namespace: {namespace}(type={type(_share_objects[namespace])})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return _share_objects[namespace]
|
|
||||||
|
|
||||||
|
|
||||||
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
||||||
@@ -161,7 +117,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
raise ValueError("Shared dictionaries not initialized")
|
raise ValueError("Shared dictionaries not initialized")
|
||||||
|
|
||||||
lock = _get_global_lock()
|
lock = get_storage_lock()
|
||||||
with lock:
|
with lock:
|
||||||
if namespace not in _shared_dicts:
|
if namespace not in _shared_dicts:
|
||||||
if is_multiprocess and _manager is not None:
|
if is_multiprocess and _manager is not None:
|
||||||
@@ -175,11 +131,6 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
|
|||||||
return _shared_dicts[namespace]
|
return _shared_dicts[namespace]
|
||||||
|
|
||||||
|
|
||||||
def get_scan_progress() -> Dict[str, Any]:
|
|
||||||
"""get storage space for document scanning progress data"""
|
|
||||||
return get_namespace_data("scan_progress")
|
|
||||||
|
|
||||||
|
|
||||||
def finalize_share_data():
|
def finalize_share_data():
|
||||||
"""
|
"""
|
||||||
Release shared resources and clean up.
|
Release shared resources and clean up.
|
||||||
@@ -195,7 +146,6 @@ def finalize_share_data():
|
|||||||
is_multiprocess, \
|
is_multiprocess, \
|
||||||
_global_lock, \
|
_global_lock, \
|
||||||
_shared_dicts, \
|
_shared_dicts, \
|
||||||
_share_objects, \
|
|
||||||
_init_flags, \
|
_init_flags, \
|
||||||
_initialized
|
_initialized
|
||||||
|
|
||||||
@@ -216,8 +166,6 @@ def finalize_share_data():
|
|||||||
# Clear shared dictionaries first
|
# Clear shared dictionaries first
|
||||||
if _shared_dicts is not None:
|
if _shared_dicts is not None:
|
||||||
_shared_dicts.clear()
|
_shared_dicts.clear()
|
||||||
if _share_objects is not None:
|
|
||||||
_share_objects.clear()
|
|
||||||
if _init_flags is not None:
|
if _init_flags is not None:
|
||||||
_init_flags.clear()
|
_init_flags.clear()
|
||||||
|
|
||||||
@@ -234,7 +182,6 @@ def finalize_share_data():
|
|||||||
_initialized = None
|
_initialized = None
|
||||||
is_multiprocess = None
|
is_multiprocess = None
|
||||||
_shared_dicts = None
|
_shared_dicts = None
|
||||||
_share_objects = None
|
|
||||||
_init_flags = None
|
_init_flags = None
|
||||||
_global_lock = None
|
_global_lock = None
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user