From 05147f47996b9a8be6cdb1076711d24e9554f90b Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Wed, 19 Feb 2025 13:31:30 +0100 Subject: [PATCH 1/4] improved code of postgress and execution --- lightrag/kg/neo4j_impl.py | 4 - lightrag/kg/oracle_impl.py | 4 +- lightrag/kg/postgres_impl.py | 165 ++++++++++++++++++----------------- lightrag/kg/tidb_impl.py | 5 +- 4 files changed, 88 insertions(+), 90 deletions(-) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 03b1bbcb..82631cf8 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -43,10 +43,6 @@ config.read("config.ini", "utf-8") @final @dataclass class Neo4JStorage(BaseGraphStorage): - @staticmethod - def load_nx_graph(file_name): - print("no preloading of graph with neo4j in production") - def __init__(self, namespace, global_config, embedding_func): super().__init__( namespace=namespace, diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index 3e0c6799..e6a8e8f4 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -178,11 +178,11 @@ class OracleDB: class ClientManager: - _instances = {"db": None, "ref_count": 0} + _instances: dict[str, Any] = {"db": None, "ref_count": 0} _lock = asyncio.Lock() @staticmethod - def get_config(): + def get_config() -> dict[str, Any]: config = configparser.ConfigParser() config.read("config.ini", "utf-8") diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index fd560668..e10c3832 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -4,7 +4,7 @@ import json import os import time from dataclasses import dataclass, field -from typing import Any, Dict, List, Union, final +from typing import Any, Union, final import numpy as np import configparser @@ -41,6 +41,7 @@ if not pm.is_installed("asyncpg"): try: import asyncpg + from asyncpg import Pool except ImportError as e: raise ImportError( @@ -49,8 +50,7 @@ except ImportError as e: class PostgreSQLDB: - def __init__(self, config, **kwargs): - self.pool = None + def __init__(self, config: dict[str, Any], **kwargs: Any): self.host = config.get("host", "localhost") self.port = config.get("port", 5432) self.user = config.get("user", "postgres") @@ -59,7 +59,7 @@ class PostgreSQLDB: self.workspace = config.get("workspace", "default") self.max = 12 self.increment = 1 - logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier") + self.pool: Pool | None = None if self.user is None or self.password is None or self.database is None: raise ValueError( @@ -68,7 +68,7 @@ class PostgreSQLDB: async def initdb(self): try: - self.pool = await asyncpg.create_pool( + self.pool = await asyncpg.create_pool( # type: ignore user=self.user, password=self.password, database=self.database, @@ -79,43 +79,44 @@ class PostgreSQLDB: ) logger.info( - f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}" + f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}" ) except Exception as e: logger.error( - f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}" + f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}" ) - logger.error(f"PostgreSQL database error: {e}") raise + async def check_graph_requirement(self, graph_name: str): + async with self.pool.acquire() as connection: # type: ignore + await self._prerequisite(connection, graph_name) # type: ignore + async def check_tables(self): for k, v in TABLES.items(): try: await self.query(f"SELECT 1 FROM {k} LIMIT 1") except Exception as e: - logger.error(f"Failed to check table {k} in PostgreSQL database") logger.error(f"PostgreSQL database error: {e}") try: + logger.info(f"PostgreSQL, Try Creating table {k} in database") await self.execute(v["ddl"]) - logger.info(f"Created table {k} in PostgreSQL database") + logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database") except Exception as e: - logger.error(f"Failed to create table {k} in PostgreSQL database") - logger.error(f"PostgreSQL database error: {e}") + logger.error( + f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}" + ) + raise e logger.info("Finished checking all tables in PostgreSQL database") async def query( self, sql: str, - params: dict = None, + params: dict[str, Any] | None = None, multirows: bool = False, - for_age: bool = False, - graph_name: str = None, - ) -> Union[dict, None, list[dict]]: - async with self.pool.acquire() as connection: + ) -> dict[str, Any] | None | list[dict[str, Any]]: + async with self.pool.acquire() as connection: # type: ignore try: - if for_age: - await PostgreSQLDB._prerequisite(connection, graph_name) if params: rows = await connection.fetch(sql, *params.values()) else: @@ -143,20 +144,15 @@ class PostgreSQLDB: async def execute( self, sql: str, - data: Union[list, dict] = None, - for_age: bool = False, - graph_name: str = None, + data: dict[str, Any] | None = None, upsert: bool = False, ): try: - async with self.pool.acquire() as connection: - if for_age: - await PostgreSQLDB._prerequisite(connection, graph_name) - + async with self.pool.acquire() as connection: # type: ignore if data is None: - await connection.execute(sql) + await connection.execute(sql) # type: ignore else: - await connection.execute(sql, *data.values()) + await connection.execute(sql, *data.values()) # type: ignore except ( asyncpg.exceptions.UniqueViolationError, asyncpg.exceptions.DuplicateTableError, @@ -172,8 +168,8 @@ class PostgreSQLDB: @staticmethod async def _prerequisite(conn: asyncpg.Connection, graph_name: str): try: - await conn.execute('SET search_path = ag_catalog, "$user", public') - await conn.execute(f"""select create_graph('{graph_name}')""") + await conn.execute('SET search_path = ag_catalog, "$user", public') # type: ignore + await conn.execute(f"select create_graph('{graph_name}')") # type: ignore except ( asyncpg.exceptions.InvalidSchemaNameError, asyncpg.exceptions.UniqueViolationError, @@ -182,11 +178,11 @@ class PostgreSQLDB: class ClientManager: - _instances = {"db": None, "ref_count": 0} + _instances: dict[str, Any] = {"db": None, "ref_count": 0} _lock = asyncio.Lock() @staticmethod - def get_config(): + def get_config() -> dict[str, Any]: config = configparser.ConfigParser() config.read("config.ini", "utf-8") @@ -558,16 +554,33 @@ class PGDocStatusStorage(DocStatusStorage): ) async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - raise NotImplementedError + """Get doc_chunks data by id""" + sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + array_res = await self.db.query(sql, params, multirows=True) + modes = set() + dict_res: dict[str, dict] = {} + for row in array_res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in array_res: + dict_res[row["mode"]][row["id"]] = row + return [{k: v} for k, v in dict_res.items()] + else: + return await self.db.query(sql, params, multirows=True) - async def get_status_counts(self) -> Dict[str, int]: + async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" sql = """SELECT status as "status", COUNT(1) as "count" FROM LIGHTRAG_DOC_STATUS where workspace=$1 GROUP BY STATUS """ result = await self.db.query(sql, {"workspace": self.db.workspace}, True) - # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...] counts = {} for doc in result: counts[doc["status"]] = doc["count"] @@ -575,7 +588,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_docs_by_status( self, status: DocStatus - ) -> Dict[str, DocProcessingStatus]: + ) -> dict[str, DocProcessingStatus]: """all documents with a specific status""" sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" params = {"workspace": self.db.workspace, "status": status.value} @@ -602,7 +615,7 @@ class PGDocStatusStorage(DocStatusStorage): """Update or insert document status Args: - data: Dictionary of document IDs and their status data + data: dictionary of document IDs and their status data """ sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status) values($1,$2,$3,$4,$5,$6,$7) @@ -627,7 +640,6 @@ class PGDocStatusStorage(DocStatusStorage): "status": v["status"], }, ) - return data async def drop(self) -> None: """Drop the storage""" @@ -638,7 +650,7 @@ class PGDocStatusStorage(DocStatusStorage): class PGGraphQueryException(Exception): """Exception for the AGE queries.""" - def __init__(self, exception: Union[str, Dict]) -> None: + def __init__(self, exception: Union[str, dict[str, Any]]) -> None: if isinstance(exception, dict): self.message = exception["message"] if "message" in exception else "unknown" self.details = exception["details"] if "details" in exception else "unknown" @@ -656,21 +668,17 @@ class PGGraphQueryException(Exception): @final @dataclass class PGGraphStorage(BaseGraphStorage): - db: PostgreSQLDB = field(default=None) - - @staticmethod - def load_nx_graph(file_name): - print("no preloading of graph with AGE in production") - def __post_init__(self): self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") self._node_embed_algorithms = { "node2vec": self._node2vec_embed, } + self.db: PostgreSQLDB | None = None async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() + await self.db.check_graph_requirement(self.graph_name) async def finalize(self): if self.db is not None: @@ -682,7 +690,7 @@ class PGGraphStorage(BaseGraphStorage): pass @staticmethod - def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]: + def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]: """ Convert a record returned from an age query to a dictionary @@ -690,7 +698,7 @@ class PGGraphStorage(BaseGraphStorage): record (): a record from an age query result Returns: - Dict[str, Any]: a dictionary representation of the record where + dict[str, Any]: a dictionary representation of the record where the dictionary key is the field name and the value is the value converted to a python type """ @@ -745,14 +753,14 @@ class PGGraphStorage(BaseGraphStorage): @staticmethod def _format_properties( - properties: Dict[str, Any], _id: Union[str, None] = None + properties: dict[str, Any], _id: Union[str, None] = None ) -> str: """ Convert a dictionary of properties to a string representation that can be used in a cypher query insert/merge statement. Args: - properties (Dict[str,str]): a dictionary containing node/edge properties + properties (dict[str,str]): a dictionary containing node/edge properties _id (Union[str, None]): the id of the node or None if none exists Returns: @@ -820,8 +828,11 @@ class PGGraphStorage(BaseGraphStorage): return field.replace("(", "_").replace(")", "") async def _query( - self, query: str, readonly: bool = True, upsert: bool = False - ) -> List[Dict[str, Any]]: + self, + query: str, + readonly: bool = True, + upsert: bool = False, + ) -> list[dict[str, Any]]: """ Query the graph by taking a cypher query, converting it to an age compatible query, executing it and converting the result @@ -831,32 +842,24 @@ class PGGraphStorage(BaseGraphStorage): params (dict): parameters for the query Returns: - List[Dict[str, Any]]: a list of dictionaries containing the result set + list[dict[str, Any]]: a list of dictionaries containing the result set """ - # convert cypher query to pgsql/age query - wrapped_query = query - - # execute the query, rolling back on an error try: if readonly: data = await self.db.query( - wrapped_query, + query, multirows=True, - for_age=True, - graph_name=self.graph_name, ) else: data = await self.db.execute( - wrapped_query, - for_age=True, - graph_name=self.graph_name, + query, upsert=upsert, ) except Exception as e: raise PGGraphQueryException( { "message": f"Error executing graph query: {query}", - "wrapped": wrapped_query, + "wrapped": query, "detail": str(e), } ) from e @@ -865,12 +868,12 @@ class PGGraphStorage(BaseGraphStorage): result = [] # decode records else: - result = [PGGraphStorage._record_to_dict(d) for d in data] + result = [self._record_to_dict(d) for d in data] return result async def has_node(self, node_id: str) -> bool: - entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + entity_name_label = self._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"}) @@ -888,8 +891,8 @@ class PGGraphStorage(BaseGraphStorage): return single_result["node_exists"] async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) + src_label = self._encode_graph_label(source_node_id.strip('"')) + tgt_label = self._encode_graph_label(target_node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"}) @@ -910,7 +913,7 @@ class PGGraphStorage(BaseGraphStorage): return single_result["edge_exists"] async def get_node(self, node_id: str) -> dict[str, str] | None: - label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + label = self._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"}) RETURN n @@ -929,7 +932,7 @@ class PGGraphStorage(BaseGraphStorage): return None async def node_degree(self, node_id: str) -> int: - label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + label = self._encode_graph_label(node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"})-[]->(x) @@ -965,8 +968,8 @@ class PGGraphStorage(BaseGraphStorage): async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: - src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) + src_label = self._encode_graph_label(source_node_id.strip('"')) + tgt_label = self._encode_graph_label(target_node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"}) @@ -991,9 +994,9 @@ class PGGraphStorage(BaseGraphStorage): async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """ Retrieves all edges (relationships) for a particular node identified by its label. - :return: List of dictionaries containing edge information + :return: list of dictionaries containing edge information """ - label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) + label = self._encode_graph_label(source_node_id.strip('"')) query = """SELECT * FROM cypher('%s', $$ MATCH (n:Entity {node_id: "%s"}) @@ -1024,8 +1027,8 @@ class PGGraphStorage(BaseGraphStorage): if source_label and target_label: edges.append( ( - PGGraphStorage._decode_graph_label(source_label), - PGGraphStorage._decode_graph_label(target_label), + self._decode_graph_label(source_label), + self._decode_graph_label(target_label), ) ) @@ -1037,7 +1040,7 @@ class PGGraphStorage(BaseGraphStorage): retry=retry_if_exception_type((PGGraphQueryException,)), ) async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - label = PGGraphStorage._encode_graph_label(node_id.strip('"')) + label = self._encode_graph_label(node_id.strip('"')) properties = node_data query = """SELECT * FROM cypher('%s', $$ @@ -1047,7 +1050,7 @@ class PGGraphStorage(BaseGraphStorage): $$) AS (n agtype)""" % ( self.graph_name, label, - PGGraphStorage._format_properties(properties), + self._format_properties(properties), ) try: @@ -1075,10 +1078,10 @@ class PGGraphStorage(BaseGraphStorage): Args: source_node_id (str): Label of the source node (used as identifier) target_node_id (str): Label of the target node (used as identifier) - edge_data (dict): Dictionary of properties to set on the edge + edge_data (dict): dictionary of properties to set on the edge """ - src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"')) - tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"')) + src_label = self._encode_graph_label(source_node_id.strip('"')) + tgt_label = self._encode_graph_label(target_node_id.strip('"')) edge_properties = edge_data query = """SELECT * FROM cypher('%s', $$ @@ -1092,7 +1095,7 @@ class PGGraphStorage(BaseGraphStorage): self.graph_name, src_label, tgt_label, - PGGraphStorage._format_properties(edge_properties), + self._format_properties(edge_properties), ) # logger.info(f"-- inserting edge after formatted: {params}") try: diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index b94148d6..e8aaaf66 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -58,7 +58,6 @@ class TiDB: logger.error(f"Failed to check table {k} in TiDB database") logger.error(f"TiDB database error: {e}") try: - # print(v["ddl"]) await self.execute(v["ddl"]) logger.info(f"Created table {k} in TiDB database") except Exception as e: @@ -106,11 +105,11 @@ class TiDB: class ClientManager: - _instances = {"db": None, "ref_count": 0} + _instances: dict[str, Any] = {"db": None, "ref_count": 0} _lock = asyncio.Lock() @staticmethod - def get_config(): + def get_config() -> dict[str, Any]: config = configparser.ConfigParser() config.read("config.ini", "utf-8") From 4adb4418fdaa3fe25ae8ceccc1d399718a1b9e07 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Wed, 19 Feb 2025 13:42:49 +0100 Subject: [PATCH 2/4] cleaned code --- lightrag/kg/mongo_impl.py | 4 +- lightrag/kg/oracle_impl.py | 2 +- lightrag/kg/postgres_impl.py | 72 ++++++++---------------------------- lightrag/kg/tidb_impl.py | 2 +- 4 files changed, 20 insertions(+), 60 deletions(-) diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index c7b16a70..a6e6edfd 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -797,8 +797,8 @@ class MongoGraphStorage(BaseGraphStorage): @final @dataclass class MongoVectorDBStorage(BaseVectorStorage): - db: AsyncIOMotorDatabase = field(default=None) - _data: AsyncIOMotorCollection = field(default=None) + db: AsyncIOMotorDatabase | None = field(default=None) + _data: AsyncIOMotorCollection | None = field(default=None) def __post_init__(self): kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py index e6a8e8f4..0916f6b0 100644 --- a/lightrag/kg/oracle_impl.py +++ b/lightrag/kg/oracle_impl.py @@ -398,7 +398,7 @@ class OracleKVStorage(BaseKVStorage): @final @dataclass class OracleVectorDBStorage(BaseVectorStorage): - db: OracleDB = field(default=None) + db: OracleDB | None = field(default=None) def __post_init__(self): config = self.global_config.get("vector_db_storage_cls_kwargs", {}) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index e10c3832..3c88a05b 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1,5 +1,4 @@ import asyncio -import inspect import json import os import time @@ -373,7 +372,7 @@ class PGKVStorage(BaseKVStorage): @final @dataclass class PGVectorStorage(BaseVectorStorage): - db: PostgreSQLDB = field(default=None) + db: PostgreSQLDB | None = field(default=None) def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] @@ -394,10 +393,10 @@ class PGVectorStorage(BaseVectorStorage): await ClientManager.release_client(self.db) self.db = None - def _upsert_chunks(self, item: dict): + def _upsert_chunks(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]: try: upsert_sql = SQL_TEMPLATES["upsert_chunk"] - data = { + data: dict[str, Any] = { "workspace": self.db.workspace, "id": item["__id__"], "tokens": item["tokens"], @@ -412,9 +411,9 @@ class PGVectorStorage(BaseVectorStorage): return upsert_sql, data - def _upsert_entities(self, item: dict): + def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]: upsert_sql = SQL_TEMPLATES["upsert_entity"] - data = { + data: dict[str, Any] = { "workspace": self.db.workspace, "id": item["__id__"], "entity_name": item["entity_name"], @@ -423,9 +422,9 @@ class PGVectorStorage(BaseVectorStorage): } return upsert_sql, data - def _upsert_relationships(self, item: dict): + def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]: upsert_sql = SQL_TEMPLATES["upsert_relationship"] - data = { + data: dict[str, Any] = { "workspace": self.db.workspace, "id": item["__id__"], "source_id": item["src_id"], @@ -881,12 +880,6 @@ class PGGraphStorage(BaseGraphStorage): $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) single_result = (await self._query(query))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - single_result["node_exists"], - ) return single_result["node_exists"] @@ -904,12 +897,7 @@ class PGGraphStorage(BaseGraphStorage): ) single_result = (await self._query(query))[0] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - single_result["edge_exists"], - ) + return single_result["edge_exists"] async def get_node(self, node_id: str) -> dict[str, str] | None: @@ -922,12 +910,7 @@ class PGGraphStorage(BaseGraphStorage): if record: node = record[0] node_dict = node["n"] - logger.debug( - "{%s}: query: {%s}, result: {%s}", - inspect.currentframe().f_code.co_name, - query, - node_dict, - ) + return node_dict return None @@ -941,12 +924,7 @@ class PGGraphStorage(BaseGraphStorage): record = (await self._query(query))[0] if record: edge_count = int(record["total_edge_count"]) - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - edge_count, - ) + return edge_count async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -958,11 +936,7 @@ class PGGraphStorage(BaseGraphStorage): trg_degree = 0 if trg_degree is None else trg_degree degrees = int(src_degree) + int(trg_degree) - logger.debug( - "{%s}:query:src_Degree+trg_degree:result:{%s}", - inspect.currentframe().f_code.co_name, - degrees, - ) + return degrees async def get_edge( @@ -983,12 +957,7 @@ class PGGraphStorage(BaseGraphStorage): record = await self._query(query) if record and record[0] and record[0]["edge_properties"]: result = record[0]["edge_properties"] - logger.debug( - "{%s}:query:{%s}:result:{%s}", - inspect.currentframe().f_code.co_name, - query, - result, - ) + return result async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: @@ -1055,13 +1024,9 @@ class PGGraphStorage(BaseGraphStorage): try: await self._query(query, readonly=False, upsert=True) - logger.debug( - "Upserted node with label '{%s}' and properties: {%s}", - label, - properties, - ) + except Exception as e: - logger.error("Error during upsert: {%s}", e) + logger.error("POSTGRES, Error during upsert: {%s}", e) raise @retry( @@ -1097,15 +1062,10 @@ class PGGraphStorage(BaseGraphStorage): tgt_label, self._format_properties(edge_properties), ) - # logger.info(f"-- inserting edge after formatted: {params}") + try: await self._query(query, readonly=False, upsert=True) - logger.debug( - "Upserted edge from '{%s}' to '{%s}' with properties: {%s}", - src_label, - tgt_label, - edge_properties, - ) + except Exception as e: logger.error("Error during edge upsert: {%s}", e) raise diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py index e8aaaf66..ed9c8d4b 100644 --- a/lightrag/kg/tidb_impl.py +++ b/lightrag/kg/tidb_impl.py @@ -277,7 +277,7 @@ class TiDBKVStorage(BaseKVStorage): @final @dataclass class TiDBVectorDBStorage(BaseVectorStorage): - db: TiDB = field(default=None) + db: TiDB | None = field(default=None) def __post_init__(self): self._client_file_name = os.path.join( From 495b0ddbe0c78eda12a2190d3e7b1f1f5bdd47eb Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Wed, 19 Feb 2025 13:47:07 +0100 Subject: [PATCH 3/4] fixed networkx --- examples/graph_visual_with_html.py | 4 +++- lightrag/kg/networkx_impl.py | 5 ++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py index c1a6a015..56ed43cc 100644 --- a/examples/graph_visual_with_html.py +++ b/examples/graph_visual_with_html.py @@ -1,9 +1,11 @@ -import networkx as nx import pipmaster as pm if not pm.is_installed("pyvis"): pm.install("pyvis") +if not pm.is_installed("networkx"): + pm.install("networkx") +import networkx as nx from pyvis.network import Network import random diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index ac321d24..313d9f8d 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -15,11 +15,10 @@ from lightrag.base import ( ) import pipmaster as pm -if not pm.is_installed("graspologic"): - pm.install("graspologic") - if not pm.is_installed("networkx"): pm.install("networkx") +if not pm.is_installed("graspologic"): + pm.install("graspologic") try: from graspologic import embed From fc151e5866e3e1955b99bbcdb2f7c0afb6ecb9a3 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Wed, 19 Feb 2025 13:50:38 +0100 Subject: [PATCH 4/4] cleaned code --- lightrag/kg/postgres_impl.py | 45 +++++++++++------------------------- 1 file changed, 13 insertions(+), 32 deletions(-) diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 3c88a05b..044bf4c1 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -88,7 +88,16 @@ class PostgreSQLDB: async def check_graph_requirement(self, graph_name: str): async with self.pool.acquire() as connection: # type: ignore - await self._prerequisite(connection, graph_name) # type: ignore + try: + await connection.execute( + 'SET search_path = ag_catalog, "$user", public' + ) # type: ignore + await connection.execute(f"select create_graph('{graph_name}')") # type: ignore + except ( + asyncpg.exceptions.InvalidSchemaNameError, + asyncpg.exceptions.UniqueViolationError, + ): + pass async def check_tables(self): for k, v in TABLES.items(): @@ -106,8 +115,6 @@ class PostgreSQLDB: ) raise e - logger.info("Finished checking all tables in PostgreSQL database") - async def query( self, sql: str, @@ -164,17 +171,6 @@ class PostgreSQLDB: logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") raise - @staticmethod - async def _prerequisite(conn: asyncpg.Connection, graph_name: str): - try: - await conn.execute('SET search_path = ag_catalog, "$user", public') # type: ignore - await conn.execute(f"select create_graph('{graph_name}')") # type: ignore - except ( - asyncpg.exceptions.InvalidSchemaNameError, - asyncpg.exceptions.UniqueViolationError, - ): - pass - class ClientManager: _instances: dict[str, Any] = {"db": None, "ref_count": 0} @@ -554,24 +550,7 @@ class PGDocStatusStorage(DocStatusStorage): async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get doc_chunks data by id""" - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - array_res = await self.db.query(sql, params, multirows=True) - modes = set() - dict_res: dict[str, dict] = {} - for row in array_res: - modes.add(row["mode"]) - for mode in modes: - if mode not in dict_res: - dict_res[mode] = {} - for row in array_res: - dict_res[row["mode"]][row["id"]] = row - return [{k: v} for k, v in dict_res.items()] - else: - return await self.db.query(sql, params, multirows=True) + raise NotImplementedError async def get_status_counts(self) -> dict[str, int]: """Get counts of documents in each status""" @@ -677,6 +656,8 @@ class PGGraphStorage(BaseGraphStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() + # `check_graph_requirement` is required to be executed after `get_client` + # to ensure the graph is created before any query is executed. await self.db.check_graph_requirement(self.graph_name) async def finalize(self):