Merge pull request #570 from ShanGor/main

Revise the AGE usage for postgres_impl
This commit is contained in:
zrguo
2025-01-12 13:23:06 +08:00
committed by GitHub
2 changed files with 119 additions and 156 deletions

View File

@@ -26,7 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
</div>
## 🎉 News
- [x] [2025.01.06]🎯📢LightRAG now supports [PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage).
- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage).
- [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
- [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
@@ -361,6 +361,11 @@ see test_neo4j.py for a working example.
For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
```
SET search_path = ag_catalog, "$user", public;
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
```
### Insert Custom KG

View File

@@ -130,6 +130,7 @@ class PostgreSQLDB:
data: Union[list, dict] = None,
for_age: bool = False,
graph_name: str = None,
upsert: bool = False,
):
try:
async with self.pool.acquire() as connection:
@@ -140,6 +141,11 @@ class PostgreSQLDB:
await connection.execute(sql)
else:
await connection.execute(sql, *data.values())
except asyncpg.exceptions.UniqueViolationError as e:
if upsert:
print("Key value duplicate, but upsert succeeded.")
else:
logger.error(f"Upsert error: {e}")
except Exception as e:
logger.error(f"PostgreSQL database error: {e}")
print(sql)
@@ -568,10 +574,10 @@ class PGGraphStorage(BaseGraphStorage):
if dtype == "vertex":
vertex = json.loads(v)
field = json.loads(v).get("properties")
field = vertex.get("properties")
if not field:
field = {}
field["label"] = PGGraphStorage._decode_graph_label(vertex["label"])
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
@@ -666,73 +672,8 @@ class PGGraphStorage(BaseGraphStorage):
# otherwise return the value stripping out some common special chars
return field.replace("(", "_").replace(")", "")
@staticmethod
def _wrap_query(query: str, graph_name: str, **params: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
params (dict): parameters for the query
Returns:
str: an equivalent pgsql query
"""
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields})"""
# if there are any returned fields they must be added to the pgsql query
if "return" in query.lower():
# parse return statement to identify returned fields
fields = (
query.lower()
.split("return")[-1]
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
)
# get pgsql formatted field names
fields = [
PGGraphStorage._get_col_name(field, idx)
for idx, field in enumerate(fields)
]
# build resulting pgsql relation
fields_str = ", ".join(
[field.split(".")[-1] + " agtype" for field in fields]
)
# if no return statement we still need to return a single field of type agtype
else:
fields_str = "a agtype"
select_str = "*"
return template.format(
graph_name=graph_name,
query=query.format(**params),
fields=fields_str,
projection=select_str,
)
async def _query(
self, query: str, readonly=True, upsert_edge=False, **params: str
self, query: str, readonly: bool = True, upsert: bool = False
) -> List[Dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
@@ -746,7 +687,7 @@ class PGGraphStorage(BaseGraphStorage):
List[Dict[str, Any]]: a list of dictionaries containing the result set
"""
# convert cypher query to pgsql/age query
wrapped_query = self._wrap_query(query, self.graph_name, **params)
wrapped_query = query
# execute the query, rolling back on an error
try:
@@ -758,22 +699,16 @@ class PGGraphStorage(BaseGraphStorage):
graph_name=self.graph_name,
)
else:
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
if upsert_edge:
data = await self.db.execute(
f"{wrapped_query};{wrapped_query};",
for_age=True,
graph_name=self.graph_name,
)
else:
data = await self.db.execute(
wrapped_query, for_age=True, graph_name=self.graph_name
)
data = await self.db.execute(
wrapped_query,
for_age=True,
graph_name=self.graph_name,
upsert=upsert,
)
except Exception as e:
raise PGGraphQueryException(
{
"message": f"Error executing graph query: {query.format(**params)}",
"message": f"Error executing graph query: {query}",
"wrapped": wrapped_query,
"detail": str(e),
}
@@ -788,77 +723,85 @@ class PGGraphStorage(BaseGraphStorage):
return result
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"""
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
single_result = (await self._query(query, **params))[0]
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
RETURN count(n) > 0 AS node_exists
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
query,
single_result["node_exists"],
)
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
RETURN COUNT(r) > 0 AS edge_exists"""
params = {
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
}
single_result = (await self._query(query, **params))[0]
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
RETURN COUNT(r) > 0 AS edge_exists
$$) AS (edge_exists bool)""" % (
self.graph_name,
src_label,
tgt_label,
)
single_result = (await self._query(query))[0]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
query,
single_result["edge_exists"],
)
return single_result["edge_exists"]
async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = node_id.strip('"')
query = """MATCH (n:`{label}`) RETURN n"""
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
record = await self._query(query, **params)
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
RETURN n
$$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query)
if record:
node = record[0]
node_dict = node["n"]
logger.debug(
"{%s}: query: {%s}, result: {%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
query,
node_dict,
)
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"')
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count"""
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
record = (await self._query(query, **params))[0]
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})-[]->(x)
RETURN count(x) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0]
if record:
edge_count = int(record["total_edge_count"])
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
query,
edge_count,
)
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
entity_name_label_source = src_id.strip('"')
entity_name_label_target = tgt_id.strip('"')
src_degree = await self.node_degree(entity_name_label_source)
trg_degree = await self.node_degree(entity_name_label_target)
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
@@ -885,23 +828,25 @@ class PGGraphStorage(BaseGraphStorage):
Returns:
list: List of all relationships/edges found
"""
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
RETURN properties(r) as edge_properties
LIMIT 1"""
params = {
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
}
record = await self._query(query, **params)
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
RETURN properties(r) as edge_properties
LIMIT 1
$$) AS (edge_properties agtype)""" % (
self.graph_name,
src_label,
tgt_label,
)
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
logger.debug(
"{%s}:query:{%s}:result:{%s}",
inspect.currentframe().f_code.co_name,
query.format(**params),
query,
result,
)
return result
@@ -911,24 +856,31 @@ class PGGraphStorage(BaseGraphStorage):
Retrieves all edges (relationships) for a particular node identified by its label.
:return: List of dictionaries containing edge information
"""
node_label = source_node_id.strip('"')
label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
query = """MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected"""
params = {"label": PGGraphStorage._encode_graph_label(node_label)}
results = await self._query(query, **params)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected
$$) AS (n agtype, r agtype, connected agtype)""" % (
self.graph_name,
label,
)
results = await self._query(query)
edges = []
for record in results:
source_node = record["n"] if record["n"] else None
connected_node = record["connected"] if record["connected"] else None
source_label = (
source_node["label"] if source_node and source_node["label"] else None
source_node["node_id"]
if source_node and source_node["node_id"]
else None
)
target_label = (
connected_node["label"]
if connected_node and connected_node["label"]
connected_node["node_id"]
if connected_node and connected_node["node_id"]
else None
)
@@ -950,17 +902,21 @@ class PGGraphStorage(BaseGraphStorage):
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
label = node_id.strip('"')
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
properties = node_data
query = """MERGE (n:`{label}`)
SET n += {properties}"""
params = {
"label": PGGraphStorage._encode_graph_label(label),
"properties": PGGraphStorage._format_properties(properties),
}
query = """SELECT * FROM cypher('%s', $$
MERGE (n:Entity {node_id: "%s"})
SET n += %s
RETURN n
$$) AS (n agtype)""" % (
self.graph_name,
label,
PGGraphStorage._format_properties(properties),
)
try:
await self._query(query, readonly=False, **params)
await self._query(query, readonly=False, upsert=True)
logger.debug(
"Upserted node with label '{%s}' and properties: {%s}",
label,
@@ -986,28 +942,30 @@ class PGGraphStorage(BaseGraphStorage):
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
"""
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
edge_properties = edge_data
query = """MATCH (source:`{src_label}`)
WITH source
MATCH (target:`{tgt_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += {properties}
RETURN r"""
params = {
"src_label": PGGraphStorage._encode_graph_label(source_node_label),
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
"properties": PGGraphStorage._format_properties(edge_properties),
}
query = """SELECT * FROM cypher('%s', $$
MATCH (source:Entity {node_id: "%s"})
WITH source
MATCH (target:Entity {node_id: "%s"})
MERGE (source)-[r:DIRECTED]->(target)
SET r += %s
RETURN r
$$) AS (r agtype)""" % (
self.graph_name,
src_label,
tgt_label,
PGGraphStorage._format_properties(edge_properties),
)
# logger.info(f"-- inserting edge after formatted: {params}")
try:
await self._query(query, readonly=False, upsert_edge=True, **params)
await self._query(query, readonly=False, upsert=True)
logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
source_node_label,
target_node_label,
src_label,
tgt_label,
edge_properties,
)
except Exception as e: