GremlinStorage: fixes and patch to support other Gremlin compatible BD. Tested on ArcadeDB with Gremlin plugin. The main change is using "entity_name" vertex property instead of the label as a node_id since different implementations have different restrictions on label names.
This commit is contained in:
@@ -2,7 +2,6 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
@@ -27,9 +26,6 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
def load_nx_graph(file_name):
|
def load_nx_graph(file_name):
|
||||||
print("no preloading of graph with Gremlin in production")
|
print("no preloading of graph with Gremlin in production")
|
||||||
|
|
||||||
# Will use this to make sure single quotes are properly escaped
|
|
||||||
escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'")
|
|
||||||
|
|
||||||
def __init__(self, namespace, global_config, embedding_func):
|
def __init__(self, namespace, global_config, embedding_func):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
namespace=namespace,
|
namespace=namespace,
|
||||||
@@ -51,12 +47,8 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
# All vertices will have graph={GRAPH} property, so that we can
|
# All vertices will have graph={GRAPH} property, so that we can
|
||||||
# have several logical graphs for one source
|
# have several logical graphs for one source
|
||||||
GRAPH = GremlinStorage.escape_rx.sub(
|
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
|
||||||
r"\1\2'",
|
|
||||||
os.environ["GREMLIN_GRAPH"].replace("'", r"\'"),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.traverse_source_name = SOURCE
|
|
||||||
self.graph_name = GRAPH
|
self.graph_name = GRAPH
|
||||||
|
|
||||||
self._driver = client.Client(
|
self._driver = client.Client(
|
||||||
@@ -87,7 +79,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_value_map(value: Any) -> str:
|
def _to_value_map(value: Any) -> str:
|
||||||
"""Dump Python dict as Gremlin valueMap"""
|
"""Dump supported Python object as Gremlin valueMap"""
|
||||||
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
|
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
|
||||||
parsed_str = json_str.replace("'", r"\'")
|
parsed_str = json_str.replace("'", r"\'")
|
||||||
|
|
||||||
@@ -122,17 +114,16 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
"""Create chained .property() commands from properties dict"""
|
"""Create chained .property() commands from properties dict"""
|
||||||
props = []
|
props = []
|
||||||
for k, v in properties.items():
|
for k, v in properties.items():
|
||||||
prop_name = GremlinStorage.escape_rx.sub(r"\1\2'", k.replace("'", r"\'"))
|
prop_name = GremlinStorage._to_value_map(k)
|
||||||
props.append(f".property('{prop_name}', {GremlinStorage._to_value_map(v)})")
|
props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
|
||||||
return "".join(props)
|
return "".join(props)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_label(label: str) -> str:
|
def _fix_name(name: str) -> str:
|
||||||
"""Strip double quotes and make sure single quotes are escaped"""
|
"""Strip double quotes and format as a proper field name"""
|
||||||
label = label.strip('"').replace("'", r"\'")
|
name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))
|
||||||
label = GremlinStorage.escape_rx.sub(r"\1\2'", label)
|
|
||||||
|
|
||||||
return label
|
return name
|
||||||
|
|
||||||
async def _query(self, query: str) -> List[Dict[str, Any]]:
|
async def _query(self, query: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -146,66 +137,69 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
|
result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
|
||||||
|
if result:
|
||||||
|
result = result[0]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
entity_name_label = GremlinStorage._fix_label(node_id)
|
entity_name = GremlinStorage._fix_name(node_id)
|
||||||
|
|
||||||
query = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {entity_name})
|
||||||
.hasLabel('{entity_name_label}')
|
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.hasNext()
|
.count()
|
||||||
|
.project('has_node')
|
||||||
|
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
|
||||||
"""
|
"""
|
||||||
result = await self._query(query)
|
result = await self._query(query)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
"{%s}:query:{%s}:result:{%s}",
|
||||||
inspect.currentframe().f_code.co_name,
|
inspect.currentframe().f_code.co_name,
|
||||||
query,
|
query,
|
||||||
result[0][0],
|
result[0]["has_node"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return result[0][0]
|
return result[0]["has_node"]
|
||||||
|
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
entity_name_label_source = GremlinStorage._fix_label(source_node_id)
|
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||||
entity_name_label_target = GremlinStorage._fix_label(target_node_id)
|
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||||
|
|
||||||
query = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {entity_name_source})
|
||||||
.hasLabel('{entity_name_label_source}')
|
.outE()
|
||||||
.bothE()
|
.inV().has('graph', {self.graph_name})
|
||||||
.otherV().has('graph', '{self.graph_name}')
|
.has('entity_name', {entity_name_target})
|
||||||
.hasLabel('{entity_name_label_target}')
|
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.hasNext()
|
.count()
|
||||||
|
.project('has_edge')
|
||||||
|
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
|
||||||
"""
|
"""
|
||||||
result = await self._query(query)
|
result = await self._query(query)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
"{%s}:query:{%s}:result:{%s}",
|
||||||
inspect.currentframe().f_code.co_name,
|
inspect.currentframe().f_code.co_name,
|
||||||
query,
|
query,
|
||||||
result[0][0],
|
result[0]["has_edge"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return result[0][0]
|
return result[0]["has_edge"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||||
entity_name_label = GremlinStorage._fix_label(node_id)
|
entity_name = GremlinStorage._fix_name(node_id)
|
||||||
query = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {entity_name})
|
||||||
.hasLabel('{entity_name_label}')
|
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.project('properties')
|
.project('properties')
|
||||||
.by(elementMap())
|
.by(elementMap())
|
||||||
"""
|
"""
|
||||||
result = await self._query(query)
|
result = await self._query(query)
|
||||||
if result:
|
if result:
|
||||||
node = result[0][0]
|
node = result[0]
|
||||||
node_dict = node["properties"]
|
node_dict = node["properties"]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"{%s}: query: {%s}, result: {%s}",
|
"{%s}: query: {%s}, result: {%s}",
|
||||||
@@ -216,19 +210,18 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
return node_dict
|
return node_dict
|
||||||
|
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
entity_name_label = GremlinStorage._fix_label(node_id)
|
entity_name = GremlinStorage._fix_name(node_id)
|
||||||
query = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {entity_name})
|
||||||
.hasLabel('{entity_name_label}')
|
|
||||||
.outE()
|
.outE()
|
||||||
.inV().has('graph', '{self.graph_name}')
|
.inV().has('graph', {self.graph_name})
|
||||||
.count()
|
.count()
|
||||||
.project('total_edge_count')
|
.project('total_edge_count')
|
||||||
.by()
|
.by()
|
||||||
"""
|
"""
|
||||||
result = await self._query(query)
|
result = await self._query(query)
|
||||||
edge_count = result[0][0]["total_edge_count"]
|
edge_count = result[0]["total_edge_count"]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
"{%s}:query:{%s}:result:{%s}",
|
||||||
@@ -259,31 +252,30 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> Union[dict, None]:
|
) -> Union[dict, None]:
|
||||||
"""
|
"""
|
||||||
Find all edges between nodes of two given labels
|
Find all edges between nodes of two given names
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node_label (str): Label of the source nodes
|
source_node_id (str): Name of the source nodes
|
||||||
target_node_label (str): Label of the target nodes
|
target_node_id (str): Name of the target nodes
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict|None: Dict of found edge properties, or None of not found
|
dict|None: Dict of found edge properties, or None if not found
|
||||||
"""
|
"""
|
||||||
entity_name_label_source = GremlinStorage._fix_label(source_node_id)
|
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
||||||
entity_name_label_target = GremlinStorage._fix_label(target_node_id)
|
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
||||||
query = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {entity_name_source})
|
||||||
.hasLabel('{entity_name_label_source}')
|
|
||||||
.outE()
|
.outE()
|
||||||
.inV().has('graph', '{self.graph_name}')
|
.inV().has('graph', {self.graph_name})
|
||||||
.hasLabel('{entity_name_label_target}')
|
.has('entity_name', {entity_name_target})
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.project('edge_properties')
|
.project('edge_properties')
|
||||||
.by(__.bothE().elementMap())
|
.by(__.bothE().elementMap())
|
||||||
"""
|
"""
|
||||||
result = await self._query(query)
|
result = await self._query(query)
|
||||||
if result:
|
if result:
|
||||||
edge_properties = result[0][0]["edge_properties"]
|
edge_properties = result[0]["edge_properties"]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
"{%s}:query:{%s}:result:{%s}",
|
||||||
inspect.currentframe().f_code.co_name,
|
inspect.currentframe().f_code.co_name,
|
||||||
@@ -294,45 +286,31 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
||||||
"""
|
"""
|
||||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
Retrieves all edges (relationships) for a particular node identified by its name.
|
||||||
:return: List of tuples containing edge sources and targets
|
:return: List of tuples containing edge sources and targets
|
||||||
"""
|
"""
|
||||||
node_label = GremlinStorage._fix_label(source_node_id)
|
node_name = GremlinStorage._fix_name(source_node_id)
|
||||||
query1 = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.E()
|
||||||
.V().has('graph', '{self.graph_name}')
|
.filter(
|
||||||
.hasLabel('{node_label}')
|
__.or(
|
||||||
.out().has('graph', '{self.graph_name}')
|
__.outV().has('graph', {self.graph_name})
|
||||||
.project('connected_label')
|
.has('entity_name', {node_name}),
|
||||||
.by(__.label())
|
__.inV().has('graph', {self.graph_name})
|
||||||
|
.has('entity_name', {node_name})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.project('source_name', 'target_name')
|
||||||
|
.by(__.outV().values('entity_name'))
|
||||||
|
.by(__.inV().values('entity_name'))
|
||||||
"""
|
"""
|
||||||
query2 = f"""
|
result = await self._query(query)
|
||||||
{self.traverse_source_name}
|
edges = [(res["source_name"], res["target_name"]) for res in result]
|
||||||
.V().has('graph', '{self.graph_name}')
|
|
||||||
.as('connected')
|
|
||||||
.out().has('graph', '{self.graph_name}')
|
|
||||||
.hasLabel('{node_label}')
|
|
||||||
.project('connected_label')
|
|
||||||
.by(__.select('connected').label())
|
|
||||||
"""
|
|
||||||
result1, result2 = await asyncio.gather(
|
|
||||||
self._query(query1), self._query(query2)
|
|
||||||
)
|
|
||||||
edges1 = (
|
|
||||||
[(node_label, res["connected_label"]) for res in result1[0]]
|
|
||||||
if result1
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
edges2 = (
|
|
||||||
[(res["connected_label"], node_label) for res in result2[0]]
|
|
||||||
if result2
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
return edges1 + edges2
|
return edges
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(10),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((GremlinServerError,)),
|
retry=retry_if_exception_type((GremlinServerError,)),
|
||||||
)
|
)
|
||||||
@@ -341,28 +319,30 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
Upsert a node in the Gremlin graph.
|
Upsert a node in the Gremlin graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: The unique identifier for the node (used as label)
|
node_id: The unique identifier for the node (used as name)
|
||||||
node_data: Dictionary of node properties
|
node_data: Dictionary of node properties
|
||||||
"""
|
"""
|
||||||
label = GremlinStorage._fix_label(node_id)
|
name = GremlinStorage._fix_name(node_id)
|
||||||
properties = GremlinStorage._convert_properties(node_data)
|
properties = GremlinStorage._convert_properties(node_data)
|
||||||
|
|
||||||
query = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {name})
|
||||||
.hasLabel('{label}').fold()
|
.fold()
|
||||||
.coalesce(
|
.coalesce(
|
||||||
unfold(),
|
__.unfold(),
|
||||||
addV('{label}'))
|
__.addV('ENTITY')
|
||||||
.property('graph', '{self.graph_name}')
|
.property('graph', {self.graph_name})
|
||||||
|
.property('entity_name', {name})
|
||||||
|
)
|
||||||
{properties}
|
{properties}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._query(query)
|
await self._query(query)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Upserted node with label '{%s}' and properties: {%s}",
|
"Upserted node with name {%s} and properties: {%s}",
|
||||||
label,
|
name,
|
||||||
properties,
|
properties,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -370,7 +350,7 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(10),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
retry=retry_if_exception_type((GremlinServerError,)),
|
retry=retry_if_exception_type((GremlinServerError,)),
|
||||||
)
|
)
|
||||||
@@ -378,36 +358,35 @@ class GremlinStorage(BaseGraphStorage):
|
|||||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Upsert an edge and its properties between two nodes identified by their labels.
|
Upsert an edge and its properties between two nodes identified by their names.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node_id (str): Label of the source node (used as identifier)
|
source_node_id (str): Name of the source node (used as identifier)
|
||||||
target_node_id (str): Label of the target node (used as identifier)
|
target_node_id (str): Name of the target node (used as identifier)
|
||||||
edge_data (dict): Dictionary of properties to set on the edge
|
edge_data (dict): Dictionary of properties to set on the edge
|
||||||
"""
|
"""
|
||||||
source_node_label = GremlinStorage._fix_label(source_node_id)
|
source_node_name = GremlinStorage._fix_name(source_node_id)
|
||||||
target_node_label = GremlinStorage._fix_label(target_node_id)
|
target_node_name = GremlinStorage._fix_name(target_node_id)
|
||||||
edge_properties = GremlinStorage._convert_properties(edge_data)
|
edge_properties = GremlinStorage._convert_properties(edge_data)
|
||||||
|
|
||||||
query = f"""
|
query = f"""g
|
||||||
{self.traverse_source_name}
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {source_node_name}).as('source')
|
||||||
.hasLabel('{source_node_label}').as('source')
|
.V().has('graph', {self.graph_name})
|
||||||
.V().has('graph', '{self.graph_name}')
|
.has('entity_name', {target_node_name}).as('target')
|
||||||
.hasLabel('{target_node_label}').as('target')
|
|
||||||
.coalesce(
|
.coalesce(
|
||||||
select('source').outE('DIRECTED').where(inV().as('target')),
|
__.select('source').outE('DIRECTED').where(__.inV().as('target')),
|
||||||
select('source').addE('DIRECTED').to(select('target'))
|
__.select('source').addE('DIRECTED').to(__.select('target'))
|
||||||
)
|
)
|
||||||
.property('graph', '{self.graph_name}')
|
.property('graph', {self.graph_name})
|
||||||
{edge_properties}
|
{edge_properties}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await self._query(query)
|
await self._query(query)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
"Upserted edge from {%s} to {%s} with properties: {%s}",
|
||||||
source_node_label,
|
source_node_name,
|
||||||
target_node_label,
|
target_node_name,
|
||||||
edge_properties,
|
edge_properties,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Reference in New Issue
Block a user