Merge pull request #570 from ShanGor/main
Revise the AGE usage for postgres_impl
This commit is contained in:
@@ -26,7 +26,7 @@ This repository hosts the code of LightRAG. The structure of this code is based
|
||||
</div>
|
||||
|
||||
## 🎉 News
|
||||
- [x] [2025.01.06]🎯📢LightRAG now supports [PostgreSQL for Storage](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#using-postgres-for-storage).
|
||||
- [x] [2025.01.06]🎯📢You can now [use PostgreSQL for Storage](#using-postgresql-for-storage).
|
||||
- [x] [2024.12.31]🎯📢LightRAG now supports [deletion by document ID](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#delete).
|
||||
- [x] [2024.11.25]🎯📢LightRAG now supports seamless integration of [custom knowledge graphs](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#insert-custom-kg), empowering users to enhance the system with their own domain expertise.
|
||||
- [x] [2024.11.19]🎯📢A comprehensive guide to LightRAG is now available on [LearnOpenCV](https://learnopencv.com/lightrag). Many thanks to the blog author.
|
||||
@@ -361,6 +361,11 @@ see test_neo4j.py for a working example.
|
||||
For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
|
||||
* PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
|
||||
* How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
|
||||
* Create index for AGE example: (Change below `dickens` to your graph name if necessary)
|
||||
```
|
||||
SET search_path = ag_catalog, "$user", public;
|
||||
CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
|
||||
```
|
||||
|
||||
### Insert Custom KG
|
||||
|
||||
|
@@ -130,6 +130,7 @@ class PostgreSQLDB:
|
||||
data: Union[list, dict] = None,
|
||||
for_age: bool = False,
|
||||
graph_name: str = None,
|
||||
upsert: bool = False,
|
||||
):
|
||||
try:
|
||||
async with self.pool.acquire() as connection:
|
||||
@@ -140,6 +141,11 @@ class PostgreSQLDB:
|
||||
await connection.execute(sql)
|
||||
else:
|
||||
await connection.execute(sql, *data.values())
|
||||
except asyncpg.exceptions.UniqueViolationError as e:
|
||||
if upsert:
|
||||
print("Key value duplicate, but upsert succeeded.")
|
||||
else:
|
||||
logger.error(f"Upsert error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL database error: {e}")
|
||||
print(sql)
|
||||
@@ -568,10 +574,10 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
if dtype == "vertex":
|
||||
vertex = json.loads(v)
|
||||
field = json.loads(v).get("properties")
|
||||
field = vertex.get("properties")
|
||||
if not field:
|
||||
field = {}
|
||||
field["label"] = PGGraphStorage._decode_graph_label(vertex["label"])
|
||||
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
|
||||
d[k] = field
|
||||
# convert edge from id-label->id by replacing id with node information
|
||||
# we only do this if the vertex was also returned in the query
|
||||
@@ -666,73 +672,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# otherwise return the value stripping out some common special chars
|
||||
return field.replace("(", "_").replace(")", "")
|
||||
|
||||
@staticmethod
|
||||
def _wrap_query(query: str, graph_name: str, **params: str) -> str:
|
||||
"""
|
||||
Convert a cypher query to an Apache Age compatible
|
||||
sql query by wrapping the cypher query in ag_catalog.cypher,
|
||||
casting results to agtype and building a select statement
|
||||
|
||||
Args:
|
||||
query (str): a valid cypher query
|
||||
graph_name (str): the name of the graph to query
|
||||
params (dict): parameters for the query
|
||||
|
||||
Returns:
|
||||
str: an equivalent pgsql query
|
||||
"""
|
||||
|
||||
# pgsql template
|
||||
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
||||
{query}
|
||||
$$) AS ({fields})"""
|
||||
|
||||
# if there are any returned fields they must be added to the pgsql query
|
||||
if "return" in query.lower():
|
||||
# parse return statement to identify returned fields
|
||||
fields = (
|
||||
query.lower()
|
||||
.split("return")[-1]
|
||||
.split("distinct")[-1]
|
||||
.split("order by")[0]
|
||||
.split("skip")[0]
|
||||
.split("limit")[0]
|
||||
.split(",")
|
||||
)
|
||||
|
||||
# raise exception if RETURN * is found as we can't resolve the fields
|
||||
if "*" in [x.strip() for x in fields]:
|
||||
raise ValueError(
|
||||
"AGE graph does not support 'RETURN *'"
|
||||
+ " statements in Cypher queries"
|
||||
)
|
||||
|
||||
# get pgsql formatted field names
|
||||
fields = [
|
||||
PGGraphStorage._get_col_name(field, idx)
|
||||
for idx, field in enumerate(fields)
|
||||
]
|
||||
|
||||
# build resulting pgsql relation
|
||||
fields_str = ", ".join(
|
||||
[field.split(".")[-1] + " agtype" for field in fields]
|
||||
)
|
||||
|
||||
# if no return statement we still need to return a single field of type agtype
|
||||
else:
|
||||
fields_str = "a agtype"
|
||||
|
||||
select_str = "*"
|
||||
|
||||
return template.format(
|
||||
graph_name=graph_name,
|
||||
query=query.format(**params),
|
||||
fields=fields_str,
|
||||
projection=select_str,
|
||||
)
|
||||
|
||||
async def _query(
|
||||
self, query: str, readonly=True, upsert_edge=False, **params: str
|
||||
self, query: str, readonly: bool = True, upsert: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
@@ -746,7 +687,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
List[Dict[str, Any]]: a list of dictionaries containing the result set
|
||||
"""
|
||||
# convert cypher query to pgsql/age query
|
||||
wrapped_query = self._wrap_query(query, self.graph_name, **params)
|
||||
wrapped_query = query
|
||||
|
||||
# execute the query, rolling back on an error
|
||||
try:
|
||||
@@ -758,22 +699,16 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
graph_name=self.graph_name,
|
||||
)
|
||||
else:
|
||||
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
||||
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
||||
if upsert_edge:
|
||||
data = await self.db.execute(
|
||||
f"{wrapped_query};{wrapped_query};",
|
||||
for_age=True,
|
||||
graph_name=self.graph_name,
|
||||
)
|
||||
else:
|
||||
data = await self.db.execute(
|
||||
wrapped_query, for_age=True, graph_name=self.graph_name
|
||||
)
|
||||
data = await self.db.execute(
|
||||
wrapped_query,
|
||||
for_age=True,
|
||||
graph_name=self.graph_name,
|
||||
upsert=upsert,
|
||||
)
|
||||
except Exception as e:
|
||||
raise PGGraphQueryException(
|
||||
{
|
||||
"message": f"Error executing graph query: {query.format(**params)}",
|
||||
"message": f"Error executing graph query: {query}",
|
||||
"wrapped": wrapped_query,
|
||||
"detail": str(e),
|
||||
}
|
||||
@@ -788,77 +723,85 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
return result
|
||||
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
|
||||
query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
||||
single_result = (await self._query(query, **params))[0]
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
RETURN count(n) > 0 AS node_exists
|
||||
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
||||
|
||||
single_result = (await self._query(query))[0]
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
query,
|
||||
single_result["node_exists"],
|
||||
)
|
||||
|
||||
return single_result["node_exists"]
|
||||
|
||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
||||
|
||||
query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
|
||||
RETURN COUNT(r) > 0 AS edge_exists"""
|
||||
params = {
|
||||
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
||||
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
||||
}
|
||||
single_result = (await self._query(query, **params))[0]
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
|
||||
RETURN COUNT(r) > 0 AS edge_exists
|
||||
$$) AS (edge_exists bool)""" % (
|
||||
self.graph_name,
|
||||
src_label,
|
||||
tgt_label,
|
||||
)
|
||||
|
||||
single_result = (await self._query(query))[0]
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
query,
|
||||
single_result["edge_exists"],
|
||||
)
|
||||
return single_result["edge_exists"]
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = """MATCH (n:`{label}`) RETURN n"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
||||
record = await self._query(query, **params)
|
||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
RETURN n
|
||||
$$) AS (n agtype)""" % (self.graph_name, label)
|
||||
record = await self._query(query)
|
||||
if record:
|
||||
node = record[0]
|
||||
node_dict = node["n"]
|
||||
logger.debug(
|
||||
"{%s}: query: {%s}, result: {%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
query,
|
||||
node_dict,
|
||||
)
|
||||
return node_dict
|
||||
return None
|
||||
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id.strip('"')
|
||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
|
||||
query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
||||
record = (await self._query(query, **params))[0]
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})-[]->(x)
|
||||
RETURN count(x) AS total_edge_count
|
||||
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
||||
record = (await self._query(query))[0]
|
||||
if record:
|
||||
edge_count = int(record["total_edge_count"])
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
query,
|
||||
edge_count,
|
||||
)
|
||||
return edge_count
|
||||
|
||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||
entity_name_label_source = src_id.strip('"')
|
||||
entity_name_label_target = tgt_id.strip('"')
|
||||
src_degree = await self.node_degree(entity_name_label_source)
|
||||
trg_degree = await self.node_degree(entity_name_label_target)
|
||||
src_degree = await self.node_degree(src_id)
|
||||
trg_degree = await self.node_degree(tgt_id)
|
||||
|
||||
# Convert None to 0 for addition
|
||||
src_degree = 0 if src_degree is None else src_degree
|
||||
@@ -885,23 +828,25 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
"""
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
||||
|
||||
query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1"""
|
||||
params = {
|
||||
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
||||
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
||||
}
|
||||
record = await self._query(query, **params)
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
$$) AS (edge_properties agtype)""" % (
|
||||
self.graph_name,
|
||||
src_label,
|
||||
tgt_label,
|
||||
)
|
||||
record = await self._query(query)
|
||||
if record and record[0] and record[0]["edge_properties"]:
|
||||
result = record[0]["edge_properties"]
|
||||
logger.debug(
|
||||
"{%s}:query:{%s}:result:{%s}",
|
||||
inspect.currentframe().f_code.co_name,
|
||||
query.format(**params),
|
||||
query,
|
||||
result,
|
||||
)
|
||||
return result
|
||||
@@ -911,24 +856,31 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Retrieves all edges (relationships) for a particular node identified by its label.
|
||||
:return: List of dictionaries containing edge information
|
||||
"""
|
||||
node_label = source_node_id.strip('"')
|
||||
label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||
|
||||
query = """MATCH (n:`{label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(node_label)}
|
||||
results = await self._query(query, **params)
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (n:Entity {node_id: "%s"})
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected
|
||||
$$) AS (n agtype, r agtype, connected agtype)""" % (
|
||||
self.graph_name,
|
||||
label,
|
||||
)
|
||||
|
||||
results = await self._query(query)
|
||||
edges = []
|
||||
for record in results:
|
||||
source_node = record["n"] if record["n"] else None
|
||||
connected_node = record["connected"] if record["connected"] else None
|
||||
|
||||
source_label = (
|
||||
source_node["label"] if source_node and source_node["label"] else None
|
||||
source_node["node_id"]
|
||||
if source_node and source_node["node_id"]
|
||||
else None
|
||||
)
|
||||
target_label = (
|
||||
connected_node["label"]
|
||||
if connected_node and connected_node["label"]
|
||||
connected_node["node_id"]
|
||||
if connected_node and connected_node["node_id"]
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -950,17 +902,21 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
node_id: The unique identifier for the node (used as label)
|
||||
node_data: Dictionary of node properties
|
||||
"""
|
||||
label = node_id.strip('"')
|
||||
label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
|
||||
properties = node_data
|
||||
|
||||
query = """MERGE (n:`{label}`)
|
||||
SET n += {properties}"""
|
||||
params = {
|
||||
"label": PGGraphStorage._encode_graph_label(label),
|
||||
"properties": PGGraphStorage._format_properties(properties),
|
||||
}
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MERGE (n:Entity {node_id: "%s"})
|
||||
SET n += %s
|
||||
RETURN n
|
||||
$$) AS (n agtype)""" % (
|
||||
self.graph_name,
|
||||
label,
|
||||
PGGraphStorage._format_properties(properties),
|
||||
)
|
||||
|
||||
try:
|
||||
await self._query(query, readonly=False, **params)
|
||||
await self._query(query, readonly=False, upsert=True)
|
||||
logger.debug(
|
||||
"Upserted node with label '{%s}' and properties: {%s}",
|
||||
label,
|
||||
@@ -986,28 +942,30 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
target_node_id (str): Label of the target node (used as identifier)
|
||||
edge_data (dict): Dictionary of properties to set on the edge
|
||||
"""
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
|
||||
tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
|
||||
edge_properties = edge_data
|
||||
|
||||
query = """MATCH (source:`{src_label}`)
|
||||
WITH source
|
||||
MATCH (target:`{tgt_label}`)
|
||||
MERGE (source)-[r:DIRECTED]->(target)
|
||||
SET r += {properties}
|
||||
RETURN r"""
|
||||
params = {
|
||||
"src_label": PGGraphStorage._encode_graph_label(source_node_label),
|
||||
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
|
||||
"properties": PGGraphStorage._format_properties(edge_properties),
|
||||
}
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
MATCH (source:Entity {node_id: "%s"})
|
||||
WITH source
|
||||
MATCH (target:Entity {node_id: "%s"})
|
||||
MERGE (source)-[r:DIRECTED]->(target)
|
||||
SET r += %s
|
||||
RETURN r
|
||||
$$) AS (r agtype)""" % (
|
||||
self.graph_name,
|
||||
src_label,
|
||||
tgt_label,
|
||||
PGGraphStorage._format_properties(edge_properties),
|
||||
)
|
||||
# logger.info(f"-- inserting edge after formatted: {params}")
|
||||
try:
|
||||
await self._query(query, readonly=False, upsert_edge=True, **params)
|
||||
await self._query(query, readonly=False, upsert=True)
|
||||
logger.debug(
|
||||
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
||||
source_node_label,
|
||||
target_node_label,
|
||||
src_label,
|
||||
tgt_label,
|
||||
edge_properties,
|
||||
)
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user