diff --git a/README.md b/README.md index ea8d0a97..d6d22522 100644 --- a/README.md +++ b/README.md @@ -361,6 +361,11 @@ see test_neo4j.py for a working example. For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac. * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) +* Create index for AGE example: (Change below `dickens` to your graph name if necessary) + ``` + SET search_path = ag_catalog, "$user", public; + CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"')); + ``` ### Insert Custom KG diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 033d63d6..ccbff679 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -130,6 +130,7 @@ class PostgreSQLDB: data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None, + upsert: bool = False, ): try: async with self.pool.acquire() as connection: @@ -140,6 +141,11 @@ class PostgreSQLDB: await connection.execute(sql) else: await connection.execute(sql, *data.values()) + except asyncpg.exceptions.UniqueViolationError as e: + if upsert: + print("Key value duplicate, but upsert succeeded.") + else: + logger.error(f"Upsert error: {e}") except Exception as e: logger.error(f"PostgreSQL database error: {e}") print(sql) @@ -568,10 +574,10 @@ class PGGraphStorage(BaseGraphStorage): if dtype == "vertex": vertex = json.loads(v) - field = json.loads(v).get("properties") + field = vertex.get("properties") if not field: field = {} - field["label"] = PGGraphStorage._decode_graph_label(vertex["label"]) + field["label"] = PGGraphStorage._decode_graph_label(field["node_id"]) d[k] = field # convert edge from id-label->id by replacing id with node information # we only do this if the vertex was also returned in the query @@ -666,73 +672,8 @@ class PGGraphStorage(BaseGraphStorage): # otherwise return the value stripping out some common special chars return field.replace("(", "_").replace(")", "") - @staticmethod - def _wrap_query(query: str, graph_name: str, **params: str) -> str: - """ - Convert a cypher query to an Apache Age compatible - sql query by wrapping the cypher query in ag_catalog.cypher, - casting results to agtype and building a select statement - - Args: - query (str): a valid cypher query - graph_name (str): the name of the graph to query - params (dict): parameters for the query - - Returns: - str: an equivalent pgsql query - """ - - # pgsql template - template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ - {query} - $$) AS ({fields})""" - - # if there are any returned fields they must be added to the pgsql query - if "return" in query.lower(): - # parse return statement to identify returned fields - fields = ( - query.lower() - .split("return")[-1] - .split("distinct")[-1] - .split("order by")[0] - .split("skip")[0] - .split("limit")[0] - .split(",") - ) - - # raise exception if RETURN * is found as we can't resolve the fields - if "*" in [x.strip() for x in fields]: - raise ValueError( - "AGE graph does not support 'RETURN *'" - + " statements in Cypher queries" - ) - - # get pgsql formatted field names - fields = [ - PGGraphStorage._get_col_name(field, idx) - for idx, field in enumerate(fields) - ] - - # build resulting pgsql relation - fields_str = ", ".join( - [field.split(".")[-1] + " agtype" for field in fields] - ) - - # if no return statement we still need to return a single field of type agtype - else: - fields_str = "a agtype" - - select_str = "*" - - return template.format( - graph_name=graph_name, - query=query.format(**params), - fields=fields_str, - projection=select_str, - ) - async def _query( - self, query: str, readonly=True, upsert_edge=False, **params: str + self, query: str, readonly: bool = True, upsert: bool = False ) -> List[Dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an @@ -746,7 +687,7 @@ class PGGraphStorage(BaseGraphStorage): List[Dict[str, Any]]: a list of dictionaries containing the result set """ # convert cypher query to pgsql/age query - wrapped_query = self._wrap_query(query, self.graph_name, **params) + wrapped_query = query # execute the query, rolling back on an error try: @@ -758,22 +699,16 @@ class PGGraphStorage(BaseGraphStorage): graph_name=self.graph_name, ) else: - # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING) - # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future. - if upsert_edge: - data = await self.db.execute( - f"{wrapped_query};{wrapped_query};", - for_age=True, - graph_name=self.graph_name, - ) - else: - data = await self.db.execute( - wrapped_query, for_age=True, graph_name=self.graph_name - ) + data = await self.db.execute( + wrapped_query, + for_age=True, + graph_name=self.graph_name, + upsert=upsert, + ) except Exception as e: raise PGGraphQueryException( { - "message": f"Error executing graph query: {query.format(**params)}", + "message": f"Error executing graph query: {query}", "wrapped": wrapped_query, "detail": str(e), } @@ -788,77 +723,85 @@ class PGGraphStorage(BaseGraphStorage): return result async def has_node(self, node_id: str) -> bool: - entity_name_label = node_id.strip('"') + entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists""" - params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} - single_result = (await self._query(query, **params))[0] + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + RETURN count(n) > 0 AS node_exists + $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) + + single_result = (await self._query(query))[0] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, single_result["node_exists"], ) return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) - query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) - RETURN COUNT(r) > 0 AS edge_exists""" - params = { - "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), - "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), - } - single_result = (await self._query(query, **params))[0] + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) + RETURN COUNT(r) > 0 AS edge_exists + $$) AS (edge_exists bool)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + + single_result = (await self._query(query))[0] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, single_result["edge_exists"], ) return single_result["edge_exists"] async def get_node(self, node_id: str) -> Union[dict, None]: - entity_name_label = node_id.strip('"') - query = """MATCH (n:`{label}`) RETURN n""" - params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} - record = await self._query(query, **params) + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + RETURN n + $$) AS (n agtype)""" % (self.graph_name, label) + record = await self._query(query) if record: node = record[0] node_dict = node["n"] logger.debug( "{%s}: query: {%s}, result: {%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, node_dict, ) return node_dict return None async def node_degree(self, node_id: str) -> int: - entity_name_label = node_id.strip('"') + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) - query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count""" - params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} - record = (await self._query(query, **params))[0] + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"})-[]->(x) + RETURN count(x) AS total_edge_count + $$) AS (total_edge_count integer)""" % (self.graph_name, label) + record = (await self._query(query))[0] if record: edge_count = int(record["total_edge_count"]) logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, edge_count, ) return edge_count async def edge_degree(self, src_id: str, tgt_id: str) -> int: - entity_name_label_source = src_id.strip('"') - entity_name_label_target = tgt_id.strip('"') - src_degree = await self.node_degree(entity_name_label_source) - trg_degree = await self.node_degree(entity_name_label_target) + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree @@ -885,23 +828,25 @@ class PGGraphStorage(BaseGraphStorage): Returns: list: List of all relationships/edges found """ - entity_name_label_source = source_node_id.strip('"') - entity_name_label_target = target_node_id.strip('"') + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) - query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) - RETURN properties(r) as edge_properties - LIMIT 1""" - params = { - "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), - "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), - } - record = await self._query(query, **params) + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) + RETURN properties(r) as edge_properties + LIMIT 1 + $$) AS (edge_properties agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + record = await self._query(query) if record and record[0] and record[0]["edge_properties"]: result = record[0]["edge_properties"] logger.debug( "{%s}:query:{%s}:result:{%s}", inspect.currentframe().f_code.co_name, - query.format(**params), + query, result, ) return result @@ -911,24 +856,31 @@ class PGGraphStorage(BaseGraphStorage): Retrieves all edges (relationships) for a particular node identified by its label. :return: List of dictionaries containing edge information """ - node_label = source_node_id.strip('"') + label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - query = """MATCH (n:`{label}`) - OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected""" - params = {"label": PGGraphStorage._encode_graph_label(node_label)} - results = await self._query(query, **params) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:Entity {node_id: "%s"}) + OPTIONAL MATCH (n)-[r]-(connected) + RETURN n, r, connected + $$) AS (n agtype, r agtype, connected agtype)""" % ( + self.graph_name, + label, + ) + + results = await self._query(query) edges = [] for record in results: source_node = record["n"] if record["n"] else None connected_node = record["connected"] if record["connected"] else None source_label = ( - source_node["label"] if source_node and source_node["label"] else None + source_node["node_id"] + if source_node and source_node["node_id"] + else None ) target_label = ( - connected_node["label"] - if connected_node and connected_node["label"] + connected_node["node_id"] + if connected_node and connected_node["node_id"] else None ) @@ -950,17 +902,21 @@ class PGGraphStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ - label = node_id.strip('"') + label = PGGraphStorage._encode_graph_label(node_id.strip('"')) properties = node_data - query = """MERGE (n:`{label}`) - SET n += {properties}""" - params = { - "label": PGGraphStorage._encode_graph_label(label), - "properties": PGGraphStorage._format_properties(properties), - } + query = """SELECT * FROM cypher('%s', $$ + MERGE (n:Entity {node_id: "%s"}) + SET n += %s + RETURN n + $$) AS (n agtype)""" % ( + self.graph_name, + label, + PGGraphStorage._format_properties(properties), + ) + try: - await self._query(query, readonly=False, **params) + await self._query(query, readonly=False, upsert=True) logger.debug( "Upserted node with label '{%s}' and properties: {%s}", label, @@ -986,28 +942,30 @@ class PGGraphStorage(BaseGraphStorage): target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge """ - source_node_label = source_node_id.strip('"') - target_node_label = target_node_id.strip('"') + src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) edge_properties = edge_data - query = """MATCH (source:`{src_label}`) - WITH source - MATCH (target:`{tgt_label}`) - MERGE (source)-[r:DIRECTED]->(target) - SET r += {properties} - RETURN r""" - params = { - "src_label": PGGraphStorage._encode_graph_label(source_node_label), - "tgt_label": PGGraphStorage._encode_graph_label(target_node_label), - "properties": PGGraphStorage._format_properties(edge_properties), - } + query = """SELECT * FROM cypher('%s', $$ + MATCH (source:Entity {node_id: "%s"}) + WITH source + MATCH (target:Entity {node_id: "%s"}) + MERGE (source)-[r:DIRECTED]->(target) + SET r += %s + RETURN r + $$) AS (r agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + PGGraphStorage._format_properties(edge_properties), + ) # logger.info(f"-- inserting edge after formatted: {params}") try: - await self._query(query, readonly=False, upsert_edge=True, **params) + await self._query(query, readonly=False, upsert=True) logger.debug( "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", - source_node_label, - target_node_label, + src_label, + tgt_label, edge_properties, ) except Exception as e: