Fix the bug of AGE processing

This commit is contained in:
Samuel Chan
2025-01-03 21:10:06 +08:00
parent b17cb2aa95
commit f6f62c32a8
3 changed files with 167 additions and 83 deletions

View File

@@ -6,7 +6,7 @@ import time
from dotenv import load_dotenv
from lightrag import LightRAG, QueryParam
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
from lightrag.kg.postgres_impl import PostgreSQLDB
from lightrag.llm import ollama_embedding, zhipu_complete
from lightrag.utils import EmbeddingFunc
@@ -67,7 +67,6 @@ async def main():
rag.entities_vdb.db = postgres_db
rag.graph_storage_cls.db = postgres_db
rag.chunk_entity_relation_graph.db = postgres_db
await rag.chunk_entity_relation_graph.check_graph_exists()
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
@@ -103,21 +102,6 @@ async def main():
)
print(f"Hybrid Query Time: {time.time() - start_time} seconds")
print("**** Start Stream Query ****")
start_time = time.time()
# stream response
resp = await rag.aquery(
"What are the top themes in this story?",
param=QueryParam(mode="hybrid", stream=True),
)
print(f"Stream Query Time: {time.time() - start_time} seconds")
print("**** Done Stream Query ****")
if inspect.isasyncgen(resp):
asyncio.run(print_stream(resp))
else:
print(resp)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -81,12 +81,12 @@ class PostgreSQLDB:
async def query(
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None
) -> Union[dict, None, list[dict]]:
async with self.pool.acquire() as connection:
try:
if for_age:
await connection.execute('SET search_path = ag_catalog, "$user", public')
await PostgreSQLDB._prerequisite(connection, graph_name)
if params:
rows = await connection.fetch(sql, *params.values())
else:
@@ -95,10 +95,7 @@ class PostgreSQLDB:
if multirows:
if rows:
columns = [col for col in rows[0].keys()]
# print("columns", columns.__class__, columns)
# print("rows", rows)
data = [dict(zip(columns, row)) for row in rows]
# print("data", data)
else:
data = []
else:
@@ -114,11 +111,11 @@ class PostgreSQLDB:
print(params)
raise
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False):
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None):
try:
async with self.pool.acquire() as connection:
if for_age:
await connection.execute('SET search_path = ag_catalog, "$user", public')
await PostgreSQLDB._prerequisite(connection, graph_name)
if data is None:
await connection.execute(sql)
@@ -130,6 +127,14 @@ class PostgreSQLDB:
print(data)
raise
@staticmethod
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
try:
await conn.execute(f'SET search_path = ag_catalog, "$user", public')
await conn.execute(f"""select create_graph('{graph_name}')""")
except asyncpg.exceptions.InvalidSchemaNameError:
pass
@dataclass
class PGKVStorage(BaseKVStorage):
@@ -346,18 +351,14 @@ class PGVectorStorage(BaseVectorStorage):
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
embedding_string = ",".join(map(str, embedding))
# print("Namespace", self.namespace)
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
# print("sql is: ", sql)
params = {
"workspace": self.db.workspace,
"better_than_threshold": self.cosine_better_than_threshold,
"top_k": top_k,
}
# print("params is: ", params)
results = await self.db.query(sql, params=params, multirows=True)
print("vector search result:", results)
return results
@dataclass
@@ -487,19 +488,6 @@ class PGGraphStorage(BaseGraphStorage):
async def index_done_callback(self):
print("KG successfully indexed.")
async def check_graph_exists(self):
try:
res = await self.db.query(f"SELECT * FROM ag_catalog.ag_graph WHERE name = '{self.graph_name}'")
if res:
logger.info(f"Graph {self.graph_name} exists.")
else:
logger.info(f"Graph {self.graph_name} does not exist. Creating...")
await self.db.execute(f"SELECT create_graph('{self.graph_name}')", for_age=True)
logger.info(f"Graph {self.graph_name} created.")
except Exception as e:
logger.info(f"Failed to check/create graph {self.graph_name}:", e)
raise e
@staticmethod
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
"""
@@ -572,7 +560,7 @@ class PGGraphStorage(BaseGraphStorage):
Args:
properties (Dict[str,str]): a dictionary containing node/edge properties
id (Union[str, None]): the id of the node or None if none exists
_id (Union[str, None]): the id of the node or None if none exists
Returns:
str: the properties dictionary as a properly formatted string
@@ -591,7 +579,7 @@ class PGGraphStorage(BaseGraphStorage):
@staticmethod
def _encode_graph_label(label: str) -> str:
"""
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
Args:
label (str): the original label
@@ -604,7 +592,7 @@ class PGGraphStorage(BaseGraphStorage):
@staticmethod
def _decode_graph_label(encoded_label: str) -> str:
"""
Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
Args:
encoded_label (str): the encoded label
@@ -657,7 +645,7 @@ class PGGraphStorage(BaseGraphStorage):
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""
$$) AS ({fields})"""
# if there are any returned fields they must be added to the pgsql query
if "return" in query.lower():
@@ -702,7 +690,7 @@ class PGGraphStorage(BaseGraphStorage):
projection=select_str,
)
async def _query(self, query: str, readonly=True, **params: str) -> List[Dict[str, Any]]:
async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result
@@ -720,9 +708,14 @@ class PGGraphStorage(BaseGraphStorage):
# execute the query, rolling back on an error
try:
if readonly:
data = await self.db.query(wrapped_query, multirows=True, for_age=True)
data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name)
else:
data = await self.db.execute(wrapped_query, for_age=True)
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
if upsert_edge:
data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name)
else:
data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name)
except Exception as e:
raise PGGraphQueryException(
{
@@ -743,9 +736,7 @@ class PGGraphStorage(BaseGraphStorage):
async def has_node(self, node_id: str) -> bool:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
"""
query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"""
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
single_result = (await self._query(query, **params))[0]
logger.debug(
@@ -761,10 +752,8 @@ class PGGraphStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
query = """
MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
RETURN COUNT(r) > 0 AS edge_exists
"""
query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
RETURN COUNT(r) > 0 AS edge_exists"""
params = {
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
@@ -780,9 +769,7 @@ class PGGraphStorage(BaseGraphStorage):
async def get_node(self, node_id: str) -> Union[dict, None]:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`) RETURN n
"""
query = """MATCH (n:`{label}`) RETURN n"""
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
record = await self._query(query, **params)
if record:
@@ -800,10 +787,7 @@ class PGGraphStorage(BaseGraphStorage):
async def node_degree(self, node_id: str) -> int:
entity_name_label = node_id.strip('"')
query = """
MATCH (n:`{label}`)-[]->(x)
RETURN count(x) AS total_edge_count
"""
query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count"""
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
record = (await self._query(query, **params))[0]
if record:
@@ -841,8 +825,8 @@ class PGGraphStorage(BaseGraphStorage):
Find all edges between nodes of two given labels
Args:
source_node_label (str): Label of the source nodes
target_node_label (str): Label of the target nodes
source_node_id (str): Label of the source nodes
target_node_id (str): Label of the target nodes
Returns:
list: List of all relationships/edges found
@@ -850,11 +834,9 @@ class PGGraphStorage(BaseGraphStorage):
entity_name_label_source = source_node_id.strip('"')
entity_name_label_target = target_node_id.strip('"')
query = """
MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
RETURN properties(r) as edge_properties
LIMIT 1
"""
LIMIT 1"""
params = {
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
@@ -877,11 +859,9 @@ class PGGraphStorage(BaseGraphStorage):
"""
node_label = source_node_id.strip('"')
query = """
MATCH (n:`{label}`)
query = """MATCH (n:`{label}`)
OPTIONAL MATCH (n)-[r]-(connected)
RETURN n, r, connected
"""
RETURN n, r, connected"""
params = {"label": PGGraphStorage._encode_graph_label(node_label)}
results = await self._query(query, **params)
edges = []
@@ -919,10 +899,8 @@ class PGGraphStorage(BaseGraphStorage):
label = node_id.strip('"')
properties = node_data
query = """
MERGE (n:`{label}`)
SET n += {properties}
"""
query = """MERGE (n:`{label}`)
SET n += {properties}"""
params = {
"label": PGGraphStorage._encode_graph_label(label),
"properties": PGGraphStorage._format_properties(properties),
@@ -957,22 +935,22 @@ class PGGraphStorage(BaseGraphStorage):
source_node_label = source_node_id.strip('"')
target_node_label = target_node_id.strip('"')
edge_properties = edge_data
logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}")
query = """
MATCH (source:`{src_label}`)
query = """MATCH (source:`{src_label}`)
WITH source
MATCH (target:`{tgt_label}`)
MERGE (source)-[r:DIRECTED]->(target)
SET r += {properties}
RETURN r
"""
RETURN r"""
params = {
"src_label": PGGraphStorage._encode_graph_label(source_node_label),
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
"properties": PGGraphStorage._format_properties(edge_properties),
}
# logger.info(f"-- inserting edge after formatted: {params}")
try:
await self._query(query, readonly=False, **params)
await self._query(query, readonly=False, upsert_edge=True, **params)
logger.debug(
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
source_node_label,
@@ -1127,7 +1105,7 @@ SQL_TEMPLATES = {
updatetime = CURRENT_TIMESTAMP
""",
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
VALUES ($1, $2, $3, $4, $5, $6)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (workspace,id) DO UPDATE
SET entity_name=EXCLUDED.entity_name,
content=EXCLUDED.content,

View File

@@ -0,0 +1,122 @@
import asyncio
import asyncpg
import sys, os
import psycopg
from psycopg_pool import AsyncConnectionPool
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
DB="rag"
USER="rag"
PASSWORD="rag"
HOST="localhost"
PORT="15432"
os.environ["AGE_GRAPH_NAME"] = "dickens"
if sys.platform.startswith("win"):
import asyncio.windows_events
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def get_pool():
return await asyncpg.create_pool(
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
min_size=10, # 连接池初始化时默认的最小连接数, 默认为1 0
max_size=10, # 连接池的最大连接数, 默认为 10
max_queries=5000, # 每个链接最大查询数量, 超过了就换新的连接, 默认 5000
# 最大不活跃时间, 默认 300.0, 超过这个时间的连接就会被关闭, 传入 0 的话则永不关闭
max_inactive_connection_lifetime=300.0
)
async def main1():
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
pool = AsyncConnectionPool(connection_string, open=False)
await pool.open()
try:
conn = await pool.getconn(timeout=10)
async with conn.cursor() as curs:
try:
await curs.execute('SET search_path = ag_catalog, "$user", public')
await curs.execute(f"SELECT create_graph('dickens-2')")
await conn.commit()
print("create_graph success")
except (
psycopg.errors.InvalidSchemaName,
psycopg.errors.UniqueViolation,
):
print("create_graph already exists")
await conn.rollback()
finally:
pass
db = PostgreSQLDB(
config={
"host": "localhost",
"port": 15432,
"user": "rag",
"password": "rag",
"database": "rag",
}
)
async def query_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace="chunk_entity_relation",
global_config={},
embedding_func=None,
)
graph.db = db
res = await graph.get_node('"CHRISTMAS-TIME"')
print("Node is: ", res)
async def create_edge_with_age():
await db.initdb()
graph = PGGraphStorage(
namespace="chunk_entity_relation",
global_config={},
embedding_func=None,
)
graph.db = db
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
await graph.upsert_edge(
'"THE CRATCHITS"',
'"THE GIRLS"',
edge_data={
"weight": 7.0,
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
"keywords": '"family, collective effort"',
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
},
)
res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"')
print("Edge is: ", res)
async def main():
pool = await get_pool()
# 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs
# 专门用来设置一些数据库独有的某些属性
# 从池子中取出一个连接
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
async with pool.acquire() as conn:
try:
await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""")
except asyncpg.exceptions.InvalidSchemaNameError:
print("create_graph already exists")
# stmt = await conn.prepare(sql)
row = await conn.fetch(sql)
print("row is: ", row)
# 解决办法就是起一个别名
row = await conn.fetchrow("select '100'::int + 200 as result")
print(row) # <Record result=300>
# 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面
if __name__ == '__main__':
asyncio.run(query_with_age())