GremlinStorage: fixes and patch to support other Gremlin compatible BD. Tested on ArcadeDB with Gremlin plugin. The main change is using "entity_name" vertex property instead of the label as a node_id since different implementations have different restrictions on label names.

This commit is contained in:
Alex Potapenko
2024-12-23 16:16:17 +01:00
parent bfacfb975e
commit 848b3f6e33

View File

@@ -2,7 +2,6 @@ import asyncio
import inspect import inspect
import json import json
import os import os
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
@@ -27,9 +26,6 @@ class GremlinStorage(BaseGraphStorage):
def load_nx_graph(file_name): def load_nx_graph(file_name):
print("no preloading of graph with Gremlin in production") print("no preloading of graph with Gremlin in production")
# Will use this to make sure single quotes are properly escaped
escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'")
def __init__(self, namespace, global_config, embedding_func): def __init__(self, namespace, global_config, embedding_func):
super().__init__( super().__init__(
namespace=namespace, namespace=namespace,
@@ -51,12 +47,8 @@ class GremlinStorage(BaseGraphStorage):
# All vertices will have graph={GRAPH} property, so that we can # All vertices will have graph={GRAPH} property, so that we can
# have several logical graphs for one source # have several logical graphs for one source
GRAPH = GremlinStorage.escape_rx.sub( GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
r"\1\2'",
os.environ["GREMLIN_GRAPH"].replace("'", r"\'"),
)
self.traverse_source_name = SOURCE
self.graph_name = GRAPH self.graph_name = GRAPH
self._driver = client.Client( self._driver = client.Client(
@@ -87,7 +79,7 @@ class GremlinStorage(BaseGraphStorage):
@staticmethod @staticmethod
def _to_value_map(value: Any) -> str: def _to_value_map(value: Any) -> str:
"""Dump Python dict as Gremlin valueMap""" """Dump supported Python object as Gremlin valueMap"""
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False) json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
parsed_str = json_str.replace("'", r"\'") parsed_str = json_str.replace("'", r"\'")
@@ -122,17 +114,16 @@ class GremlinStorage(BaseGraphStorage):
"""Create chained .property() commands from properties dict""" """Create chained .property() commands from properties dict"""
props = [] props = []
for k, v in properties.items(): for k, v in properties.items():
prop_name = GremlinStorage.escape_rx.sub(r"\1\2'", k.replace("'", r"\'")) prop_name = GremlinStorage._to_value_map(k)
props.append(f".property('{prop_name}', {GremlinStorage._to_value_map(v)})") props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
return "".join(props) return "".join(props)
@staticmethod @staticmethod
def _fix_label(label: str) -> str: def _fix_name(name: str) -> str:
"""Strip double quotes and make sure single quotes are escaped""" """Strip double quotes and format as a proper field name"""
label = label.strip('"').replace("'", r"\'") name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))
label = GremlinStorage.escape_rx.sub(r"\1\2'", label)
return label return name
async def _query(self, query: str) -> List[Dict[str, Any]]: async def _query(self, query: str) -> List[Dict[str, Any]]:
""" """
@@ -146,66 +137,69 @@ class GremlinStorage(BaseGraphStorage):
""" """
result = list(await asyncio.wrap_future(self._driver.submit_async(query))) result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
if result:
result = result[0]
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = GremlinStorage._fix_label(node_id) entity_name = GremlinStorage._fix_name(node_id)
query = f""" query = f"""g
{self.traverse_source_name} .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {entity_name})
.hasLabel('{entity_name_label}')
.limit(1) .limit(1)
.hasNext() .count()
.project('has_node')
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
""" """
result = await self._query(query) result = await self._query(query)
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query, query,
result[0][0], result[0]["has_node"],
) )
return result[0][0] return result[0]["has_node"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = GremlinStorage._fix_label(source_node_id) entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_label_target = GremlinStorage._fix_label(target_node_id) entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f""" query = f"""g
{self.traverse_source_name} .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {entity_name_source})
.hasLabel('{entity_name_label_source}') .outE()
.bothE() .inV().has('graph', {self.graph_name})
.otherV().has('graph', '{self.graph_name}') .has('entity_name', {entity_name_target})
.hasLabel('{entity_name_label_target}')
.limit(1) .limit(1)
.hasNext() .count()
.project('has_edge')
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
""" """
result = await self._query(query) result = await self._query(query)
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query, query,
result[0][0], result[0]["has_edge"],
) )
return result[0][0] return result[0]["has_edge"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = GremlinStorage._fix_label(node_id) entity_name = GremlinStorage._fix_name(node_id)
query = f""" query = f"""g
{self.traverse_source_name} .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {entity_name})
.hasLabel('{entity_name_label}')
.limit(1) .limit(1)
.project('properties') .project('properties')
.by(elementMap()) .by(elementMap())
""" """
result = await self._query(query) result = await self._query(query)
if result: if result:
node = result[0][0] node = result[0]
node_dict = node["properties"] node_dict = node["properties"]
logger.debug( logger.debug(
"{%s}: query: {%s}, result: {%s}", "{%s}: query: {%s}, result: {%s}",
@@ -216,19 +210,18 @@ class GremlinStorage(BaseGraphStorage):
return node_dict return node_dict
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
entity_name_label = GremlinStorage._fix_label(node_id) entity_name = GremlinStorage._fix_name(node_id)
query = f""" query = f"""g
{self.traverse_source_name} .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {entity_name})
.hasLabel('{entity_name_label}')
.outE() .outE()
.inV().has('graph', '{self.graph_name}') .inV().has('graph', {self.graph_name})
.count() .count()
.project('total_edge_count') .project('total_edge_count')
.by() .by()
""" """
result = await self._query(query) result = await self._query(query)
edge_count = result[0][0]["total_edge_count"] edge_count = result[0]["total_edge_count"]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
@@ -259,31 +252,30 @@ class GremlinStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> Union[dict, None]: ) -> Union[dict, None]:
""" """
Find all edges between nodes of two given labels Find all edges between nodes of two given names
Args: Args:
source_node_label (str): Label of the source nodes source_node_id (str): Name of the source nodes
target_node_label (str): Label of the target nodes target_node_id (str): Name of the target nodes
Returns: Returns:
dict|None: Dict of found edge properties, or None of not found dict|None: Dict of found edge properties, or None if not found
""" """
entity_name_label_source = GremlinStorage._fix_label(source_node_id) entity_name_source = GremlinStorage._fix_name(source_node_id)
entity_name_label_target = GremlinStorage._fix_label(target_node_id) entity_name_target = GremlinStorage._fix_name(target_node_id)
query = f""" query = f"""g
{self.traverse_source_name} .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {entity_name_source})
.hasLabel('{entity_name_label_source}')
.outE() .outE()
.inV().has('graph', '{self.graph_name}') .inV().has('graph', {self.graph_name})
.hasLabel('{entity_name_label_target}') .has('entity_name', {entity_name_target})
.limit(1) .limit(1)
.project('edge_properties') .project('edge_properties')
.by(__.bothE().elementMap()) .by(__.bothE().elementMap())
""" """
result = await self._query(query) result = await self._query(query)
if result: if result:
edge_properties = result[0][0]["edge_properties"] edge_properties = result[0]["edge_properties"]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
@@ -294,45 +286,31 @@ class GremlinStorage(BaseGraphStorage):
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]: async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
""" """
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its name.
:return: List of tuples containing edge sources and targets :return: List of tuples containing edge sources and targets
""" """
node_label = GremlinStorage._fix_label(source_node_id) node_name = GremlinStorage._fix_name(source_node_id)
query1 = f""" query = f"""g
{self.traverse_source_name} .E()
.V().has('graph', '{self.graph_name}') .filter(
.hasLabel('{node_label}') __.or(
.out().has('graph', '{self.graph_name}') __.outV().has('graph', {self.graph_name})
.project('connected_label') .has('entity_name', {node_name}),
.by(__.label()) __.inV().has('graph', {self.graph_name})
.has('entity_name', {node_name})
)
)
.project('source_name', 'target_name')
.by(__.outV().values('entity_name'))
.by(__.inV().values('entity_name'))
""" """
query2 = f""" result = await self._query(query)
{self.traverse_source_name} edges = [(res["source_name"], res["target_name"]) for res in result]
.V().has('graph', '{self.graph_name}')
.as('connected')
.out().has('graph', '{self.graph_name}')
.hasLabel('{node_label}')
.project('connected_label')
.by(__.select('connected').label())
"""
result1, result2 = await asyncio.gather(
self._query(query1), self._query(query2)
)
edges1 = (
[(node_label, res["connected_label"]) for res in result1[0]]
if result1
else []
)
edges2 = (
[(res["connected_label"], node_label) for res in result2[0]]
if result2
else []
)
return edges1 + edges2 return edges
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)), retry=retry_if_exception_type((GremlinServerError,)),
) )
@@ -341,28 +319,30 @@ class GremlinStorage(BaseGraphStorage):
Upsert a node in the Gremlin graph. Upsert a node in the Gremlin graph.
Args: Args:
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as name)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
label = GremlinStorage._fix_label(node_id) name = GremlinStorage._fix_name(node_id)
properties = GremlinStorage._convert_properties(node_data) properties = GremlinStorage._convert_properties(node_data)
query = f""" query = f"""g
{self.traverse_source_name} .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {name})
.hasLabel('{label}').fold() .fold()
.coalesce( .coalesce(
unfold(), __.unfold(),
addV('{label}')) __.addV('ENTITY')
.property('graph', '{self.graph_name}') .property('graph', {self.graph_name})
.property('entity_name', {name})
)
{properties} {properties}
""" """
try: try:
await self._query(query) await self._query(query)
logger.debug( logger.debug(
"Upserted node with label '{%s}' and properties: {%s}", "Upserted node with name {%s} and properties: {%s}",
label, name,
properties, properties,
) )
except Exception as e: except Exception as e:
@@ -370,7 +350,7 @@ class GremlinStorage(BaseGraphStorage):
raise raise
@retry( @retry(
stop=stop_after_attempt(3), stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10), wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((GremlinServerError,)), retry=retry_if_exception_type((GremlinServerError,)),
) )
@@ -378,36 +358,35 @@ class GremlinStorage(BaseGraphStorage):
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
): ):
""" """
Upsert an edge and its properties between two nodes identified by their labels. Upsert an edge and its properties between two nodes identified by their names.
Args: Args:
source_node_id (str): Label of the source node (used as identifier) source_node_id (str): Name of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Name of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
""" """
source_node_label = GremlinStorage._fix_label(source_node_id) source_node_name = GremlinStorage._fix_name(source_node_id)
target_node_label = GremlinStorage._fix_label(target_node_id) target_node_name = GremlinStorage._fix_name(target_node_id)
edge_properties = GremlinStorage._convert_properties(edge_data) edge_properties = GremlinStorage._convert_properties(edge_data)
query = f""" query = f"""g
{self.traverse_source_name} .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {source_node_name}).as('source')
.hasLabel('{source_node_label}').as('source') .V().has('graph', {self.graph_name})
.V().has('graph', '{self.graph_name}') .has('entity_name', {target_node_name}).as('target')
.hasLabel('{target_node_label}').as('target')
.coalesce( .coalesce(
select('source').outE('DIRECTED').where(inV().as('target')), __.select('source').outE('DIRECTED').where(__.inV().as('target')),
select('source').addE('DIRECTED').to(select('target')) __.select('source').addE('DIRECTED').to(__.select('target'))
) )
.property('graph', '{self.graph_name}') .property('graph', {self.graph_name})
{edge_properties} {edge_properties}
""" """
try: try:
await self._query(query) await self._query(query)
logger.debug( logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}", "Upserted edge from {%s} to {%s} with properties: {%s}",
source_node_label, source_node_name,
target_node_label, target_node_name,
edge_properties, edge_properties,
) )
except Exception as e: except Exception as e: