Remove unused node embedding functionality from graph storage

- Deleted embed_nodes() method implementations
This commit is contained in:
yangdx
2025-04-11 18:34:48 +08:00
parent c084358dc9
commit 83353ab9a6
7 changed files with 2 additions and 141 deletions

View File

@@ -6,7 +6,6 @@ import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Union, final
import numpy as np
import pipmaster as pm
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@@ -668,21 +667,6 @@ class AGEStorage(BaseGraphStorage):
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
"""Embed nodes using the specified algorithm
Args:
algorithm: Name of the embedding algorithm
Returns:
tuple: (embedding matrix, list of node identifiers)
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def get_all_labels(self) -> list[str]:
"""Get all node labels in the database

View File

@@ -6,9 +6,6 @@ import pipmaster as pm
from dataclasses import dataclass
from typing import Any, Dict, List, final
import numpy as np
from tenacity import (
retry,
retry_if_exception_type,
@@ -419,27 +416,6 @@ class GremlinStorage(BaseGraphStorage):
logger.error(f"Error during node deletion: {str(e)}")
raise
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
"""
Embed nodes using the specified algorithm.
Currently, only node2vec is supported but never called.
Args:
algorithm: The name of the embedding algorithm to use
Returns:
A tuple of (embeddings, node_ids)
Raises:
NotImplementedError: If the specified algorithm is not supported
ValueError: If the algorithm is not supported
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
async def get_all_labels(self) -> list[str]:
"""
Get all node entity_names in the graph

View File

@@ -663,20 +663,6 @@ class MongoGraphStorage(BaseGraphStorage):
# Remove the node doc
await self.collection.delete_one({"_id": node_id})
#
# -------------------------------------------------------------------------
# EMBEDDINGS (NOT IMPLEMENTED)
# -------------------------------------------------------------------------
#
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
"""
Placeholder for demonstration, raises NotImplementedError.
"""
raise NotImplementedError("Node embedding is not used in lightrag.")
#
# -------------------------------------------------------------------------
# QUERY

View File

@@ -2,8 +2,7 @@ import inspect
import os
import re
from dataclasses import dataclass
from typing import Any, final
import numpy as np
from typing import final
import configparser
@@ -1126,11 +1125,6 @@ class Neo4JStorage(BaseGraphStorage):
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
raise NotImplementedError
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources

View File

@@ -1,7 +1,6 @@
import os
from dataclasses import dataclass
from typing import Any, final
import numpy as np
from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
@@ -16,7 +15,6 @@ if not pm.is_installed("graspologic"):
pm.install("graspologic")
import networkx as nx
from graspologic import embed
from .shared_storage import (
get_storage_lock,
get_update_flag,
@@ -42,40 +40,6 @@ class NetworkXStorage(BaseGraphStorage):
)
nx.write_graphml(graph, file_name)
# TODOdeprecated, remove later
@staticmethod
def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
Ensure an undirected graph with the same relationships will always be read the same way.
"""
fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
sorted_nodes = graph.nodes(data=True)
sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
fixed_graph.add_nodes_from(sorted_nodes)
edges = list(graph.edges(data=True))
if not graph.is_directed():
def _sort_source_target(edge):
source, target, edge_data = edge
if source > target:
temp = source
source = target
target = temp
return source, target, edge_data
edges = [_sort_source_target(edge) for edge in edges]
def _get_edge_key(source: Any, target: Any) -> str:
return f"{source} -> {target}"
edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
fixed_graph.add_edges_from(edges)
return fixed_graph
def __post_init__(self):
self._graphml_xml_file = os.path.join(
self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
@@ -191,24 +155,6 @@ class NetworkXStorage(BaseGraphStorage):
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
# TODO: NOT USED
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# TODO: NOT USED
async def _node2vec_embed(self):
graph = await self._get_graph()
embeddings, nodes = embed.node2vec_embed(
graph,
**self.global_config["node2vec_params"],
)
nodes_ids = [graph.nodes[node_id]["id"] for node_id in nodes]
return embeddings, nodes_ids
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes

View File

@@ -1485,24 +1485,6 @@ class PGGraphStorage(BaseGraphStorage):
labels = [result["label"] for result in results]
return labels
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
"""
Generate node embeddings using the specified algorithm.
Args:
algorithm (str): The name of the embedding algorithm to use.
Returns:
tuple[np.ndarray[Any, Any], list[str]]: A tuple containing the embeddings and the corresponding node IDs.
"""
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Unsupported embedding algorithm: {algorithm}")
embed_func = self._node_embed_algorithms[algorithm]
return await embed_func()
async def get_knowledge_graph(
self,
node_label: str,

View File

@@ -800,13 +800,6 @@ class TiDBGraphStorage(BaseGraphStorage):
}
await self.db.execute(merge_sql, data)
async def embed_nodes(
self, algorithm: str
) -> tuple[np.ndarray[Any, Any], list[str]]:
if algorithm not in self._node_embed_algorithms:
raise ValueError(f"Node embedding algorithm {algorithm} not supported")
return await self._node_embed_algorithms[algorithm]()
# Query
async def has_node(self, node_id: str) -> bool: