Fix the bug of AGE processing
This commit is contained in:
@@ -81,12 +81,12 @@ class PostgreSQLDB:
|
||||
|
||||
|
||||
async def query(
|
||||
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False
|
||||
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None
|
||||
) -> Union[dict, None, list[dict]]:
|
||||
async with self.pool.acquire() as connection:
|
||||
try:
|
||||
if for_age:
|
||||
await connection.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await PostgreSQLDB._prerequisite(connection, graph_name)
|
||||
if params:
|
||||
rows = await connection.fetch(sql, *params.values())
|
||||
else:
|
||||
@@ -95,10 +95,7 @@ class PostgreSQLDB:
|
||||
if multirows:
|
||||
if rows:
|
||||
columns = [col for col in rows[0].keys()]
|
||||
# print("columns", columns.__class__, columns)
|
||||
# print("rows", rows)
|
||||
data = [dict(zip(columns, row)) for row in rows]
|
||||
# print("data", data)
|
||||
else:
|
||||
data = []
|
||||
else:
|
||||
@@ -114,11 +111,11 @@ class PostgreSQLDB:
|
||||
print(params)
|
||||
raise
|
||||
|
||||
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False):
|
||||
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None):
|
||||
try:
|
||||
async with self.pool.acquire() as connection:
|
||||
if for_age:
|
||||
await connection.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await PostgreSQLDB._prerequisite(connection, graph_name)
|
||||
|
||||
if data is None:
|
||||
await connection.execute(sql)
|
||||
@@ -130,6 +127,14 @@ class PostgreSQLDB:
|
||||
print(data)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
||||
try:
|
||||
await conn.execute(f'SET search_path = ag_catalog, "$user", public')
|
||||
await conn.execute(f"""select create_graph('{graph_name}')""")
|
||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
@@ -346,18 +351,14 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
# print("Namespace", self.namespace)
|
||||
|
||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
||||
# print("sql is: ", sql)
|
||||
params = {
|
||||
"workspace": self.db.workspace,
|
||||
"better_than_threshold": self.cosine_better_than_threshold,
|
||||
"top_k": top_k,
|
||||
}
|
||||
# print("params is: ", params)
|
||||
results = await self.db.query(sql, params=params, multirows=True)
|
||||
print("vector search result:", results)
|
||||
return results
|
||||
|
||||
@dataclass
|
||||
@@ -487,19 +488,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
async def check_graph_exists(self):
|
||||
try:
|
||||
res = await self.db.query(f"SELECT * FROM ag_catalog.ag_graph WHERE name = '{self.graph_name}'")
|
||||
if res:
|
||||
logger.info(f"Graph {self.graph_name} exists.")
|
||||
else:
|
||||
logger.info(f"Graph {self.graph_name} does not exist. Creating...")
|
||||
await self.db.execute(f"SELECT create_graph('{self.graph_name}')", for_age=True)
|
||||
logger.info(f"Graph {self.graph_name} created.")
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to check/create graph {self.graph_name}:", e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -572,7 +560,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
Args:
|
||||
properties (Dict[str,str]): a dictionary containing node/edge properties
|
||||
id (Union[str, None]): the id of the node or None if none exists
|
||||
_id (Union[str, None]): the id of the node or None if none exists
|
||||
|
||||
Returns:
|
||||
str: the properties dictionary as a properly formatted string
|
||||
@@ -591,7 +579,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def _encode_graph_label(label: str) -> str:
|
||||
"""
|
||||
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
|
||||
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
||||
|
||||
Args:
|
||||
label (str): the original label
|
||||
@@ -604,7 +592,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
@staticmethod
|
||||
def _decode_graph_label(encoded_label: str) -> str:
|
||||
"""
|
||||
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
|
||||
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
||||
|
||||
Args:
|
||||
encoded_label (str): the encoded label
|
||||
@@ -656,8 +644,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
# pgsql template
|
||||
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
||||
{query}
|
||||
$$) AS ({fields});"""
|
||||
{query}
|
||||
$$) AS ({fields})"""
|
||||
|
||||
# if there are any returned fields they must be added to the pgsql query
|
||||
if "return" in query.lower():
|
||||
@@ -702,7 +690,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
projection=select_str,
|
||||
)
|
||||
|
||||
async def _query(self, query: str, readonly=True, **params: str) -> List[Dict[str, Any]]:
|
||||
async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
age compatible query, executing it and converting the result
|
||||
@@ -720,9 +708,14 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# execute the query, rolling back on an error
|
||||
try:
|
||||
if readonly:
|
||||
data = await self.db.query(wrapped_query, multirows=True, for_age=True)
|
||||
data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name)
|
||||
else:
|
||||
data = await self.db.execute(wrapped_query, for_age=True)
|
||||
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
||||
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
||||
if upsert_edge:
|
||||
data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name)
|
||||
else:
|
||||
data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name)
|
||||
except Exception as e:
|
||||
raise PGGraphQueryException(
|
||||
{
|
||||
@@ -743,9 +736,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def has_node(self, node_id: str) -> bool:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
|
||||
"""
|
||||
query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
||||
single_result = (await self._query(query, **params))[0]
|
||||
logger.debug(
|
||||
@@ -761,10 +752,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
|
||||
RETURN COUNT(r) > 0 AS edge_exists
|
||||
"""
|
||||
query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
|
||||
RETURN COUNT(r) > 0 AS edge_exists"""
|
||||
params = {
|
||||
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
||||
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
||||
@@ -780,9 +769,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
async def get_node(self, node_id: str) -> Union[dict, None]:
|
||||
entity_name_label = node_id.strip('"')
|
||||
query = """
|
||||
MATCH (n:`{label}`) RETURN n
|
||||
"""
|
||||
query = """MATCH (n:`{label}`) RETURN n"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
||||
record = await self._query(query, **params)
|
||||
if record:
|
||||
@@ -800,10 +787,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def node_degree(self, node_id: str) -> int:
|
||||
entity_name_label = node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (n:`{label}`)-[]->(x)
|
||||
RETURN count(x) AS total_edge_count
|
||||
"""
|
||||
query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
||||
record = (await self._query(query, **params))[0]
|
||||
if record:
|
||||
@@ -841,8 +825,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
Find all edges between nodes of two given labels
|
||||
|
||||
Args:
|
||||
source_node_label (str): Label of the source nodes
|
||||
target_node_label (str): Label of the target nodes
|
||||
source_node_id (str): Label of the source nodes
|
||||
target_node_id (str): Label of the target nodes
|
||||
|
||||
Returns:
|
||||
list: List of all relationships/edges found
|
||||
@@ -850,11 +834,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
entity_name_label_source = source_node_id.strip('"')
|
||||
entity_name_label_target = target_node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
|
||||
query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
|
||||
RETURN properties(r) as edge_properties
|
||||
LIMIT 1
|
||||
"""
|
||||
LIMIT 1"""
|
||||
params = {
|
||||
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
||||
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
||||
@@ -877,11 +859,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"""
|
||||
node_label = source_node_id.strip('"')
|
||||
|
||||
query = """
|
||||
MATCH (n:`{label}`)
|
||||
query = """MATCH (n:`{label}`)
|
||||
OPTIONAL MATCH (n)-[r]-(connected)
|
||||
RETURN n, r, connected
|
||||
"""
|
||||
RETURN n, r, connected"""
|
||||
params = {"label": PGGraphStorage._encode_graph_label(node_label)}
|
||||
results = await self._query(query, **params)
|
||||
edges = []
|
||||
@@ -919,10 +899,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
label = node_id.strip('"')
|
||||
properties = node_data
|
||||
|
||||
query = """
|
||||
MERGE (n:`{label}`)
|
||||
SET n += {properties}
|
||||
"""
|
||||
query = """MERGE (n:`{label}`)
|
||||
SET n += {properties}"""
|
||||
params = {
|
||||
"label": PGGraphStorage._encode_graph_label(label),
|
||||
"properties": PGGraphStorage._format_properties(properties),
|
||||
@@ -957,22 +935,22 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
edge_properties = edge_data
|
||||
logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}")
|
||||
|
||||
query = """
|
||||
MATCH (source:`{src_label}`)
|
||||
query = """MATCH (source:`{src_label}`)
|
||||
WITH source
|
||||
MATCH (target:`{tgt_label}`)
|
||||
MERGE (source)-[r:DIRECTED]->(target)
|
||||
SET r += {properties}
|
||||
RETURN r
|
||||
"""
|
||||
RETURN r"""
|
||||
params = {
|
||||
"src_label": PGGraphStorage._encode_graph_label(source_node_label),
|
||||
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
|
||||
"properties": PGGraphStorage._format_properties(edge_properties),
|
||||
}
|
||||
# logger.info(f"-- inserting edge after formatted: {params}")
|
||||
try:
|
||||
await self._query(query, readonly=False, **params)
|
||||
await self._query(query, readonly=False, upsert_edge=True, **params)
|
||||
logger.debug(
|
||||
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
||||
source_node_label,
|
||||
@@ -1127,7 +1105,7 @@ SQL_TEMPLATES = {
|
||||
updatetime = CURRENT_TIMESTAMP
|
||||
""",
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET entity_name=EXCLUDED.entity_name,
|
||||
content=EXCLUDED.content,
|
||||
|
Reference in New Issue
Block a user