diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 07b3c907..f498e248 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1064,31 +1064,11 @@ class PGGraphStorage(BaseGraphStorage): if v.startswith("[") and v.endswith("]"): if "::vertex" in v: v = v.replace("::vertex", "") - vertexes = json.loads(v) - dl = [] - for vertex in vertexes: - prop = vertex.get("properties") - if not prop: - prop = {} - prop["label"] = PGGraphStorage._decode_graph_label( - prop["node_id"] - ) - dl.append(prop) - d[k] = dl + d[k] = json.loads(v) elif "::edge" in v: v = v.replace("::edge", "") - edges = json.loads(v) - dl = [] - for edge in edges: - dl.append( - ( - vertices[edge["start_id"]], - edge["label"], - vertices[edge["end_id"]], - ) - ) - d[k] = dl + d[k] = json.loads(v) else: print("WARNING: unsupported type") continue @@ -1097,26 +1077,9 @@ class PGGraphStorage(BaseGraphStorage): dtype = v.split("::")[-1] v = v.split("::")[0] if dtype == "vertex": - vertex = json.loads(v) - field = vertex.get("properties") - if not field: - field = {} - field["label"] = PGGraphStorage._decode_graph_label( - field["node_id"] - ) - d[k] = field - # convert edge from id-label->id by replacing id with node information - # we only do this if the vertex was also returned in the query - # this is an attempt to be consistent with neo4j implementation + d[k] = json.loads(v) elif dtype == "edge": - edge = json.loads(v) - d[k] = ( - vertices.get(edge["start_id"], {}), - edge[ - "label" - ], # we don't use decode_graph_label(), since edge label is always "DIRECTED" - vertices.get(edge["end_id"], {}), - ) + d[k] = json.loads(v) else: d[k] = ( json.loads(v) @@ -1152,56 +1115,6 @@ class PGGraphStorage(BaseGraphStorage): ) return "{" + ", ".join(props) + "}" - @staticmethod - def _encode_graph_label(label: str) -> str: - """ - Since AGE supports only alphanumerical labels, we will encode generic label as HEX string - - Args: - label (str): the original label - - Returns: - str: the encoded label - """ - return "x" + label.encode().hex() - - @staticmethod - def _decode_graph_label(encoded_label: str) -> str: - """ - Since AGE supports only alphanumerical labels, we will encode generic label as HEX string - - Args: - encoded_label (str): the encoded label - - Returns: - str: the decoded label - """ - return bytes.fromhex(encoded_label.removeprefix("x")).decode() - - @staticmethod - def _get_col_name(field: str, idx: int) -> str: - """ - Convert a cypher return field to a pgsql select field - If possible keep the cypher column name, but create a generic name if necessary - - Args: - field (str): a return field from a cypher query to be formatted for pgsql - idx (int): the position of the field in the return statement - - Returns: - str: the field to be used in the pgsql select statement - """ - # remove white space - field = field.strip() - # if an alias is provided for the field, use it - if " as " in field: - return field.split(" as ")[-1].strip() - # if the return value is an unnamed primitive, give it a generic name - if field.isnumeric() or field in ("true", "false", "null"): - return f"column_{idx}" - # otherwise return the value stripping out some common special chars - return field.replace("(", "_").replace(")", "") - async def _query( self, query: str, @@ -1252,10 +1165,10 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = self._encode_graph_label(node_id.strip('"')) + entity_name_label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {node_id: "%s"}) + MATCH (n:base {entity_id: "%s"}) RETURN count(n) > 0 AS node_exists $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) @@ -1264,11 +1177,11 @@ class PGGraphStorage(BaseGraphStorage): return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = self._encode_graph_label(source_node_id.strip('"')) - tgt_label = self._encode_graph_label(target_node_id.strip('"')) + src_label = source_node_id.strip('"') + tgt_label = target_node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"}) + MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) RETURN COUNT(r) > 0 AS edge_exists $$) AS (edge_exists bool)""" % ( self.graph_name, @@ -1281,13 +1194,14 @@ class PGGraphStorage(BaseGraphStorage): return single_result["edge_exists"] async def get_node(self, node_id: str) -> dict[str, str] | None: - label = self._encode_graph_label(node_id.strip('"')) + label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {node_id: "%s"}) + MATCH (n:base {entity_id: "%s"}) RETURN n $$) AS (n agtype)""" % (self.graph_name, label) record = await self._query(query) if record: + print(f"Record: {record}") node = record[0] node_dict = node["n"] @@ -1295,10 +1209,10 @@ class PGGraphStorage(BaseGraphStorage): return None async def node_degree(self, node_id: str) -> int: - label = self._encode_graph_label(node_id.strip('"')) + label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {node_id: "%s"})-[]->(x) + MATCH (n:base {entity_id: "%s"})-[]->(x) RETURN count(x) AS total_edge_count $$) AS (total_edge_count integer)""" % (self.graph_name, label) record = (await self._query(query))[0] @@ -1322,11 +1236,11 @@ class PGGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - src_label = self._encode_graph_label(source_node_id.strip('"')) - tgt_label = self._encode_graph_label(target_node_id.strip('"')) + src_label = source_node_id.strip('"') + tgt_label = target_node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"}) + MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"}) RETURN properties(r) as edge_properties LIMIT 1 $$) AS (edge_properties agtype)""" % ( @@ -1336,6 +1250,7 @@ class PGGraphStorage(BaseGraphStorage): ) record = await self._query(query) if record and record[0] and record[0]["edge_properties"]: + print(f"Record: {record}") result = record[0]["edge_properties"] return result @@ -1345,10 +1260,10 @@ class PGGraphStorage(BaseGraphStorage): Retrieves all edges (relationships) for a particular node identified by its label. :return: list of dictionaries containing edge information """ - label = self._encode_graph_label(source_node_id.strip('"')) + label = source_node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {node_id: "%s"}) + MATCH (n:base {entity_id: "%s"}) OPTIONAL MATCH (n)-[]-(connected:base) RETURN n, connected $$) AS (n agtype, connected agtype)""" % ( @@ -1362,24 +1277,17 @@ class PGGraphStorage(BaseGraphStorage): source_node = record["n"] if record["n"] else None connected_node = record["connected"] if record["connected"] else None - source_label = ( - source_node["node_id"] - if source_node and source_node["node_id"] - else None - ) - target_label = ( - connected_node["node_id"] - if connected_node and connected_node["node_id"] - else None - ) + if ( + source_node + and connected_node + and "properties" in source_node + and "properties" in connected_node + ): + source_label = source_node["properties"].get("entity_id") + target_label = connected_node["properties"].get("entity_id") - if source_label and target_label: - edges.append( - ( - self._decode_graph_label(source_label), - self._decode_graph_label(target_label), - ) - ) + if source_label and target_label: + edges.append((source_label, target_label)) return edges @@ -1389,17 +1297,17 @@ class PGGraphStorage(BaseGraphStorage): retry=retry_if_exception_type((PGGraphQueryException,)), ) async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - label = self._encode_graph_label(node_id.strip('"')) - properties = node_data + label = node_id.strip('"') + properties = self._format_properties(node_data) query = """SELECT * FROM cypher('%s', $$ - MERGE (n:base {node_id: "%s"}) + MERGE (n:base {entity_id: "%s"}) SET n += %s RETURN n $$) AS (n agtype)""" % ( self.graph_name, label, - self._format_properties(properties), + properties, ) try: @@ -1425,14 +1333,14 @@ class PGGraphStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): dictionary of properties to set on the edge """ - src_label = self._encode_graph_label(source_node_id.strip('"')) - tgt_label = self._encode_graph_label(target_node_id.strip('"')) - edge_properties = edge_data + src_label = source_node_id.strip('"') + tgt_label = target_node_id.strip('"') + edge_properties = self._format_properties(edge_data) query = """SELECT * FROM cypher('%s', $$ - MATCH (source:base {node_id: "%s"}) + MATCH (source:base {entity_id: "%s"}) WITH source - MATCH (target:base {node_id: "%s"}) + MATCH (target:base {entity_id: "%s"}) MERGE (source)-[r:DIRECTED]->(target) SET r += %s RETURN r @@ -1440,7 +1348,7 @@ class PGGraphStorage(BaseGraphStorage): self.graph_name, src_label, tgt_label, - self._format_properties(edge_properties), + edge_properties, ) try: @@ -1460,7 +1368,7 @@ class PGGraphStorage(BaseGraphStorage): Args: node_id (str): The ID of the node to delete. """ - label = self._encode_graph_label(node_id.strip('"')) + label = node_id.strip('"') query = """SELECT * FROM cypher('%s', $$ MATCH (n:base {entity_id: "%s"}) @@ -1480,14 +1388,12 @@ class PGGraphStorage(BaseGraphStorage): Args: node_ids (list[str]): A list of node IDs to remove. """ - encoded_node_ids = [ - self._encode_graph_label(node_id.strip('"')) for node_id in node_ids - ] - node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids]) + node_ids = [node_id.strip('"') for node_id in node_ids] + node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids]) query = """SELECT * FROM cypher('%s', $$ MATCH (n:base) - WHERE n.nentity_id IN [%s] + WHERE n.entity_id IN [%s] DETACH DELETE n $$) AS (n agtype)""" % (self.graph_name, node_id_list) @@ -1505,11 +1411,11 @@ class PGGraphStorage(BaseGraphStorage): edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). """ for source, target in edges: - src_label = self._encode_graph_label(source.strip('"')) - tgt_label = self._encode_graph_label(target.strip('"')) + src_label = source.strip('"') + tgt_label = target.strip('"') query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"}) + MATCH (a:base {entity_id: "%s"})-[r]->(b:base {entity_id: "%s"}) DELETE r $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) @@ -1560,95 +1466,98 @@ class PGGraphStorage(BaseGraphStorage): return await embed_func() async def get_knowledge_graph( - self, node_label: str, max_depth: int = 5 + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ - Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Args: - node_label (str): The label of the node to start from. If "*", the entire graph is returned. - max_depth (int): The maximum depth to traverse from the starting node. + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 Returns: - KnowledgeGraph: The retrieved subgraph. + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit """ - MAX_GRAPH_NODES = 1000 # Build the query based on whether we want the full graph or a specific subgraph. if node_label == "*": query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:base) - OPTIONAL MATCH (n)-[r]->(m:base) - RETURN n, r, m - LIMIT {MAX_GRAPH_NODES} - $$) AS (n agtype, r agtype, m agtype)""" + MATCH (n:base) + OPTIONAL MATCH (n)-[r]->(target:base) + RETURN collect(distinct n) AS n, collect(distinct r) AS r + LIMIT {MAX_GRAPH_NODES} + $$) AS (n agtype, r agtype)""" else: - encoded_label = self._encode_graph_label(node_label.strip('"')) + strip_label = node_label.strip('"') query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:base {{entity_id: "{encoded_label}"}}) - OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) - RETURN nodes(p) AS nodes, relationships(p) AS relationships - LIMIT {MAX_GRAPH_NODES} - $$) AS (nodes agtype, relationships agtype)""" + MATCH (n:base {{entity_id: "{strip_label}"}}) + OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m) + RETURN nodes(p) AS n, relationships(p) AS r + LIMIT {max_nodes} + $$) AS (n agtype, r agtype)""" results = await self._query(query) - nodes = {} - edges = [] - unique_edge_ids = set() - - def add_node(node_data: dict): - node_id = self._decode_graph_label(node_data["node_id"]) - if node_id not in nodes: - nodes[node_id] = node_data - - def add_edge(edge_data: list): - src_id = self._decode_graph_label(edge_data[0]["node_id"]) - tgt_id = self._decode_graph_label(edge_data[2]["node_id"]) - edge_key = f"{src_id},{tgt_id}" - if edge_key not in unique_edge_ids: - unique_edge_ids.add(edge_key) - edges.append( - ( - edge_key, - src_id, - tgt_id, - {"source": edge_data[0], "target": edge_data[2]}, + # Process the query results with deduplication by node and edge IDs + nodes_dict = {} + edges_dict = {} + for result in results: + # Handle single node cases + if result.get("n") and isinstance(result["n"], dict): + node_id = str(result["n"]["id"]) + if node_id not in nodes_dict: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[result["n"]["properties"]["entity_id"]], + properties=result["n"]["properties"], ) - ) + # Handle node list cases + elif result.get("n") and isinstance(result["n"], list): + for node in result["n"]: + if isinstance(node, dict) and "id" in node: + node_id = str(node["id"]) + if node_id not in nodes_dict and "properties" in node: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[node["properties"]["entity_id"]], + properties=node["properties"], + ) + + # Handle single edge cases + if result.get("r") and isinstance(result["r"], dict): + edge_id = str(result["r"]["id"]) + if edge_id not in edges_dict: + edges_dict[edge_id] = KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(result["r"]["start_id"]), + target=str(result["r"]["end_id"]), + properties=result["r"]["properties"], + ) + # Handle edge list cases + elif result.get("r") and isinstance(result["r"], list): + for edge in result["r"]: + if isinstance(edge, dict) and "id" in edge: + edge_id = str(edge["id"]) + if edge_id not in edges_dict: + edges_dict[edge_id] = KnowledgeGraphEdge( + id=edge_id, + type="DIRECTED", + source=str(edge["start_id"]), + target=str(edge["end_id"]), + properties=edge["properties"], + ) - # Process the query results. - if node_label == "*": - for result in results: - if result.get("n"): - add_node(result["n"]) - if result.get("m"): - add_node(result["m"]) - if result.get("r"): - add_edge(result["r"]) - else: - for result in results: - for node in result.get("nodes", []): - add_node(node) - for edge in result.get("relationships", []): - add_edge(edge) - - # Construct and return the KnowledgeGraph. + # Construct and return the KnowledgeGraph with deduplicated nodes and edges kg = KnowledgeGraph( - nodes=[ - KnowledgeGraphNode(id=node_id, labels=[node_id], properties=node_data) - for node_id, node_data in nodes.items() - ], - edges=[ - KnowledgeGraphEdge( - id=edge_id, - type="DIRECTED", - source=src, - target=tgt, - properties=props, - ) - for edge_id, src, tgt, props in edges - ], + nodes=list(nodes_dict.values()), + edges=list(edges_dict.values()), + is_truncated=False, ) return kg