diff --git a/examples/graph_visual_with_html.py b/examples/graph_visual_with_html.py
index c1a6a015..56ed43cc 100644
--- a/examples/graph_visual_with_html.py
+++ b/examples/graph_visual_with_html.py
@@ -1,9 +1,11 @@
-import networkx as nx
import pipmaster as pm
if not pm.is_installed("pyvis"):
pm.install("pyvis")
+if not pm.is_installed("networkx"):
+ pm.install("networkx")
+import networkx as nx
from pyvis.network import Network
import random
diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py
index c7b16a70..a6e6edfd 100644
--- a/lightrag/kg/mongo_impl.py
+++ b/lightrag/kg/mongo_impl.py
@@ -797,8 +797,8 @@ class MongoGraphStorage(BaseGraphStorage):
@final
@dataclass
class MongoVectorDBStorage(BaseVectorStorage):
- db: AsyncIOMotorDatabase = field(default=None)
- _data: AsyncIOMotorCollection = field(default=None)
+ db: AsyncIOMotorDatabase | None = field(default=None)
+ _data: AsyncIOMotorCollection | None = field(default=None)
def __post_init__(self):
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py
index 03b1bbcb..82631cf8 100644
--- a/lightrag/kg/neo4j_impl.py
+++ b/lightrag/kg/neo4j_impl.py
@@ -43,10 +43,6 @@ config.read("config.ini", "utf-8")
@final
@dataclass
class Neo4JStorage(BaseGraphStorage):
- @staticmethod
- def load_nx_graph(file_name):
- print("no preloading of graph with neo4j in production")
-
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py
index ac321d24..313d9f8d 100644
--- a/lightrag/kg/networkx_impl.py
+++ b/lightrag/kg/networkx_impl.py
@@ -15,11 +15,10 @@ from lightrag.base import (
)
import pipmaster as pm
-if not pm.is_installed("graspologic"):
- pm.install("graspologic")
-
if not pm.is_installed("networkx"):
pm.install("networkx")
+if not pm.is_installed("graspologic"):
+ pm.install("graspologic")
try:
from graspologic import embed
diff --git a/lightrag/kg/oracle_impl.py b/lightrag/kg/oracle_impl.py
index 3e0c6799..0916f6b0 100644
--- a/lightrag/kg/oracle_impl.py
+++ b/lightrag/kg/oracle_impl.py
@@ -178,11 +178,11 @@ class OracleDB:
class ClientManager:
- _instances = {"db": None, "ref_count": 0}
+ _instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
- def get_config():
+ def get_config() -> dict[str, Any]:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -398,7 +398,7 @@ class OracleKVStorage(BaseKVStorage):
@final
@dataclass
class OracleVectorDBStorage(BaseVectorStorage):
- db: OracleDB = field(default=None)
+ db: OracleDB | None = field(default=None)
def __post_init__(self):
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py
index fd560668..044bf4c1 100644
--- a/lightrag/kg/postgres_impl.py
+++ b/lightrag/kg/postgres_impl.py
@@ -1,10 +1,9 @@
import asyncio
-import inspect
import json
import os
import time
from dataclasses import dataclass, field
-from typing import Any, Dict, List, Union, final
+from typing import Any, Union, final
import numpy as np
import configparser
@@ -41,6 +40,7 @@ if not pm.is_installed("asyncpg"):
try:
import asyncpg
+ from asyncpg import Pool
except ImportError as e:
raise ImportError(
@@ -49,8 +49,7 @@ except ImportError as e:
class PostgreSQLDB:
- def __init__(self, config, **kwargs):
- self.pool = None
+ def __init__(self, config: dict[str, Any], **kwargs: Any):
self.host = config.get("host", "localhost")
self.port = config.get("port", 5432)
self.user = config.get("user", "postgres")
@@ -59,7 +58,7 @@ class PostgreSQLDB:
self.workspace = config.get("workspace", "default")
self.max = 12
self.increment = 1
- logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
+ self.pool: Pool | None = None
if self.user is None or self.password is None or self.database is None:
raise ValueError(
@@ -68,7 +67,7 @@ class PostgreSQLDB:
async def initdb(self):
try:
- self.pool = await asyncpg.create_pool(
+ self.pool = await asyncpg.create_pool( # type: ignore
user=self.user,
password=self.password,
database=self.database,
@@ -79,43 +78,51 @@ class PostgreSQLDB:
)
logger.info(
- f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
+ f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
)
except Exception as e:
logger.error(
- f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
+ f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}"
)
- logger.error(f"PostgreSQL database error: {e}")
raise
+ async def check_graph_requirement(self, graph_name: str):
+ async with self.pool.acquire() as connection: # type: ignore
+ try:
+ await connection.execute(
+ 'SET search_path = ag_catalog, "$user", public'
+ ) # type: ignore
+ await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
+ except (
+ asyncpg.exceptions.InvalidSchemaNameError,
+ asyncpg.exceptions.UniqueViolationError,
+ ):
+ pass
+
async def check_tables(self):
for k, v in TABLES.items():
try:
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
except Exception as e:
- logger.error(f"Failed to check table {k} in PostgreSQL database")
logger.error(f"PostgreSQL database error: {e}")
try:
+ logger.info(f"PostgreSQL, Try Creating table {k} in database")
await self.execute(v["ddl"])
- logger.info(f"Created table {k} in PostgreSQL database")
+ logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
except Exception as e:
- logger.error(f"Failed to create table {k} in PostgreSQL database")
- logger.error(f"PostgreSQL database error: {e}")
-
- logger.info("Finished checking all tables in PostgreSQL database")
+ logger.error(
+ f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
+ )
+ raise e
async def query(
self,
sql: str,
- params: dict = None,
+ params: dict[str, Any] | None = None,
multirows: bool = False,
- for_age: bool = False,
- graph_name: str = None,
- ) -> Union[dict, None, list[dict]]:
- async with self.pool.acquire() as connection:
+ ) -> dict[str, Any] | None | list[dict[str, Any]]:
+ async with self.pool.acquire() as connection: # type: ignore
try:
- if for_age:
- await PostgreSQLDB._prerequisite(connection, graph_name)
if params:
rows = await connection.fetch(sql, *params.values())
else:
@@ -143,20 +150,15 @@ class PostgreSQLDB:
async def execute(
self,
sql: str,
- data: Union[list, dict] = None,
- for_age: bool = False,
- graph_name: str = None,
+ data: dict[str, Any] | None = None,
upsert: bool = False,
):
try:
- async with self.pool.acquire() as connection:
- if for_age:
- await PostgreSQLDB._prerequisite(connection, graph_name)
-
+ async with self.pool.acquire() as connection: # type: ignore
if data is None:
- await connection.execute(sql)
+ await connection.execute(sql) # type: ignore
else:
- await connection.execute(sql, *data.values())
+ await connection.execute(sql, *data.values()) # type: ignore
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DuplicateTableError,
@@ -169,24 +171,13 @@ class PostgreSQLDB:
logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}")
raise
- @staticmethod
- async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
- try:
- await conn.execute('SET search_path = ag_catalog, "$user", public')
- await conn.execute(f"""select create_graph('{graph_name}')""")
- except (
- asyncpg.exceptions.InvalidSchemaNameError,
- asyncpg.exceptions.UniqueViolationError,
- ):
- pass
-
class ClientManager:
- _instances = {"db": None, "ref_count": 0}
+ _instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
- def get_config():
+ def get_config() -> dict[str, Any]:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -377,7 +368,7 @@ class PGKVStorage(BaseKVStorage):
@final
@dataclass
class PGVectorStorage(BaseVectorStorage):
- db: PostgreSQLDB = field(default=None)
+ db: PostgreSQLDB | None = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -398,10 +389,10 @@ class PGVectorStorage(BaseVectorStorage):
await ClientManager.release_client(self.db)
self.db = None
- def _upsert_chunks(self, item: dict):
+ def _upsert_chunks(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
try:
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
- data = {
+ data: dict[str, Any] = {
"workspace": self.db.workspace,
"id": item["__id__"],
"tokens": item["tokens"],
@@ -416,9 +407,9 @@ class PGVectorStorage(BaseVectorStorage):
return upsert_sql, data
- def _upsert_entities(self, item: dict):
+ def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_entity"]
- data = {
+ data: dict[str, Any] = {
"workspace": self.db.workspace,
"id": item["__id__"],
"entity_name": item["entity_name"],
@@ -427,9 +418,9 @@ class PGVectorStorage(BaseVectorStorage):
}
return upsert_sql, data
- def _upsert_relationships(self, item: dict):
+ def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
- data = {
+ data: dict[str, Any] = {
"workspace": self.db.workspace,
"id": item["__id__"],
"source_id": item["src_id"],
@@ -558,16 +549,16 @@ class PGDocStatusStorage(DocStatusStorage):
)
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
+ """Get doc_chunks data by id"""
raise NotImplementedError
- async def get_status_counts(self) -> Dict[str, int]:
+ async def get_status_counts(self) -> dict[str, int]:
"""Get counts of documents in each status"""
sql = """SELECT status as "status", COUNT(1) as "count"
FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS
"""
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
- # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
counts = {}
for doc in result:
counts[doc["status"]] = doc["count"]
@@ -575,7 +566,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_docs_by_status(
self, status: DocStatus
- ) -> Dict[str, DocProcessingStatus]:
+ ) -> dict[str, DocProcessingStatus]:
"""all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.db.workspace, "status": status.value}
@@ -602,7 +593,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""Update or insert document status
Args:
- data: Dictionary of document IDs and their status data
+ data: dictionary of document IDs and their status data
"""
sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status)
values($1,$2,$3,$4,$5,$6,$7)
@@ -627,7 +618,6 @@ class PGDocStatusStorage(DocStatusStorage):
"status": v["status"],
},
)
- return data
async def drop(self) -> None:
"""Drop the storage"""
@@ -638,7 +628,7 @@ class PGDocStatusStorage(DocStatusStorage):
class PGGraphQueryException(Exception):
"""Exception for the AGE queries."""
- def __init__(self, exception: Union[str, Dict]) -> None:
+ def __init__(self, exception: Union[str, dict[str, Any]]) -> None:
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
@@ -656,21 +646,19 @@ class PGGraphQueryException(Exception):
@final
@dataclass
class PGGraphStorage(BaseGraphStorage):
- db: PostgreSQLDB = field(default=None)
-
- @staticmethod
- def load_nx_graph(file_name):
- print("no preloading of graph with AGE in production")
-
def __post_init__(self):
self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag")
self._node_embed_algorithms = {
"node2vec": self._node2vec_embed,
}
+ self.db: PostgreSQLDB | None = None
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
+ # `check_graph_requirement` is required to be executed after `get_client`
+ # to ensure the graph is created before any query is executed.
+ await self.db.check_graph_requirement(self.graph_name)
async def finalize(self):
if self.db is not None:
@@ -682,7 +670,7 @@ class PGGraphStorage(BaseGraphStorage):
pass
@staticmethod
- def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
+ def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]:
"""
Convert a record returned from an age query to a dictionary
@@ -690,7 +678,7 @@ class PGGraphStorage(BaseGraphStorage):
record (): a record from an age query result
Returns:
- Dict[str, Any]: a dictionary representation of the record where
+ dict[str, Any]: a dictionary representation of the record where
the dictionary key is the field name and the value is the
value converted to a python type
"""
@@ -745,14 +733,14 @@ class PGGraphStorage(BaseGraphStorage):
@staticmethod
def _format_properties(
- properties: Dict[str, Any], _id: Union[str, None] = None
+ properties: dict[str, Any], _id: Union[str, None] = None
) -> str:
"""
Convert a dictionary of properties to a string representation that
can be used in a cypher query insert/merge statement.
Args:
- properties (Dict[str,str]): a dictionary containing node/edge properties
+ properties (dict[str,str]): a dictionary containing node/edge properties
_id (Union[str, None]): the id of the node or None if none exists
Returns:
@@ -820,8 +808,11 @@ class PGGraphStorage(BaseGraphStorage):
return field.replace("(", "_").replace(")", "")
async def _query(
- self, query: str, readonly: bool = True, upsert: bool = False
- ) -> List[Dict[str, Any]]:
+ self,
+ query: str,
+ readonly: bool = True,
+ upsert: bool = False,
+ ) -> list[dict[str, Any]]:
"""
Query the graph by taking a cypher query, converting it to an
age compatible query, executing it and converting the result
@@ -831,32 +822,24 @@ class PGGraphStorage(BaseGraphStorage):
params (dict): parameters for the query
Returns:
- List[Dict[str, Any]]: a list of dictionaries containing the result set
+ list[dict[str, Any]]: a list of dictionaries containing the result set
"""
- # convert cypher query to pgsql/age query
- wrapped_query = query
-
- # execute the query, rolling back on an error
try:
if readonly:
data = await self.db.query(
- wrapped_query,
+ query,
multirows=True,
- for_age=True,
- graph_name=self.graph_name,
)
else:
data = await self.db.execute(
- wrapped_query,
- for_age=True,
- graph_name=self.graph_name,
+ query,
upsert=upsert,
)
except Exception as e:
raise PGGraphQueryException(
{
"message": f"Error executing graph query: {query}",
- "wrapped": wrapped_query,
+ "wrapped": query,
"detail": str(e),
}
) from e
@@ -865,12 +848,12 @@ class PGGraphStorage(BaseGraphStorage):
result = []
# decode records
else:
- result = [PGGraphStorage._record_to_dict(d) for d in data]
+ result = [self._record_to_dict(d) for d in data]
return result
async def has_node(self, node_id: str) -> bool:
- entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
+ entity_name_label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
@@ -878,18 +861,12 @@ class PGGraphStorage(BaseGraphStorage):
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
- logger.debug(
- "{%s}:query:{%s}:result:{%s}",
- inspect.currentframe().f_code.co_name,
- query,
- single_result["node_exists"],
- )
return single_result["node_exists"]
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
- src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
- tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
+ src_label = self._encode_graph_label(source_node_id.strip('"'))
+ tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
@@ -901,16 +878,11 @@ class PGGraphStorage(BaseGraphStorage):
)
single_result = (await self._query(query))[0]
- logger.debug(
- "{%s}:query:{%s}:result:{%s}",
- inspect.currentframe().f_code.co_name,
- query,
- single_result["edge_exists"],
- )
+
return single_result["edge_exists"]
async def get_node(self, node_id: str) -> dict[str, str] | None:
- label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
+ label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
RETURN n
@@ -919,17 +891,12 @@ class PGGraphStorage(BaseGraphStorage):
if record:
node = record[0]
node_dict = node["n"]
- logger.debug(
- "{%s}: query: {%s}, result: {%s}",
- inspect.currentframe().f_code.co_name,
- query,
- node_dict,
- )
+
return node_dict
return None
async def node_degree(self, node_id: str) -> int:
- label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
+ label = self._encode_graph_label(node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})-[]->(x)
@@ -938,12 +905,7 @@ class PGGraphStorage(BaseGraphStorage):
record = (await self._query(query))[0]
if record:
edge_count = int(record["total_edge_count"])
- logger.debug(
- "{%s}:query:{%s}:result:{%s}",
- inspect.currentframe().f_code.co_name,
- query,
- edge_count,
- )
+
return edge_count
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -955,18 +917,14 @@ class PGGraphStorage(BaseGraphStorage):
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
- logger.debug(
- "{%s}:query:src_Degree+trg_degree:result:{%s}",
- inspect.currentframe().f_code.co_name,
- degrees,
- )
+
return degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
- src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
- tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
+ src_label = self._encode_graph_label(source_node_id.strip('"'))
+ tgt_label = self._encode_graph_label(target_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
@@ -980,20 +938,15 @@ class PGGraphStorage(BaseGraphStorage):
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
- logger.debug(
- "{%s}:query:{%s}:result:{%s}",
- inspect.currentframe().f_code.co_name,
- query,
- result,
- )
+
return result
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""
Retrieves all edges (relationships) for a particular node identified by its label.
- :return: List of dictionaries containing edge information
+ :return: list of dictionaries containing edge information
"""
- label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
+ label = self._encode_graph_label(source_node_id.strip('"'))
query = """SELECT * FROM cypher('%s', $$
MATCH (n:Entity {node_id: "%s"})
@@ -1024,8 +977,8 @@ class PGGraphStorage(BaseGraphStorage):
if source_label and target_label:
edges.append(
(
- PGGraphStorage._decode_graph_label(source_label),
- PGGraphStorage._decode_graph_label(target_label),
+ self._decode_graph_label(source_label),
+ self._decode_graph_label(target_label),
)
)
@@ -1037,7 +990,7 @@ class PGGraphStorage(BaseGraphStorage):
retry=retry_if_exception_type((PGGraphQueryException,)),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
- label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
+ label = self._encode_graph_label(node_id.strip('"'))
properties = node_data
query = """SELECT * FROM cypher('%s', $$
@@ -1047,18 +1000,14 @@ class PGGraphStorage(BaseGraphStorage):
$$) AS (n agtype)""" % (
self.graph_name,
label,
- PGGraphStorage._format_properties(properties),
+ self._format_properties(properties),
)
try:
await self._query(query, readonly=False, upsert=True)
- logger.debug(
- "Upserted node with label '{%s}' and properties: {%s}",
- label,
- properties,
- )
+
except Exception as e:
- logger.error("Error during upsert: {%s}", e)
+ logger.error("POSTGRES, Error during upsert: {%s}", e)
raise
@retry(
@@ -1075,10 +1024,10 @@ class PGGraphStorage(BaseGraphStorage):
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
- edge_data (dict): Dictionary of properties to set on the edge
+ edge_data (dict): dictionary of properties to set on the edge
"""
- src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
- tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
+ src_label = self._encode_graph_label(source_node_id.strip('"'))
+ tgt_label = self._encode_graph_label(target_node_id.strip('"'))
edge_properties = edge_data
query = """SELECT * FROM cypher('%s', $$
@@ -1092,17 +1041,12 @@ class PGGraphStorage(BaseGraphStorage):
self.graph_name,
src_label,
tgt_label,
- PGGraphStorage._format_properties(edge_properties),
+ self._format_properties(edge_properties),
)
- # logger.info(f"-- inserting edge after formatted: {params}")
+
try:
await self._query(query, readonly=False, upsert=True)
- logger.debug(
- "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
- src_label,
- tgt_label,
- edge_properties,
- )
+
except Exception as e:
logger.error("Error during edge upsert: {%s}", e)
raise
diff --git a/lightrag/kg/tidb_impl.py b/lightrag/kg/tidb_impl.py
index b94148d6..ed9c8d4b 100644
--- a/lightrag/kg/tidb_impl.py
+++ b/lightrag/kg/tidb_impl.py
@@ -58,7 +58,6 @@ class TiDB:
logger.error(f"Failed to check table {k} in TiDB database")
logger.error(f"TiDB database error: {e}")
try:
- # print(v["ddl"])
await self.execute(v["ddl"])
logger.info(f"Created table {k} in TiDB database")
except Exception as e:
@@ -106,11 +105,11 @@ class TiDB:
class ClientManager:
- _instances = {"db": None, "ref_count": 0}
+ _instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
- def get_config():
+ def get_config() -> dict[str, Any]:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@@ -278,7 +277,7 @@ class TiDBKVStorage(BaseKVStorage):
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
- db: TiDB = field(default=None)
+ db: TiDB | None = field(default=None)
def __post_init__(self):
self._client_file_name = os.path.join(