Revised the postgres implementation, to use attributes(node_id) rather than nodes to identify an entity. Which significantly reduced the table counts.

This commit is contained in:
Samuel Chan
2025-01-11 09:30:19 +08:00
parent 196350b75b
commit d03d6f5fc5
2 changed files with 118 additions and 155 deletions

View File

@@ -361,6 +361,11 @@ see test_neo4j.py for a working example.
For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE). For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac. * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py) * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
```
SET search_path = ag_catalog, "$user", public;
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
```
### Insert Custom KG ### Insert Custom KG

View File

@@ -130,6 +130,7 @@ class PostgreSQLDB:
data: Union[list, dict] = None, data: Union[list, dict] = None,
for_age: bool = False, for_age: bool = False,
graph_name: str = None, graph_name: str = None,
upsert: bool = False,
): ):
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@@ -140,6 +141,11 @@ class PostgreSQLDB:
await connection.execute(sql) await connection.execute(sql)
else: else:
await connection.execute(sql, *data.values()) await connection.execute(sql, *data.values())
except asyncpg.exceptions.UniqueViolationError as e:
if upsert:
print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
except Exception as e: except Exception as e:
logger.error(f"PostgreSQL database error: {e}") logger.error(f"PostgreSQL database error: {e}")
print(sql) print(sql)
@@ -568,10 +574,10 @@ class PGGraphStorage(BaseGraphStorage):
if dtype == "vertex": if dtype == "vertex":
vertex = json.loads(v) vertex = json.loads(v)
field = json.loads(v).get("properties") field = vertex.get("properties")
if not field: if not field:
field = {} field = {}
field["label"] = PGGraphStorage._decode_graph_label(vertex["label"]) field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
d[k] = field d[k] = field
# convert edge from id-label->id by replacing id with node information # convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query # we only do this if the vertex was also returned in the query
@@ -666,73 +672,8 @@ class PGGraphStorage(BaseGraphStorage):
# otherwise return the value stripping out some common special chars # otherwise return the value stripping out some common special chars
return field.replace("(", "_").replace(")", "") return field.replace("(", "_").replace(")", "")
@staticmethod
def _wrap_query(query: str, graph_name: str, **params: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
params (dict): parameters for the query
Returns:
str: an equivalent pgsql query
"""
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields})"""
# if there are any returned fields they must be added to the pgsql query
if "return" in query.lower():
# parse return statement to identify returned fields
fields = (
query.lower()
.split("return")[-1]
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
)
# get pgsql formatted field names
fields = [
PGGraphStorage._get_col_name(field, idx)
for idx, field in enumerate(fields)
]
# build resulting pgsql relation
fields_str = ", ".join(
[field.split(".")[-1] + " agtype" for field in fields]
)
# if no return statement we still need to return a single field of type agtype
else:
fields_str = "a agtype"
select_str = "*"
return template.format(
graph_name=graph_name,
query=query.format(**params),
fields=fields_str,
projection=select_str,
)
async def _query( async def _query(
self, query: str, readonly=True, upsert_edge=False, **params: str self, query: str, readonly: bool = True, upsert: bool = False
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Query the graph by taking a cypher query, converting it to an Query the graph by taking a cypher query, converting it to an
@@ -746,7 +687,7 @@ class PGGraphStorage(BaseGraphStorage):
List[Dict[str, Any]]: a list of dictionaries containing the result set List[Dict[str, Any]]: a list of dictionaries containing the result set
""" """
# convert cypher query to pgsql/age query # convert cypher query to pgsql/age query
wrapped_query = self._wrap_query(query, self.graph_name, **params) wrapped_query = query
# execute the query, rolling back on an error # execute the query, rolling back on an error
try: try:
@@ -758,22 +699,16 @@ class PGGraphStorage(BaseGraphStorage):
graph_name=self.graph_name, graph_name=self.graph_name,
) )
else: else:
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING) data = await self.db.execute(
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future. wrapped_query,
if upsert_edge: for_age=True,
data = await self.db.execute( graph_name=self.graph_name,
f"{wrapped_query};{wrapped_query};", upsert=upsert,
for_age=True, )
graph_name=self.graph_name,
)
else:
data = await self.db.execute(
wrapped_query, for_age=True, graph_name=self.graph_name
)
except Exception as e: except Exception as e:
raise PGGraphQueryException( raise PGGraphQueryException(
{ {
"message": f"Error executing graph query: {query.format(**params)}", "message": f"Error executing graph query: {query}",
"wrapped": wrapped_query, "wrapped": wrapped_query,
"detail": str(e), "detail": str(e),
} }
@@ -788,77 +723,85 @@ class PGGraphStorage(BaseGraphStorage):
return result return result
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"') entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists""" query = """SELECT * FROM cypher('%s', $$
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} MATCH (n:Entity {node_id: "%s"})
single_result = (await self._query(query, **params))[0] RETURN count(n) > 0 AS node_exists
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
single_result["node_exists"], single_result["node_exists"],
) )
return single_result["node_exists"] return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"') src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
entity_name_label_target = target_node_id.strip('"') tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) query = """SELECT * FROM cypher('%s', $$
RETURN COUNT(r) > 0 AS edge_exists""" MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
params = { RETURN COUNT(r) > 0 AS edge_exists
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), $$) AS (edge_exists bool)""" % (
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), self.graph_name,
} src_label,
single_result = (await self._query(query, **params))[0] tgt_label,
)
single_result = (await self._query(query))[0]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
single_result["edge_exists"], single_result["edge_exists"],
) )
return single_result["edge_exists"] return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]: async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = node_id.strip('"') label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`) RETURN n""" query = """SELECT * FROM cypher('%s', $$
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} MATCH (n:Entity {node_id: "%s"})
record = await self._query(query, **params) RETURN n
$$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query)
if record: if record:
node = record[0] node = record[0]
node_dict = node["n"] node_dict = node["n"]
logger.debug( logger.debug(
"{%s}: query: {%s}, result: {%s}", "{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
node_dict, node_dict,
) )
return node_dict return node_dict
return None return None
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"') label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count""" query = """SELECT * FROM cypher('%s', $$
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} MATCH (n:Entity {node_id: "%s"})-[]->(x)
record = (await self._query(query, **params))[0] RETURN count(x) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0]
if record: if record:
edge_count = int(record["total_edge_count"]) edge_count = int(record["total_edge_count"])
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
edge_count, edge_count,
) )
return edge_count return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('"') src_degree = await self.node_degree(src_id)
entity_name_label_target = tgt_id.strip('"') trg_degree = await self.node_degree(tgt_id)
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
# Convert None to 0 for addition # Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree src_degree = 0 if src_degree is None else src_degree
@@ -885,23 +828,25 @@ class PGGraphStorage(BaseGraphStorage):
Returns: Returns:
list: List of all relationships/edges found list: List of all relationships/edges found
""" """
entity_name_label_source = source_node_id.strip('"') src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
entity_name_label_target = target_node_id.strip('"') tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) query = """SELECT * FROM cypher('%s', $$
RETURN properties(r) as edge_properties MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
LIMIT 1""" RETURN properties(r) as edge_properties
params = { LIMIT 1
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), $$) AS (edge_properties agtype)""" % (
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), self.graph_name,
} src_label,
record = await self._query(query, **params) tgt_label,
)
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]: if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"] result = record[0]["edge_properties"]
logger.debug( logger.debug(
"{%s}:query:{%s}:result:{%s}", "{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name, inspect.currentframe().f_code.co_name,
query.format(**params), query,
result, result,
) )
return result return result
@@ -911,24 +856,31 @@ class PGGraphStorage(BaseGraphStorage):
Retrieves all edges (relationships) for a particular node identified by its label. Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information :return: List of dictionaries containing edge information
""" """
node_label = source_node_id.strip('"') label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
query = """MATCH (n:`{label}`) query = """SELECT * FROM cypher('%s', $$
OPTIONAL MATCH (n)-[r]-(connected) MATCH (n:Entity {node_id: "%s"})
RETURN n, r, connected""" OPTIONAL MATCH (n)-[r]-(connected)
params = {"label": PGGraphStorage._encode_graph_label(node_label)} RETURN n, r, connected
results = await self._query(query, **params) $$) AS (n agtype, r agtype, connected agtype)""" % (
self.graph_name,
label,
)
results = await self._query(query)
edges = [] edges = []
for record in results: for record in results:
source_node = record["n"] if record["n"] else None source_node = record["n"] if record["n"] else None
connected_node = record["connected"] if record["connected"] else None connected_node = record["connected"] if record["connected"] else None
source_label = ( source_label = (
source_node["label"] if source_node and source_node["label"] else None source_node["node_id"]
if source_node and source_node["node_id"]
else None
) )
target_label = ( target_label = (
connected_node["label"] connected_node["node_id"]
if connected_node and connected_node["label"] if connected_node and connected_node["node_id"]
else None else None
) )
@@ -950,17 +902,21 @@ class PGGraphStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label) node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties node_data: Dictionary of node properties
""" """
label = node_id.strip('"') label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
properties = node_data properties = node_data
query = """MERGE (n:`{label}`) query = """SELECT * FROM cypher('%s', $$
SET n += {properties}""" MERGE (n:Entity {node_id: "%s"})
params = { SET n += %s
"label": PGGraphStorage._encode_graph_label(label), RETURN n
"properties": PGGraphStorage._format_properties(properties), $$) AS (n agtype)""" % (
} self.graph_name,
label,
PGGraphStorage._format_properties(properties),
)
try: try:
await self._query(query, readonly=False, **params) await self._query(query, readonly=False, upsert=True)
logger.debug( logger.debug(
"Upserted node with label '{%s}' and properties: {%s}", "Upserted node with label '{%s}' and properties: {%s}",
label, label,
@@ -986,28 +942,30 @@ class PGGraphStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier) target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge edge_data (dict): Dictionary of properties to set on the edge
""" """
source_node_label = source_node_id.strip('"') src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
target_node_label = target_node_id.strip('"') tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
edge_properties = edge_data edge_properties = edge_data
query = """MATCH (source:`{src_label}`) query = """SELECT * FROM cypher('%s', $$
WITH source MATCH (source:Entity {node_id: "%s"})
MATCH (target:`{tgt_label}`) WITH source
MERGE (source)-[r:DIRECTED]->(target) MATCH (target:Entity {node_id: "%s"})
SET r += {properties} MERGE (source)-[r:DIRECTED]->(target)
RETURN r""" SET r += %s
params = { RETURN r
"src_label": PGGraphStorage._encode_graph_label(source_node_label), $$) AS (r agtype)""" % (
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label), self.graph_name,
"properties": PGGraphStorage._format_properties(edge_properties), src_label,
} tgt_label,
PGGraphStorage._format_properties(edge_properties),
)
# logger.info(f"-- inserting edge after formatted: {params}") # logger.info(f"-- inserting edge after formatted: {params}")
try: try:
await self._query(query, readonly=False, upsert_edge=True, **params) await self._query(query, readonly=False, upsert=True)
logger.debug( logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}", "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
source_node_label, src_label,
target_node_label, tgt_label,
edge_properties, edge_properties,
) )
except Exception as e: except Exception as e: