revert vector and graph use local data(single process)

This commit is contained in:
yangdx
2025-02-28 01:14:25 +08:00
parent db2a902fcb
commit 291e0c1b14
4 changed files with 287 additions and 443 deletions

View File

@@ -10,19 +10,12 @@ import pipmaster as pm
from lightrag.utils import logger, compute_mdhash_id from lightrag.utils import logger, compute_mdhash_id
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from .shared_storage import (
get_namespace_data,
get_storage_lock,
get_namespace_object,
is_multiprocess,
try_initialize_namespace,
)
if not pm.is_installed("faiss"): if not pm.is_installed("faiss"):
pm.install("faiss") pm.install("faiss")
import faiss # type: ignore import faiss # type: ignore
from threading import Lock as ThreadLock
@final @final
@dataclass @dataclass
@@ -51,35 +44,29 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# Embedding dimension (e.g. 768) must match your embedding function # Embedding dimension (e.g. 768) must match your embedding function
self._dim = self.embedding_func.embedding_dim self._dim = self.embedding_func.embedding_dim
self._storage_lock = get_storage_lock() self._storage_lock = ThreadLock()
# check need_init must before get_namespace_object/get_namespace_data # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
need_init = try_initialize_namespace("faiss_indices") # If you have a large number of vectors, you might want IVF or other indexes.
self._index = get_namespace_object("faiss_indices") # For demonstration, we use a simple IndexFlatIP.
self._id_to_meta = get_namespace_data("faiss_meta") self._index = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta = {}
# Attempt to load an existing index + metadata from disk
with self._storage_lock:
self._load_faiss_index()
if need_init:
if is_multiprocess:
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
# If you have a large number of vectors, you might want IVF or other indexes.
# For demonstration, we use a simple IndexFlatIP.
self._index.value = faiss.IndexFlatIP(self._dim)
# Keep a local store for metadata, IDs, etc.
# Maps <int faiss_id> → metadata (including your original ID).
self._id_to_meta.update({})
# Attempt to load an existing index + metadata from disk
self._load_faiss_index()
else:
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta.update({})
self._load_faiss_index()
def _get_index(self): def _get_index(self):
""" """Check if the shtorage should be reloaded"""
Helper method to get the correct index object based on multiprocess mode. return self._index
Returns the actual index object that can be used for operations.
""" async def index_done_callback(self) -> None:
return self._index.value if is_multiprocess else self._index with self._storage_lock:
self._save_faiss_index()
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
@@ -134,34 +121,33 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Normalize embeddings for cosine similarity (in-place) # Normalize embeddings for cosine similarity (in-place)
faiss.normalize_L2(embeddings) faiss.normalize_L2(embeddings)
with self._storage_lock: # Upsert logic:
# Upsert logic: # 1. Identify which vectors to remove if they exist
# 1. Identify which vectors to remove if they exist # 2. Remove them
# 2. Remove them # 3. Add the new vectors
# 3. Add the new vectors existing_ids_to_remove = []
existing_ids_to_remove = [] for meta, emb in zip(list_data, embeddings):
for meta, emb in zip(list_data, embeddings): faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"]) if faiss_internal_id is not None:
if faiss_internal_id is not None: existing_ids_to_remove.append(faiss_internal_id)
existing_ids_to_remove.append(faiss_internal_id)
if existing_ids_to_remove: if existing_ids_to_remove:
self._remove_faiss_ids(existing_ids_to_remove) self._remove_faiss_ids(existing_ids_to_remove)
# Step 2: Add new vectors # Step 2: Add new vectors
index = self._get_index() index = self._get_index()
start_idx = index.ntotal start_idx = index.ntotal
index.add(embeddings) index.add(embeddings)
# Step 3: Store metadata + vector for each new ID # Step 3: Store metadata + vector for each new ID
for i, meta in enumerate(list_data): for i, meta in enumerate(list_data):
fid = start_idx + i fid = start_idx + i
# Store the raw vector so we can rebuild if something is removed # Store the raw vector so we can rebuild if something is removed
meta["__vector__"] = embeddings[i].tolist() meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta.update({fid: meta}) self._id_to_meta.update({fid: meta})
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.") logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
return [m["__id__"] for m in list_data] return [m["__id__"] for m in list_data]
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]: async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
""" """
@@ -177,57 +163,54 @@ class FaissVectorDBStorage(BaseVectorStorage):
) )
# Perform the similarity search # Perform the similarity search
with self._storage_lock: distances, indices = self._get_index().search(embedding, top_k)
distances, indices = self._get_index().search(embedding, top_k)
distances = distances[0] distances = distances[0]
indices = indices[0] indices = indices[0]
results = [] results = []
for dist, idx in zip(distances, indices): for dist, idx in zip(distances, indices):
if idx == -1: if idx == -1:
# Faiss returns -1 if no neighbor # Faiss returns -1 if no neighbor
continue continue
# Cosine similarity threshold # Cosine similarity threshold
if dist < self.cosine_better_than_threshold: if dist < self.cosine_better_than_threshold:
continue continue
meta = self._id_to_meta.get(idx, {}) meta = self._id_to_meta.get(idx, {})
results.append( results.append(
{ {
**meta, **meta,
"id": meta.get("__id__"), "id": meta.get("__id__"),
"distance": float(dist), "distance": float(dist),
"created_at": meta.get("__created_at__"), "created_at": meta.get("__created_at__"),
} }
) )
return results return results
@property @property
def client_storage(self): def client_storage(self):
# Return whatever structure LightRAG might need for debugging # Return whatever structure LightRAG might need for debugging
with self._storage_lock: return {"data": list(self._id_to_meta.values())}
return {"data": list(self._id_to_meta.values())}
async def delete(self, ids: list[str]): async def delete(self, ids: list[str]):
""" """
Delete vectors for the provided custom IDs. Delete vectors for the provided custom IDs.
""" """
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}") logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
with self._storage_lock: to_remove = []
to_remove = [] for cid in ids:
for cid in ids: fid = self._find_faiss_id_by_custom_id(cid)
fid = self._find_faiss_id_by_custom_id(cid) if fid is not None:
if fid is not None: to_remove.append(fid)
to_remove.append(fid)
if to_remove: if to_remove:
self._remove_faiss_ids(to_remove) self._remove_faiss_ids(to_remove)
logger.debug( logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}" f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
) )
async def delete_entity(self, entity_name: str) -> None: async def delete_entity(self, entity_name: str) -> None:
entity_id = compute_mdhash_id(entity_name, prefix="ent-") entity_id = compute_mdhash_id(entity_name, prefix="ent-")
@@ -239,23 +222,18 @@ class FaissVectorDBStorage(BaseVectorStorage):
Delete relations for a given entity by scanning metadata. Delete relations for a given entity by scanning metadata.
""" """
logger.debug(f"Searching relations for entity {entity_name}") logger.debug(f"Searching relations for entity {entity_name}")
with self._storage_lock: relations = []
relations = [] for fid, meta in self._id_to_meta.items():
for fid, meta in self._id_to_meta.items(): if (
if ( meta.get("src_id") == entity_name
meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name
or meta.get("tgt_id") == entity_name ):
): relations.append(fid)
relations.append(fid)
logger.debug(f"Found {len(relations)} relations for {entity_name}") logger.debug(f"Found {len(relations)} relations for {entity_name}")
if relations: if relations:
self._remove_faiss_ids(relations) self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}") logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
async def index_done_callback(self) -> None:
with self._storage_lock:
self._save_faiss_index()
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
# Internal helper methods # Internal helper methods
@@ -265,11 +243,10 @@ class FaissVectorDBStorage(BaseVectorStorage):
""" """
Return the Faiss internal ID for a given custom ID, or None if not found. Return the Faiss internal ID for a given custom ID, or None if not found.
""" """
with self._storage_lock: for fid, meta in self._id_to_meta.items():
for fid, meta in self._id_to_meta.items(): if meta.get("__id__") == custom_id:
if meta.get("__id__") == custom_id: return fid
return fid return None
return None
def _remove_faiss_ids(self, fid_list): def _remove_faiss_ids(self, fid_list):
""" """
@@ -277,48 +254,42 @@ class FaissVectorDBStorage(BaseVectorStorage):
Because IndexFlatIP doesn't support 'removals', Because IndexFlatIP doesn't support 'removals',
we rebuild the index excluding those vectors. we rebuild the index excluding those vectors.
""" """
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
# Rebuild the index
vectors_to_keep = []
new_id_to_meta = {}
for new_fid, old_fid in enumerate(keep_fids):
vec_meta = self._id_to_meta[old_fid]
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
with self._storage_lock: with self._storage_lock:
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list] # Re-init index
self._index = faiss.IndexFlatIP(self._dim)
# Rebuild the index
vectors_to_keep = []
new_id_to_meta = {}
for new_fid, old_fid in enumerate(keep_fids):
vec_meta = self._id_to_meta[old_fid]
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
new_id_to_meta[new_fid] = vec_meta
# Re-init index
new_index = faiss.IndexFlatIP(self._dim)
if vectors_to_keep: if vectors_to_keep:
arr = np.array(vectors_to_keep, dtype=np.float32) arr = np.array(vectors_to_keep, dtype=np.float32)
new_index.add(arr) self._index.add(arr)
if is_multiprocess:
self._index.value = new_index self._id_to_meta = new_id_to_meta
else:
self._index = new_index
self._id_to_meta.update(new_id_to_meta)
def _save_faiss_index(self): def _save_faiss_index(self):
""" """
Save the current Faiss index + metadata to disk so it can persist across runs. Save the current Faiss index + metadata to disk so it can persist across runs.
""" """
with self._storage_lock: faiss.write_index(self._index, self._faiss_index_file)
faiss.write_index(
self._get_index(),
self._faiss_index_file,
)
# Save metadata dict to JSON. Convert all keys to strings for JSON storage. # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } } # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
# We'll keep the int -> dict, but JSON requires string keys. # We'll keep the int -> dict, but JSON requires string keys.
serializable_dict = {} serializable_dict = {}
for fid, meta in self._id_to_meta.items(): for fid, meta in self._id_to_meta.items():
serializable_dict[str(fid)] = meta serializable_dict[str(fid)] = meta
with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f)
with open(self._meta_file, "w", encoding="utf-8") as f:
json.dump(serializable_dict, f)
def _load_faiss_index(self): def _load_faiss_index(self):
""" """
@@ -331,31 +302,22 @@ class FaissVectorDBStorage(BaseVectorStorage):
try: try:
# Load the Faiss index # Load the Faiss index
loaded_index = faiss.read_index(self._faiss_index_file) self._index = faiss.read_index(self._faiss_index_file)
if is_multiprocess:
self._index.value = loaded_index
else:
self._index = loaded_index
# Load metadata # Load metadata
with open(self._meta_file, "r", encoding="utf-8") as f: with open(self._meta_file, "r", encoding="utf-8") as f:
stored_dict = json.load(f) stored_dict = json.load(f)
# Convert string keys back to int # Convert string keys back to int
self._id_to_meta.update({}) self._id_to_meta = {}
for fid_str, meta in stored_dict.items(): for fid_str, meta in stored_dict.items():
fid = int(fid_str) fid = int(fid_str)
self._id_to_meta[fid] = meta self._id_to_meta[fid] = meta
logger.info( logger.info(
f"Faiss index loaded with {loaded_index.ntotal} vectors from {self._faiss_index_file}" f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}") logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.") logger.warning("Starting with an empty Faiss index.")
new_index = faiss.IndexFlatIP(self._dim) self._index = faiss.IndexFlatIP(self._dim)
if is_multiprocess: self._id_to_meta = {}
self._index.value = new_index
else:
self._index = new_index
self._id_to_meta.update({})

View File

@@ -11,25 +11,19 @@ from lightrag.utils import (
) )
import pipmaster as pm import pipmaster as pm
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from .shared_storage import (
get_storage_lock,
get_namespace_object,
is_multiprocess,
try_initialize_namespace,
)
if not pm.is_installed("nano-vectordb"): if not pm.is_installed("nano-vectordb"):
pm.install("nano-vectordb") pm.install("nano-vectordb")
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
from threading import Lock as ThreadLock
@final @final
@dataclass @dataclass
class NanoVectorDBStorage(BaseVectorStorage): class NanoVectorDBStorage(BaseVectorStorage):
def __post_init__(self): def __post_init__(self):
# Initialize lock only for file operations # Initialize lock only for file operations
self._storage_lock = get_storage_lock() self._storage_lock = ThreadLock()
# Use global config value if specified, otherwise use default # Use global config value if specified, otherwise use default
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -45,32 +39,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
# check need_init must before get_namespace_object with self._storage_lock:
need_init = try_initialize_namespace(self.namespace) self._client = NanoVectorDB(
self._client = get_namespace_object(self.namespace) self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
if need_init: )
if is_multiprocess:
self._client.value = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
)
else:
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
storage_file=self._client_file_name,
)
logger.info(
f"Initialized vector DB client for namespace {self.namespace}"
)
def _get_client(self): def _get_client(self):
"""Get the appropriate client instance based on multiprocess mode""" """Check if the shtorage should be reloaded"""
if is_multiprocess:
return self._client.value
return self._client return self._client
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
@@ -101,8 +77,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
if len(embeddings) == len(list_data): if len(embeddings) == len(list_data):
for i, d in enumerate(list_data): for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i] d["__vector__"] = embeddings[i]
with self._storage_lock: results = self._get_client().upsert(datas=list_data)
results = self._get_client().upsert(datas=list_data)
return results return results
else: else:
# sometimes the embedding is not returned correctly. just log it. # sometimes the embedding is not returned correctly. just log it.
@@ -115,21 +90,20 @@ class NanoVectorDBStorage(BaseVectorStorage):
embedding = await self.embedding_func([query]) embedding = await self.embedding_func([query])
embedding = embedding[0] embedding = embedding[0]
with self._storage_lock: results = self._get_client().query(
results = self._get_client().query( query=embedding,
query=embedding, top_k=top_k,
top_k=top_k, better_than_threshold=self.cosine_better_than_threshold,
better_than_threshold=self.cosine_better_than_threshold, )
) results = [
results = [ {
{ **dp,
**dp, "id": dp["__id__"],
"id": dp["__id__"], "distance": dp["__metrics__"],
"distance": dp["__metrics__"], "created_at": dp.get("__created_at__"),
"created_at": dp.get("__created_at__"), }
} for dp in results
for dp in results ]
]
return results return results
@property @property
@@ -143,8 +117,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
ids: List of vector IDs to be deleted ids: List of vector IDs to be deleted
""" """
try: try:
with self._storage_lock: self._get_client().delete(ids)
self._get_client().delete(ids)
logger.debug( logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}" f"Successfully deleted {len(ids)} vectors from {self.namespace}"
) )
@@ -158,37 +131,35 @@ class NanoVectorDBStorage(BaseVectorStorage):
f"Attempting to delete entity {entity_name} with ID {entity_id}" f"Attempting to delete entity {entity_name} with ID {entity_id}"
) )
with self._storage_lock: # Check if the entity exists
# Check if the entity exists if self._get_client().get([entity_id]):
if self._get_client().get([entity_id]): self._get_client().delete([entity_id])
self._get_client().delete([entity_id]) logger.debug(f"Successfully deleted entity {entity_name}")
logger.debug(f"Successfully deleted entity {entity_name}") else:
else: logger.debug(f"Entity {entity_name} not found in storage")
logger.debug(f"Entity {entity_name} not found in storage")
except Exception as e: except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}") logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None: async def delete_entity_relation(self, entity_name: str) -> None:
try: try:
with self._storage_lock: storage = getattr(self._get_client(), "_NanoVectorDB__storage")
storage = getattr(self._get_client(), "_NanoVectorDB__storage") relations = [
relations = [ dp
dp for dp in storage["data"]
for dp in storage["data"] if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name ]
] logger.debug(
logger.debug( f"Found {len(relations)} relations for entity {entity_name}"
f"Found {len(relations)} relations for entity {entity_name}" )
) ids_to_delete = [relation["__id__"] for relation in relations]
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete: if ids_to_delete:
self._get_client().delete(ids_to_delete) self._get_client().delete(ids_to_delete)
logger.debug( logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}" f"Deleted {len(ids_to_delete)} relations for {entity_name}"
) )
else: else:
logger.debug(f"No relations found for entity {entity_name}") logger.debug(f"No relations found for entity {entity_name}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}") logger.error(f"Error deleting relations for {entity_name}: {e}")

View File

@@ -6,12 +6,6 @@ import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseGraphStorage from lightrag.base import BaseGraphStorage
from .shared_storage import (
get_storage_lock,
get_namespace_object,
is_multiprocess,
try_initialize_namespace,
)
import pipmaster as pm import pipmaster as pm
@@ -23,7 +17,7 @@ if not pm.is_installed("graspologic"):
import networkx as nx import networkx as nx
from graspologic import embed from graspologic import embed
from threading import Lock as ThreadLock
@final @final
@dataclass @dataclass
@@ -78,38 +72,23 @@ class NetworkXStorage(BaseGraphStorage):
self._graphml_xml_file = os.path.join( self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml" self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
) )
self._storage_lock = get_storage_lock() self._storage_lock = ThreadLock()
# check need_init must before get_namespace_object
need_init = try_initialize_namespace(self.namespace)
self._graph = get_namespace_object(self.namespace)
if need_init:
if is_multiprocess:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
self._graph.value = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
self._graph = preloaded_graph or nx.Graph()
if preloaded_graph:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
with self._storage_lock:
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph") logger.info("Created new empty graph")
self._graph = preloaded_graph or nx.Graph()
self._node_embed_algorithms = { self._node_embed_algorithms = {
"node2vec": self._node2vec_embed, "node2vec": self._node2vec_embed,
} }
def _get_graph(self): def _get_graph(self):
"""Get the appropriate graph instance based on multiprocess mode""" """Check if the shtorage should be reloaded"""
if is_multiprocess:
return self._graph.value
return self._graph return self._graph
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
@@ -117,54 +96,44 @@ class NetworkXStorage(BaseGraphStorage):
NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file) NetworkXStorage.write_nx_graph(self._get_graph(), self._graphml_xml_file)
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
with self._storage_lock: return self._get_graph().has_node(node_id)
return self._get_graph().has_node(node_id)
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
with self._storage_lock: return self._get_graph().has_edge(source_node_id, target_node_id)
return self._get_graph().has_edge(source_node_id, target_node_id)
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
with self._storage_lock: return self._get_graph().nodes.get(node_id)
return self._get_graph().nodes.get(node_id)
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
with self._storage_lock: return self._get_graph().degree(node_id)
return self._get_graph().degree(node_id)
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
with self._storage_lock: return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
return self._get_graph().degree(src_id) + self._get_graph().degree(tgt_id)
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
with self._storage_lock: return self._get_graph().edges.get((source_node_id, target_node_id))
return self._get_graph().edges.get((source_node_id, target_node_id))
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
with self._storage_lock: if self._get_graph().has_node(source_node_id):
if self._get_graph().has_node(source_node_id): return list(self._get_graph().edges(source_node_id))
return list(self._get_graph().edges(source_node_id)) return None
return None
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
with self._storage_lock: self._get_graph().add_node(node_id, **node_data)
self._get_graph().add_node(node_id, **node_data)
async def upsert_edge( async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None: ) -> None:
with self._storage_lock: self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
self._get_graph().add_edge(source_node_id, target_node_id, **edge_data)
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
with self._storage_lock: if self._get_graph().has_node(node_id):
if self._get_graph().has_node(node_id): self._get_graph().remove_node(node_id)
self._get_graph().remove_node(node_id) logger.debug(f"Node {node_id} deleted from the graph.")
logger.debug(f"Node {node_id} deleted from the graph.") else:
else: logger.warning(f"Node {node_id} not found in the graph for deletion.")
logger.warning(f"Node {node_id} not found in the graph for deletion.")
async def embed_nodes( async def embed_nodes(
self, algorithm: str self, algorithm: str
@@ -175,13 +144,12 @@ class NetworkXStorage(BaseGraphStorage):
# TODO: NOT USED # TODO: NOT USED
async def _node2vec_embed(self): async def _node2vec_embed(self):
with self._storage_lock: graph = self._get_graph()
graph = self._get_graph() embeddings, nodes = embed.node2vec_embed(
embeddings, nodes = embed.node2vec_embed( graph,
graph, **self.global_config["node2vec_params"],
**self.global_config["node2vec_params"], )
) nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids return embeddings, nodes_ids
def remove_nodes(self, nodes: list[str]): def remove_nodes(self, nodes: list[str]):
@@ -190,11 +158,10 @@ class NetworkXStorage(BaseGraphStorage):
Args: Args:
nodes: List of node IDs to be deleted nodes: List of node IDs to be deleted
""" """
with self._storage_lock: graph = self._get_graph()
graph = self._get_graph() for node in nodes:
for node in nodes: if graph.has_node(node):
if graph.has_node(node): graph.remove_node(node)
graph.remove_node(node)
def remove_edges(self, edges: list[tuple[str, str]]): def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges """Delete multiple edges
@@ -202,11 +169,10 @@ class NetworkXStorage(BaseGraphStorage):
Args: Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple edges: List of edges to be deleted, each edge is a (source, target) tuple
""" """
with self._storage_lock: graph = self._get_graph()
graph = self._get_graph() for source, target in edges:
for source, target in edges: if graph.has_edge(source, target):
if graph.has_edge(source, target): graph.remove_edge(source, target)
graph.remove_edge(source, target)
async def get_all_labels(self) -> list[str]: async def get_all_labels(self) -> list[str]:
""" """
@@ -214,10 +180,9 @@ class NetworkXStorage(BaseGraphStorage):
Returns: Returns:
[label1, label2, ...] # Alphabetically sorted label list [label1, label2, ...] # Alphabetically sorted label list
""" """
with self._storage_lock: labels = set()
labels = set() for node in self._get_graph().nodes():
for node in self._get_graph().nodes(): labels.add(str(node)) # Add node id as a label
labels.add(str(node)) # Add node id as a label
# Return sorted list # Return sorted list
return sorted(list(labels)) return sorted(list(labels))
@@ -239,88 +204,87 @@ class NetworkXStorage(BaseGraphStorage):
seen_nodes = set() seen_nodes = set()
seen_edges = set() seen_edges = set()
with self._storage_lock: graph = self._get_graph()
graph = self._get_graph()
# Handle special case for "*" label # Handle special case for "*" label
if node_label == "*": if node_label == "*":
# For "*", return the entire graph including all nodes and edges # For "*", return the entire graph including all nodes and edges
subgraph = ( subgraph = (
graph.copy() graph.copy()
) # Create a copy to avoid modifying the original graph ) # Create a copy to avoid modifying the original graph
else: else:
# Find nodes with matching node id (partial match) # Find nodes with matching node id (partial match)
nodes_to_explore = [] nodes_to_explore = []
for n, attr in graph.nodes(data=True): for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching if node_label in str(n): # Use partial matching
nodes_to_explore.append(n) nodes_to_explore.append(n)
if not nodes_to_explore: if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}") logger.warning(f"No nodes found with label {node_label}")
return result return result
# Get subgraph using ego_graph # Get subgraph using ego_graph
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth) subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500 max_graph_nodes = 500
if len(subgraph.nodes()) > max_graph_nodes: if len(subgraph.nodes()) > max_graph_nodes:
origin_nodes = len(subgraph.nodes()) origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree()) node_degrees = dict(subgraph.degree())
top_nodes = sorted( top_nodes = sorted(
node_degrees.items(), key=lambda x: x[1], reverse=True node_degrees.items(), key=lambda x: x[1], reverse=True
)[:max_graph_nodes] )[:max_graph_nodes]
top_node_ids = [node[0] for node in top_nodes] top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes # Create new subgraph with only top nodes
subgraph = subgraph.subgraph(top_node_ids) subgraph = subgraph.subgraph(top_node_ids)
logger.info( logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})" f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
)
# Add nodes to result
for node in subgraph.nodes():
if str(node) in seen_nodes:
continue
node_data = dict(subgraph.nodes[node])
# Get entity_type as labels
labels = []
if "entity_type" in node_data:
if isinstance(node_data["entity_type"], list):
labels.extend(node_data["entity_type"])
else:
labels.append(node_data["entity_type"])
# Create node with properties
node_properties = {k: v for k, v in node_data.items()}
result.nodes.append(
KnowledgeGraphNode(
id=str(node), labels=[str(node)], properties=node_properties
) )
)
seen_nodes.add(str(node))
# Add nodes to result # Add edges to result
for node in subgraph.nodes(): for edge in subgraph.edges():
if str(node) in seen_nodes: source, target = edge
continue edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue
node_data = dict(subgraph.nodes[node]) edge_data = dict(subgraph.edges[edge])
# Get entity_type as labels
labels = []
if "entity_type" in node_data:
if isinstance(node_data["entity_type"], list):
labels.extend(node_data["entity_type"])
else:
labels.append(node_data["entity_type"])
# Create node with properties # Create edge with complete information
node_properties = {k: v for k, v in node_data.items()} result.edges.append(
KnowledgeGraphEdge(
result.nodes.append( id=edge_id,
KnowledgeGraphNode( type="DIRECTED",
id=str(node), labels=[str(node)], properties=node_properties source=str(source),
) target=str(target),
properties=edge_data,
) )
seen_nodes.add(str(node)) )
seen_edges.add(edge_id)
# Add edges to result
for edge in subgraph.edges():
source, target = edge
edge_id = f"{source}-{target}"
if edge_id in seen_edges:
continue
edge_data = dict(subgraph.edges[edge])
# Create edge with complete information
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=str(source),
target=str(target),
properties=edge_data,
)
)
seen_edges.add(edge_id)
logger.info( logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"

View File

@@ -20,15 +20,12 @@ LockType = Union[ProcessLock, ThreadLock]
_manager = None _manager = None
_initialized = None _initialized = None
is_multiprocess = None is_multiprocess = None
_global_lock: Optional[LockType] = None
# shared data for storage across processes # shared data for storage across processes
_shared_dicts: Optional[Dict[str, Any]] = None _shared_dicts: Optional[Dict[str, Any]] = None
_share_objects: Optional[Dict[str, Any]] = None
_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
_global_lock: Optional[LockType] = None
def initialize_share_data(workers: int = 1): def initialize_share_data(workers: int = 1):
""" """
Initialize shared storage data for single or multi-process mode. Initialize shared storage data for single or multi-process mode.
@@ -53,7 +50,6 @@ def initialize_share_data(workers: int = 1):
is_multiprocess, \ is_multiprocess, \
_global_lock, \ _global_lock, \
_shared_dicts, \ _shared_dicts, \
_share_objects, \
_init_flags, \ _init_flags, \
_initialized _initialized
@@ -72,7 +68,6 @@ def initialize_share_data(workers: int = 1):
_global_lock = _manager.Lock() _global_lock = _manager.Lock()
# Create shared dictionaries with manager # Create shared dictionaries with manager
_shared_dicts = _manager.dict() _shared_dicts = _manager.dict()
_share_objects = _manager.dict()
_init_flags = ( _init_flags = (
_manager.dict() _manager.dict()
) # Use shared dictionary to store initialization flags ) # Use shared dictionary to store initialization flags
@@ -83,7 +78,6 @@ def initialize_share_data(workers: int = 1):
is_multiprocess = False is_multiprocess = False
_global_lock = ThreadLock() _global_lock = ThreadLock()
_shared_dicts = {} _shared_dicts = {}
_share_objects = {}
_init_flags = {} _init_flags = {}
direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") direct_log(f"Process {os.getpid()} Shared-Data created for Single Process")
@@ -99,11 +93,7 @@ def try_initialize_namespace(namespace: str) -> bool:
global _init_flags, _manager global _init_flags, _manager
if _init_flags is None: if _init_flags is None:
direct_log( raise ValueError("Try to create nanmespace before Shared-Data is initialized")
f"Error: try to create nanmespace before Shared-Data is initialized, pid={os.getpid()}",
level="ERROR",
)
raise ValueError("Shared dictionaries not initialized")
if namespace not in _init_flags: if namespace not in _init_flags:
_init_flags[namespace] = True _init_flags[namespace] = True
@@ -113,43 +103,9 @@ def try_initialize_namespace(namespace: str) -> bool:
return False return False
def _get_global_lock() -> LockType:
return _global_lock
def get_storage_lock() -> LockType: def get_storage_lock() -> LockType:
"""return storage lock for data consistency""" """return storage lock for data consistency"""
return _get_global_lock() return _global_lock
def get_scan_lock() -> LockType:
"""return scan_progress lock for data consistency"""
return get_storage_lock()
def get_namespace_object(namespace: str) -> Any:
"""Get an object for specific namespace"""
if _share_objects is None:
direct_log(
f"Error: try to getnanmespace before Shared-Data is initialized, pid={os.getpid()}",
level="ERROR",
)
raise ValueError("Shared dictionaries not initialized")
lock = _get_global_lock()
with lock:
if namespace not in _share_objects:
if namespace not in _share_objects:
if is_multiprocess:
_share_objects[namespace] = _manager.Value("O", None)
else:
_share_objects[namespace] = None
direct_log(
f"Created namespace: {namespace}(type={type(_share_objects[namespace])})"
)
return _share_objects[namespace]
def get_namespace_data(namespace: str) -> Dict[str, Any]: def get_namespace_data(namespace: str) -> Dict[str, Any]:
@@ -161,7 +117,7 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
) )
raise ValueError("Shared dictionaries not initialized") raise ValueError("Shared dictionaries not initialized")
lock = _get_global_lock() lock = get_storage_lock()
with lock: with lock:
if namespace not in _shared_dicts: if namespace not in _shared_dicts:
if is_multiprocess and _manager is not None: if is_multiprocess and _manager is not None:
@@ -175,11 +131,6 @@ def get_namespace_data(namespace: str) -> Dict[str, Any]:
return _shared_dicts[namespace] return _shared_dicts[namespace]
def get_scan_progress() -> Dict[str, Any]:
"""get storage space for document scanning progress data"""
return get_namespace_data("scan_progress")
def finalize_share_data(): def finalize_share_data():
""" """
Release shared resources and clean up. Release shared resources and clean up.
@@ -195,7 +146,6 @@ def finalize_share_data():
is_multiprocess, \ is_multiprocess, \
_global_lock, \ _global_lock, \
_shared_dicts, \ _shared_dicts, \
_share_objects, \
_init_flags, \ _init_flags, \
_initialized _initialized
@@ -216,8 +166,6 @@ def finalize_share_data():
# Clear shared dictionaries first # Clear shared dictionaries first
if _shared_dicts is not None: if _shared_dicts is not None:
_shared_dicts.clear() _shared_dicts.clear()
if _share_objects is not None:
_share_objects.clear()
if _init_flags is not None: if _init_flags is not None:
_init_flags.clear() _init_flags.clear()
@@ -234,7 +182,6 @@ def finalize_share_data():
_initialized = None _initialized = None
is_multiprocess = None is_multiprocess = None
_shared_dicts = None _shared_dicts = None
_share_objects = None
_init_flags = None _init_flags = None
_global_lock = None _global_lock = None