diff --git a/lightrag/kg/age_impl.py b/lightrag/kg/age_impl.py index cd4b4f3a..d74995a0 100644 --- a/lightrag/kg/age_impl.py +++ b/lightrag/kg/age_impl.py @@ -282,11 +282,9 @@ class AGEStorage(BaseGraphStorage): select_str = "*" - query = query.format(**params) - return template.format( graph_name=graph_name, - query=query, + query=query.format(**params), fields=fields_str, projection=select_str, ) @@ -349,16 +347,15 @@ class AGEStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: entity_name_label = node_id.strip('"') - query = "MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists" - single_result = ( - await self._query( - query, label=AGEStorage._encode_graph_label(entity_name_label) - ) - )[0] + query = """ + MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists + """ + params = {"label": AGEStorage._encode_graph_label(entity_name_label)} + single_result = (await self._query(query, **params))[0] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query, + query.format(**params), single_result[0], ) @@ -368,20 +365,19 @@ class AGEStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - query = ( - "MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) " - "RETURN COUNT(r) > 0 AS edge_exists" - ) - single_result = ( - await self._query( - query, - src_label=AGEStorage._encode_graph_label(entity_name_label_source), - tgt_label=AGEStorage._encode_graph_label(entity_name_label_target), - ) - )[0] + query = """ + MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) + RETURN COUNT(r) > 0 AS edge_exists + """ + params = { + "src_label": AGEStorage._encode_graph_label(entity_name_label_source), + "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), + } + single_result = (await self._query(query, **params))[0] logger.debug( - "{%s}:query:{query}:result:{%s}", + "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, + query.format(**params), single_result[0], ) return single_result["edge_exists"].lower() == "true" @@ -389,16 +385,15 @@ class AGEStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: entity_name_label = node_id.strip('"') query = "MATCH (n:`{label}`) RETURN n" - record = await self._query( - query, label=AGEStorage._encode_graph_label(entity_name_label) - ) + params = {"label": AGEStorage._encode_graph_label(entity_name_label)} + record = await self._query(query, **params) if record: node = record[0] node_dict = node["n"] logger.debug( "{%s}: query: {%s}, result: {%s}", inspect.currentframe().f_code.co_name, - query, + query.format(**params), node_dict, ) return node_dict @@ -408,20 +403,17 @@ class AGEStorage(BaseGraphStorage): entity_name_label = node_id.strip('"') query = """ - MATCH (n:`{label}`)-[]->(x) - RETURN count(x) AS total_edge_count - """ - record = ( - await self._query( - query, label=AGEStorage._encode_graph_label(entity_name_label) - ) - )[0] + MATCH (n:`{label}`)-[]->(x) + RETURN count(x) AS total_edge_count + """ + params = {"label": AGEStorage._encode_graph_label(entity_name_label)} + record = (await self._query(query, **params))[0] if record: edge_count = int(record["total_edge_count"]) logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query, + query.format(**params), edge_count, ) return edge_count @@ -461,22 +453,21 @@ class AGEStorage(BaseGraphStorage): entity_name_label_target = target_node_id.strip('"') query = """ - MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) - RETURN properties(r) as edge_properties - LIMIT 1 - """ - - record = await self._query( - query, - src_label=AGEStorage._encode_graph_label(entity_name_label_source), - tgt_label=AGEStorage._encode_graph_label(entity_name_label_target), - ) + MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) + RETURN properties(r) as edge_properties + LIMIT 1 + """ + params = { + "src_label": AGEStorage._encode_graph_label(entity_name_label_source), + "tgt_label": AGEStorage._encode_graph_label(entity_name_label_target), + } + record = await self._query(query, **params) if record and record[0] and record[0]["edge_properties"]: result = record[0]["edge_properties"] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query, + query.format(**params), result, ) return result @@ -488,12 +479,13 @@ class AGEStorage(BaseGraphStorage): """ node_label = source_node_id.strip('"') - query = """MATCH (n:`{label}`) + query = """ + MATCH (n:`{label}`) OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - results = await self._query( - query, label=AGEStorage._encode_graph_label(node_label) - ) + RETURN n, r, connected + """ + params = {"label": AGEStorage._encode_graph_label(node_label)} + results = await self._query(query, **params) edges = [] for record in results: source_node = record["n"] if record["n"] else None @@ -530,15 +522,15 @@ class AGEStorage(BaseGraphStorage): properties = node_data query = """ - MERGE (n:`{label}`) - SET n += {properties} - """ + MERGE (n:`{label}`) + SET n += {properties} + """ + params = { + "label": AGEStorage._encode_graph_label(label), + "properties": AGEStorage._format_properties(properties), + } try: - await self._query( - query, - label=AGEStorage._encode_graph_label(label), - properties=AGEStorage._format_properties(properties), - ) + await self._query(query, **params) logger.debug( "Upserted node with label '{%s}' and properties: {%s}", label, @@ -569,20 +561,20 @@ class AGEStorage(BaseGraphStorage): edge_properties = edge_data query = """ - MATCH (source:`{src_label}`) - WITH source - MATCH (target:`{tgt_label}`) - MERGE (source)-[r:DIRECTED]->(target) - SET r += {properties} - RETURN r - """ + MATCH (source:`{src_label}`) + WITH source + MATCH (target:`{tgt_label}`) + MERGE (source)-[r:DIRECTED]->(target) + SET r += {properties} + RETURN r + """ + params = { + "src_label": AGEStorage._encode_graph_label(source_node_label), + "tgt_label": AGEStorage._encode_graph_label(target_node_label), + "properties": AGEStorage._format_properties(edge_properties), + } try: - await self._query( - query, - src_label=AGEStorage._encode_graph_label(source_node_label), - tgt_label=AGEStorage._encode_graph_label(target_node_label), - properties=AGEStorage._format_properties(edge_properties), - ) + await self._query(query, **params) logger.debug( "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", source_node_label,