diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index 7574521a..1876bb8c 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -6,7 +6,7 @@ import time from dotenv import load_dotenv from lightrag import LightRAG, QueryParam -from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage +from lightrag.kg.postgres_impl import PostgreSQLDB from lightrag.llm import ollama_embedding, zhipu_complete from lightrag.utils import EmbeddingFunc @@ -67,7 +67,6 @@ async def main(): rag.entities_vdb.db = postgres_db rag.graph_storage_cls.db = postgres_db rag.chunk_entity_relation_graph.db = postgres_db - await rag.chunk_entity_relation_graph.check_graph_exists() # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func @@ -103,21 +102,6 @@ async def main(): ) print(f"Hybrid Query Time: {time.time() - start_time} seconds") - print("**** Start Stream Query ****") - start_time = time.time() - # stream response - resp = await rag.aquery( - "What are the top themes in this story?", - param=QueryParam(mode="hybrid", stream=True), - ) - print(f"Stream Query Time: {time.time() - start_time} seconds") - print("**** Done Stream Query ****") - - if inspect.isasyncgen(resp): - asyncio.run(print_stream(resp)) - else: - print(resp) - if __name__ == "__main__": asyncio.run(main()) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 80e97b16..8d60f471 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -81,12 +81,12 @@ class PostgreSQLDB: async def query( - self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False + self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None ) -> Union[dict, None, list[dict]]: async with self.pool.acquire() as connection: try: if for_age: - await connection.execute('SET search_path = ag_catalog, "$user", public') + await PostgreSQLDB._prerequisite(connection, graph_name) if params: rows = await connection.fetch(sql, *params.values()) else: @@ -95,10 +95,7 @@ class PostgreSQLDB: if multirows: if rows: columns = [col for col in rows[0].keys()] - # print("columns", columns.__class__, columns) - # print("rows", rows) data = [dict(zip(columns, row)) for row in rows] - # print("data", data) else: data = [] else: @@ -114,11 +111,11 @@ class PostgreSQLDB: print(params) raise - async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False): + async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None): try: async with self.pool.acquire() as connection: if for_age: - await connection.execute('SET search_path = ag_catalog, "$user", public') + await PostgreSQLDB._prerequisite(connection, graph_name) if data is None: await connection.execute(sql) @@ -130,6 +127,14 @@ class PostgreSQLDB: print(data) raise + @staticmethod + async def _prerequisite(conn: asyncpg.Connection, graph_name: str): + try: + await conn.execute(f'SET search_path = ag_catalog, "$user", public') + await conn.execute(f"""select create_graph('{graph_name}')""") + except asyncpg.exceptions.InvalidSchemaNameError: + pass + @dataclass class PGKVStorage(BaseKVStorage): @@ -346,18 +351,14 @@ class PGVectorStorage(BaseVectorStorage): embeddings = await self.embedding_func([query]) embedding = embeddings[0] embedding_string = ",".join(map(str, embedding)) - # print("Namespace", self.namespace) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) - # print("sql is: ", sql) params = { "workspace": self.db.workspace, "better_than_threshold": self.cosine_better_than_threshold, "top_k": top_k, } - # print("params is: ", params) results = await self.db.query(sql, params=params, multirows=True) - print("vector search result:", results) return results @dataclass @@ -487,19 +488,6 @@ class PGGraphStorage(BaseGraphStorage): async def index_done_callback(self): print("KG successfully indexed.") - async def check_graph_exists(self): - try: - res = await self.db.query(f"SELECT * FROM ag_catalog.ag_graph WHERE name = '{self.graph_name}'") - if res: - logger.info(f"Graph {self.graph_name} exists.") - else: - logger.info(f"Graph {self.graph_name} does not exist. Creating...") - await self.db.execute(f"SELECT create_graph('{self.graph_name}')", for_age=True) - logger.info(f"Graph {self.graph_name} created.") - except Exception as e: - logger.info(f"Failed to check/create graph {self.graph_name}:", e) - raise e - @staticmethod def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: """ @@ -572,7 +560,7 @@ class PGGraphStorage(BaseGraphStorage): Args: properties (Dict[str,str]): a dictionary containing node/edge properties - id (Union[str, None]): the id of the node or None if none exists + _id (Union[str, None]): the id of the node or None if none exists Returns: str: the properties dictionary as a properly formatted string @@ -591,7 +579,7 @@ class PGGraphStorage(BaseGraphStorage): @staticmethod def _encode_graph_label(label: str) -> str: """ - Since AGE suports only alphanumerical labels, we will encode generic label as HEX string + Since AGE supports only alphanumerical labels, we will encode generic label as HEX string Args: label (str): the original label @@ -604,7 +592,7 @@ class PGGraphStorage(BaseGraphStorage): @staticmethod def _decode_graph_label(encoded_label: str) -> str: """ - Since AGE suports only alphanumerical labels, we will encode generic label as HEX string + Since AGE supports only alphanumerical labels, we will encode generic label as HEX string Args: encoded_label (str): the encoded label @@ -656,8 +644,8 @@ class PGGraphStorage(BaseGraphStorage): # pgsql template template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ - {query} - $$) AS ({fields});""" + {query} + $$) AS ({fields})""" # if there are any returned fields they must be added to the pgsql query if "return" in query.lower(): @@ -702,7 +690,7 @@ class PGGraphStorage(BaseGraphStorage): projection=select_str, ) - async def _query(self, query: str, readonly=True, **params: str) -> List[Dict[str, Any]]: + async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an age compatible query, executing it and converting the result @@ -720,9 +708,14 @@ class PGGraphStorage(BaseGraphStorage): # execute the query, rolling back on an error try: if readonly: - data = await self.db.query(wrapped_query, multirows=True, for_age=True) + data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name) else: - data = await self.db.execute(wrapped_query, for_age=True) + # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING) + # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future. + if upsert_edge: + data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name) + else: + data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name) except Exception as e: raise PGGraphQueryException( { @@ -743,9 +736,7 @@ class PGGraphStorage(BaseGraphStorage): async def has_node(self, node_id: str) -> bool: entity_name_label = node_id.strip('"') - query = """ - MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists - """ + query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists""" params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} single_result = (await self._query(query, **params))[0] logger.debug( @@ -761,10 +752,8 @@ class PGGraphStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - query = """ - MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) - RETURN COUNT(r) > 0 AS edge_exists - """ + query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) + RETURN COUNT(r) > 0 AS edge_exists""" params = { "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), @@ -780,9 +769,7 @@ class PGGraphStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: entity_name_label = node_id.strip('"') - query = """ - MATCH (n:`{label}`) RETURN n - """ + query = """MATCH (n:`{label}`) RETURN n""" params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} record = await self._query(query, **params) if record: @@ -800,10 +787,7 @@ class PGGraphStorage(BaseGraphStorage): async def node_degree(self, node_id: str) -> int: entity_name_label = node_id.strip('"') - query = """ - MATCH (n:`{label}`)-[]->(x) - RETURN count(x) AS total_edge_count - """ + query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count""" params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)} record = (await self._query(query, **params))[0] if record: @@ -841,8 +825,8 @@ class PGGraphStorage(BaseGraphStorage): Find all edges between nodes of two given labels Args: - source_node_label (str): Label of the source nodes - target_node_label (str): Label of the target nodes + source_node_id (str): Label of the source nodes + target_node_id (str): Label of the target nodes Returns: list: List of all relationships/edges found @@ -850,11 +834,9 @@ class PGGraphStorage(BaseGraphStorage): entity_name_label_source = source_node_id.strip('"') entity_name_label_target = target_node_id.strip('"') - query = """ - MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) + query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`) RETURN properties(r) as edge_properties - LIMIT 1 - """ + LIMIT 1""" params = { "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source), "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target), @@ -877,11 +859,9 @@ class PGGraphStorage(BaseGraphStorage): """ node_label = source_node_id.strip('"') - query = """ - MATCH (n:`{label}`) + query = """MATCH (n:`{label}`) OPTIONAL MATCH (n)-[r]-(connected) - RETURN n, r, connected - """ + RETURN n, r, connected""" params = {"label": PGGraphStorage._encode_graph_label(node_label)} results = await self._query(query, **params) edges = [] @@ -919,10 +899,8 @@ class PGGraphStorage(BaseGraphStorage): label = node_id.strip('"') properties = node_data - query = """ - MERGE (n:`{label}`) - SET n += {properties} - """ + query = """MERGE (n:`{label}`) + SET n += {properties}""" params = { "label": PGGraphStorage._encode_graph_label(label), "properties": PGGraphStorage._format_properties(properties), @@ -957,22 +935,22 @@ class PGGraphStorage(BaseGraphStorage): source_node_label = source_node_id.strip('"') target_node_label = target_node_id.strip('"') edge_properties = edge_data + logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}") - query = """ - MATCH (source:`{src_label}`) + query = """MATCH (source:`{src_label}`) WITH source MATCH (target:`{tgt_label}`) MERGE (source)-[r:DIRECTED]->(target) SET r += {properties} - RETURN r - """ + RETURN r""" params = { "src_label": PGGraphStorage._encode_graph_label(source_node_label), "tgt_label": PGGraphStorage._encode_graph_label(target_node_label), "properties": PGGraphStorage._format_properties(edge_properties), } + # logger.info(f"-- inserting edge after formatted: {params}") try: - await self._query(query, readonly=False, **params) + await self._query(query, readonly=False, upsert_edge=True, **params) logger.debug( "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", source_node_label, @@ -1127,7 +1105,7 @@ SQL_TEMPLATES = { updatetime = CURRENT_TIMESTAMP """, "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector) - VALUES ($1, $2, $3, $4, $5, $6) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT (workspace,id) DO UPDATE SET entity_name=EXCLUDED.entity_name, content=EXCLUDED.content, diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py new file mode 100644 index 00000000..66cdfcc7 --- /dev/null +++ b/lightrag/kg/postgres_impl_test.py @@ -0,0 +1,122 @@ +import asyncio +import asyncpg +import sys, os + +import psycopg +from psycopg_pool import AsyncConnectionPool +from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage + +DB="rag" +USER="rag" +PASSWORD="rag" +HOST="localhost" +PORT="15432" +os.environ["AGE_GRAPH_NAME"] = "dickens" + +if sys.platform.startswith("win"): + import asyncio.windows_events + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +async def get_pool(): + return await asyncpg.create_pool( + f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", + min_size=10, # 连接池初始化时默认的最小连接数, 默认为1 0 + max_size=10, # 连接池的最大连接数, 默认为 10 + max_queries=5000, # 每个链接最大查询数量, 超过了就换新的连接, 默认 5000 + # 最大不活跃时间, 默认 300.0, 超过这个时间的连接就会被关闭, 传入 0 的话则永不关闭 + max_inactive_connection_lifetime=300.0 + ) + +async def main1(): + connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" + pool = AsyncConnectionPool(connection_string, open=False) + await pool.open() + + try: + conn = await pool.getconn(timeout=10) + async with conn.cursor() as curs: + try: + await curs.execute('SET search_path = ag_catalog, "$user", public') + await curs.execute(f"SELECT create_graph('dickens-2')") + await conn.commit() + print("create_graph success") + except ( + psycopg.errors.InvalidSchemaName, + psycopg.errors.UniqueViolation, + ): + print("create_graph already exists") + await conn.rollback() + finally: + pass + +db = PostgreSQLDB( + config={ + "host": "localhost", + "port": 15432, + "user": "rag", + "password": "rag", + "database": "rag", + } +) + +async def query_with_age(): + await db.initdb() + graph = PGGraphStorage( + namespace="chunk_entity_relation", + global_config={}, + embedding_func=None, + ) + graph.db = db + res = await graph.get_node('"CHRISTMAS-TIME"') + print("Node is: ", res) + +async def create_edge_with_age(): + await db.initdb() + graph = PGGraphStorage( + namespace="chunk_entity_relation", + global_config={}, + embedding_func=None, + ) + graph.db = db + await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"}) + await graph.upsert_node('"THE GIRLS"', {"world": "hello"}) + await graph.upsert_edge( + '"THE CRATCHITS"', + '"THE GIRLS"', + edge_data={ + "weight": 7.0, + "description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.', + "keywords": '"family, collective effort"', + "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", + }, + ) + res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"') + print("Edge is: ", res) + + +async def main(): + pool = await get_pool() + # 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs + # 专门用来设置一些数据库独有的某些属性 + # 从池子中取出一个连接 + sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" + # cypher = "MATCH (n:how_are_you_doing) RETURN n" + async with pool.acquire() as conn: + try: + await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""") + except asyncpg.exceptions.InvalidSchemaNameError: + print("create_graph already exists") + # stmt = await conn.prepare(sql) + row = await conn.fetch(sql) + print("row is: ", row) + + # 解决办法就是起一个别名 + row = await conn.fetchrow("select '100'::int + 200 as result") + print(row) # + # 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面 + + +if __name__ == '__main__': + asyncio.run(query_with_age()) + +