diff --git a/examples/lightrag_zhipu_postgres_demo.py b/examples/lightrag_zhipu_postgres_demo.py index 1462fdd2..f28480df 100644 --- a/examples/lightrag_zhipu_postgres_demo.py +++ b/examples/lightrag_zhipu_postgres_demo.py @@ -53,7 +53,7 @@ async def main(): kv_storage="PGKVStorage", doc_status_storage="PGDocStatusStorage", graph_storage="PGGraphStorage", - vector_storage="PGVectorStorage" + vector_storage="PGVectorStorage", ) # Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool rag.doc_status.db = postgres_db @@ -77,27 +77,35 @@ async def main(): start_time = time.time() # Perform naive search print( - await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="naive")) + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="naive") + ) ) print(f"Naive Query Time: {time.time() - start_time} seconds") # Perform local search print("**** Start Local Query ****") start_time = time.time() print( - await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="local")) + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) ) print(f"Local Query Time: {time.time() - start_time} seconds") # Perform global search print("**** Start Global Query ****") start_time = time.time() print( - await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="global")) + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) ) print(f"Global Query Time: {time.time() - start_time}") # Perform hybrid search print("**** Start Hybrid Query ****") print( - await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="hybrid")) + await rag.aquery( + "What are the top themes in this story?", param=QueryParam(mode="hybrid") + ) ) print(f"Hybrid Query Time: {time.time() - start_time} seconds") diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index dc027113..704fa476 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -19,7 +19,11 @@ from tenacity import ( from ..utils import logger from ..base import ( BaseKVStorage, - BaseVectorStorage, DocStatusStorage, DocStatus, DocProcessingStatus, BaseGraphStorage, + BaseVectorStorage, + DocStatusStorage, + DocStatus, + DocProcessingStatus, + BaseGraphStorage, ) if sys.platform.startswith("win"): @@ -36,14 +40,15 @@ class PostgreSQLDB: self.user = config.get("user", "postgres") self.password = config.get("password", None) self.database = config.get("database", "postgres") - self.workspace = config.get("workspace", 'default') + self.workspace = config.get("workspace", "default") self.max = 12 self.increment = 1 logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier") if self.user is None or self.password is None or self.database is None: - raise ValueError("Missing database user, password, or database in addon_params") - + raise ValueError( + "Missing database user, password, or database in addon_params" + ) async def initdb(self): try: @@ -54,12 +59,16 @@ class PostgreSQLDB: host=self.host, port=self.port, min_size=1, - max_size=self.max + max_size=self.max, ) - logger.info(f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}") + logger.info( + f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}" + ) except Exception as e: - logger.error(f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}") + logger.error( + f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}" + ) logger.error(f"PostgreSQL database error: {e}") raise @@ -79,9 +88,13 @@ class PostgreSQLDB: logger.info("Finished checking all tables in PostgreSQL database") - async def query( - self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None + self, + sql: str, + params: dict = None, + multirows: bool = False, + for_age: bool = False, + graph_name: str = None, ) -> Union[dict, None, list[dict]]: async with self.pool.acquire() as connection: try: @@ -111,7 +124,13 @@ class PostgreSQLDB: print(params) raise - async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None): + async def execute( + self, + sql: str, + data: Union[list, dict] = None, + for_age: bool = False, + graph_name: str = None, + ): try: async with self.pool.acquire() as connection: if for_age: @@ -130,7 +149,7 @@ class PostgreSQLDB: @staticmethod async def _prerequisite(conn: asyncpg.Connection, graph_name: str): try: - await conn.execute(f'SET search_path = ag_catalog, "$user", public') + await conn.execute('SET search_path = ag_catalog, "$user", public') await conn.execute(f"""select create_graph('{graph_name}')""") except asyncpg.exceptions.InvalidSchemaNameError: pass @@ -138,7 +157,7 @@ class PostgreSQLDB: @dataclass class PGKVStorage(BaseKVStorage): - db:PostgreSQLDB = None + db: PostgreSQLDB = None def __post_init__(self): self._data = {} @@ -180,7 +199,7 @@ class PGKVStorage(BaseKVStorage): dict_res[mode] = {} for row in array_res: dict_res[row["mode"]][row["id"]] = row - res = [{k:v} for k,v in dict_res.items()] + res = [{k: v} for k, v in dict_res.items()] else: res = await self.db.query(sql, params, multirows=True) if res: @@ -191,7 +210,8 @@ class PGKVStorage(BaseKVStorage): async def filter_keys(self, keys: List[str]) -> Set[str]: """Filter out duplicated content""" sql = SQL_TEMPLATES["filter_keys"].format( - table_name=NAMESPACE_TABLE_MAP[self.namespace], ids=",".join([f"'{id}'" for id in keys]) + table_name=NAMESPACE_TABLE_MAP[self.namespace], + ids=",".join([f"'{id}'" for id in keys]), ) params = {"workspace": self.db.workspace} try: @@ -207,7 +227,6 @@ class PGKVStorage(BaseKVStorage): print(sql) print(params) - ################ INSERT METHODS ################ async def upsert(self, data: Dict[str, dict]): left_data = {k: v for k, v in data.items() if k not in self._data} @@ -246,7 +265,7 @@ class PGKVStorage(BaseKVStorage): @dataclass class PGVectorStorage(BaseVectorStorage): cosine_better_than_threshold: float = 0.2 - db:PostgreSQLDB = None + db: PostgreSQLDB = None def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] @@ -282,6 +301,7 @@ class PGVectorStorage(BaseVectorStorage): "content_vector": json.dumps(item["__vector__"].tolist()), } return upsert_sql, data + def _upsert_relationships(self, item: dict): upsert_sql = SQL_TEMPLATES["upsert_relationship"] data = { @@ -340,8 +360,6 @@ class PGVectorStorage(BaseVectorStorage): await self.db.execute(upsert_sql, data) - - async def index_done_callback(self): logger.info("vector data had been saved into postgresql db!") @@ -350,7 +368,7 @@ class PGVectorStorage(BaseVectorStorage): """从向量数据库中查询数据""" embeddings = await self.embedding_func([query]) embedding = embeddings[0] - embedding_string = ",".join(map(str, embedding)) + embedding_string = ",".join(map(str, embedding)) sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) params = { @@ -361,10 +379,12 @@ class PGVectorStorage(BaseVectorStorage): results = await self.db.query(sql, params=params, multirows=True) return results + @dataclass class PGDocStatusStorage(DocStatusStorage): """PostgreSQL implementation of document status storage""" - db:PostgreSQLDB = None + + db: PostgreSQLDB = None def __post_init__(self): pass @@ -372,41 +392,47 @@ class PGDocStatusStorage(DocStatusStorage): async def filter_keys(self, data: list[str]) -> set[str]: """Return keys that don't exist in storage""" sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})" - result = await self.db.query(sql, {'workspace': self.db.workspace}, True) + result = await self.db.query(sql, {"workspace": self.db.workspace}, True) # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...]. if result is None: return set(data) else: - existed = set([element['id'] for element in result]) + existed = set([element["id"] for element in result]) return set(data) - existed async def get_status_counts(self) -> Dict[str, int]: """Get counts of documents in each status""" - sql = '''SELECT status as "status", COUNT(1) as "count" + sql = """SELECT status as "status", COUNT(1) as "count" FROM LIGHTRAG_DOC_STATUS where workspace=$1 GROUP BY STATUS - ''' - result = await self.db.query(sql, {'workspace': self.db.workspace}, True) + """ + result = await self.db.query(sql, {"workspace": self.db.workspace}, True) # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] counts = {} for doc in result: counts[doc["status"]] = doc["count"] return counts - async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]: + async def get_docs_by_status( + self, status: DocStatus + ) -> Dict[str, DocProcessingStatus]: """Get all documents by status""" - sql = 'select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1' - params = {'workspace': self.db.workspace, 'status': status} + sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1" + params = {"workspace": self.db.workspace, "status": status} result = await self.db.query(sql, params, True) # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...] # Converting to be a dict - return {element["id"]: - DocProcessingStatus(content_summary=element["content_summary"], - content_length=element["content_length"], - status=element["status"], - created_at=element["created_at"], - updated_at=element["updated_at"], - chunks_count=element["chunks_count"]) for element in result} + return { + element["id"]: DocProcessingStatus( + content_summary=element["content_summary"], + content_length=element["content_length"], + status=element["status"], + created_at=element["created_at"], + updated_at=element["updated_at"], + chunks_count=element["chunks_count"], + ) + for element in result + } async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]: """Get all failed documents""" @@ -436,14 +462,17 @@ class PGDocStatusStorage(DocStatusStorage): updated_at = CURRENT_TIMESTAMP""" for k, v in data.items(): # chunks_count is optional - await self.db.execute(sql, { - "workspace": self.db.workspace, - "id": k, - "content_summary": v["content_summary"], - "content_length": v["content_length"], - "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, - "status": v["status"], - }) + await self.db.execute( + sql, + { + "workspace": self.db.workspace, + "id": k, + "content_summary": v["content_summary"], + "content_length": v["content_length"], + "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, + "status": v["status"], + }, + ) return data @@ -467,7 +496,7 @@ class PGGraphQueryException(Exception): @dataclass class PGGraphStorage(BaseGraphStorage): - db:PostgreSQLDB = None + db: PostgreSQLDB = None @staticmethod def load_nx_graph(file_name): @@ -484,7 +513,6 @@ class PGGraphStorage(BaseGraphStorage): "node2vec": self._node2vec_embed, } - async def index_done_callback(self): print("KG successfully indexed.") @@ -552,7 +580,7 @@ class PGGraphStorage(BaseGraphStorage): @staticmethod def _format_properties( - properties: Dict[str, Any], _id: Union[str, None] = None + properties: Dict[str, Any], _id: Union[str, None] = None ) -> str: """ Convert a dictionary of properties to a string representation that @@ -669,7 +697,8 @@ class PGGraphStorage(BaseGraphStorage): # get pgsql formatted field names fields = [ - PGGraphStorage._get_col_name(field, idx) for idx, field in enumerate(fields) + PGGraphStorage._get_col_name(field, idx) + for idx, field in enumerate(fields) ] # build resulting pgsql relation @@ -690,7 +719,9 @@ class PGGraphStorage(BaseGraphStorage): projection=select_str, ) - async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]: + async def _query( + self, query: str, readonly=True, upsert_edge=False, **params: str + ) -> List[Dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an age compatible query, executing it and converting the result @@ -708,14 +739,25 @@ class PGGraphStorage(BaseGraphStorage): # execute the query, rolling back on an error try: if readonly: - data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name) + data = await self.db.query( + wrapped_query, + multirows=True, + for_age=True, + graph_name=self.graph_name, + ) else: # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING) # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future. if upsert_edge: - data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name) + data = await self.db.execute( + f"{wrapped_query};{wrapped_query};", + for_age=True, + graph_name=self.graph_name, + ) else: - data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name) + data = await self.db.execute( + wrapped_query, for_age=True, graph_name=self.graph_name + ) except Exception as e: raise PGGraphQueryException( { @@ -819,7 +861,7 @@ class PGGraphStorage(BaseGraphStorage): return degrees async def get_edge( - self, source_node_id: str, target_node_id: str + self, source_node_id: str, target_node_id: str ) -> Union[dict, None]: """ Find all edges between nodes of two given labels @@ -922,7 +964,7 @@ class PGGraphStorage(BaseGraphStorage): retry=retry_if_exception_type((PGGraphQueryException,)), ) async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] + self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any] ): """ Upsert an edge and its properties between two nodes identified by their labels. @@ -935,7 +977,9 @@ class PGGraphStorage(BaseGraphStorage): source_node_label = source_node_id.strip('"') target_node_label = target_node_id.strip('"') edge_properties = edge_data - logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}") + logger.info( + f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}" + ) query = """MATCH (source:`{src_label}`) WITH source @@ -1056,7 +1100,6 @@ TABLES = { } - SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content @@ -1107,7 +1150,7 @@ SQL_TEMPLATES = { "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (workspace,id) DO UPDATE - SET entity_name=EXCLUDED.entity_name, + SET entity_name=EXCLUDED.entity_name, content=EXCLUDED.content, content_vector=EXCLUDED.content_vector, updatetime=CURRENT_TIMESTAMP @@ -1136,5 +1179,5 @@ SQL_TEMPLATES = { (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance FROM LIGHTRAG_DOC_CHUNKS where workspace=$1) WHERE distance>$2 ORDER BY distance DESC LIMIT $3 - """ + """, } diff --git a/lightrag/kg/postgres_impl_test.py b/lightrag/kg/postgres_impl_test.py index d98849c1..dc046311 100644 --- a/lightrag/kg/postgres_impl_test.py +++ b/lightrag/kg/postgres_impl_test.py @@ -1,33 +1,39 @@ import asyncio import asyncpg -import sys, os +import sys +import os import psycopg from psycopg_pool import AsyncConnectionPool from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage -DB="rag" -USER="rag" -PASSWORD="rag" -HOST="localhost" -PORT="15432" +DB = "rag" +USER = "rag" +PASSWORD = "rag" +HOST = "localhost" +PORT = "15432" os.environ["AGE_GRAPH_NAME"] = "dickens" if sys.platform.startswith("win"): import asyncio.windows_events + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + async def get_pool(): return await asyncpg.create_pool( f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}", min_size=10, max_size=10, max_queries=5000, - max_inactive_connection_lifetime=300.0 + max_inactive_connection_lifetime=300.0, ) + async def main1(): - connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" + connection_string = ( + f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}" + ) pool = AsyncConnectionPool(connection_string, open=False) await pool.open() @@ -36,18 +42,19 @@ async def main1(): async with conn.cursor() as curs: try: await curs.execute('SET search_path = ag_catalog, "$user", public') - await curs.execute(f"SELECT create_graph('dickens-2')") + await curs.execute("SELECT create_graph('dickens-2')") await conn.commit() print("create_graph success") except ( - psycopg.errors.InvalidSchemaName, - psycopg.errors.UniqueViolation, + psycopg.errors.InvalidSchemaName, + psycopg.errors.UniqueViolation, ): print("create_graph already exists") await conn.rollback() finally: pass + db = PostgreSQLDB( config={ "host": "localhost", @@ -58,6 +65,7 @@ db = PostgreSQLDB( } ) + async def query_with_age(): await db.initdb() graph = PGGraphStorage( @@ -69,6 +77,7 @@ async def query_with_age(): res = await graph.get_node('"CHRISTMAS-TIME"') print("Node is: ", res) + async def create_edge_with_age(): await db.initdb() graph = PGGraphStorage( @@ -89,31 +98,28 @@ async def create_edge_with_age(): "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8", }, ) - res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"') + res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"') print("Edge is: ", res) async def main(): pool = await get_pool() - # 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs - # 专门用来设置一些数据库独有的某些属性 - # 从池子中取出一个连接 sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)" # cypher = "MATCH (n:how_are_you_doing) RETURN n" async with pool.acquire() as conn: - try: - await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""") - except asyncpg.exceptions.InvalidSchemaNameError: - print("create_graph already exists") - # stmt = await conn.prepare(sql) - row = await conn.fetch(sql) - print("row is: ", row) + try: + await conn.execute( + """SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""" + ) + except asyncpg.exceptions.InvalidSchemaNameError: + print("create_graph already exists") + # stmt = await conn.prepare(sql) + row = await conn.fetch(sql) + print("row is: ", row) - # 解决办法就是起一个别名 - row = await conn.fetchrow("select '100'::int + 200 as result") - print(row) # - # 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面 + row = await conn.fetchrow("select '100'::int + 200 as result") + print(row) # -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(query_with_age()) diff --git a/requirements.txt b/requirements.txt index 82252628..79249e7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ accelerate aioboto3~=13.3.0 +aiofiles~=24.1.0 aiohttp~=3.11.11 +asyncpg~=0.30.0 # database packages graspologic @@ -9,14 +11,20 @@ hnswlib nano-vectordb neo4j~=5.27.0 networkx~=3.2.1 + +numpy~=2.2.0 ollama~=0.4.4 openai~=1.58.1 oracledb +psycopg-pool~=3.2.4 psycopg[binary,pool]~=3.2.3 +pydantic~=2.10.4 pymilvus pymongo pymysql +python-dotenv~=1.0.1 pyvis~=0.3.2 +setuptools~=70.0.0 # lmdeploy[all] sqlalchemy~=2.0.36 tenacity~=9.0.0 @@ -25,14 +33,6 @@ tenacity~=9.0.0 # LLM packages tiktoken~=0.8.0 torch~=2.5.1+cu121 +tqdm~=4.67.1 transformers~=4.47.1 xxhash - -numpy~=2.2.0 -aiofiles~=24.1.0 -pydantic~=2.10.4 -python-dotenv~=1.0.1 -psycopg-pool~=3.2.4 -tqdm~=4.67.1 -asyncpg~=0.30.0 -setuptools~=70.0.0 \ No newline at end of file