cleaned code
This commit is contained in:
@@ -797,8 +797,8 @@ class MongoGraphStorage(BaseGraphStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class MongoVectorDBStorage(BaseVectorStorage):
|
class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
db: AsyncIOMotorDatabase = field(default=None)
|
db: AsyncIOMotorDatabase | None = field(default=None)
|
||||||
_data: AsyncIOMotorCollection = field(default=None)
|
_data: AsyncIOMotorCollection | None = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
@@ -398,7 +398,7 @@ class OracleKVStorage(BaseKVStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class OracleVectorDBStorage(BaseVectorStorage):
|
class OracleVectorDBStorage(BaseVectorStorage):
|
||||||
db: OracleDB = field(default=None)
|
db: OracleDB | None = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
||||||
|
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@@ -373,7 +372,7 @@ class PGKVStorage(BaseKVStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGVectorStorage(BaseVectorStorage):
|
class PGVectorStorage(BaseVectorStorage):
|
||||||
db: PostgreSQLDB = field(default=None)
|
db: PostgreSQLDB | None = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
@@ -394,10 +393,10 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
await ClientManager.release_client(self.db)
|
await ClientManager.release_client(self.db)
|
||||||
self.db = None
|
self.db = None
|
||||||
|
|
||||||
def _upsert_chunks(self, item: dict):
|
def _upsert_chunks(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
|
upsert_sql = SQL_TEMPLATES["upsert_chunk"]
|
||||||
data = {
|
data: dict[str, Any] = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"id": item["__id__"],
|
"id": item["__id__"],
|
||||||
"tokens": item["tokens"],
|
"tokens": item["tokens"],
|
||||||
@@ -412,9 +411,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
|
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
|
|
||||||
def _upsert_entities(self, item: dict):
|
def _upsert_entities(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
upsert_sql = SQL_TEMPLATES["upsert_entity"]
|
||||||
data = {
|
data: dict[str, Any] = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"id": item["__id__"],
|
"id": item["__id__"],
|
||||||
"entity_name": item["entity_name"],
|
"entity_name": item["entity_name"],
|
||||||
@@ -423,9 +422,9 @@ class PGVectorStorage(BaseVectorStorage):
|
|||||||
}
|
}
|
||||||
return upsert_sql, data
|
return upsert_sql, data
|
||||||
|
|
||||||
def _upsert_relationships(self, item: dict):
|
def _upsert_relationships(self, item: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
||||||
data = {
|
data: dict[str, Any] = {
|
||||||
"workspace": self.db.workspace,
|
"workspace": self.db.workspace,
|
||||||
"id": item["__id__"],
|
"id": item["__id__"],
|
||||||
"source_id": item["src_id"],
|
"source_id": item["src_id"],
|
||||||
@@ -881,12 +880,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
||||||
|
|
||||||
single_result = (await self._query(query))[0]
|
single_result = (await self._query(query))[0]
|
||||||
logger.debug(
|
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
|
||||||
inspect.currentframe().f_code.co_name,
|
|
||||||
query,
|
|
||||||
single_result["node_exists"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return single_result["node_exists"]
|
return single_result["node_exists"]
|
||||||
|
|
||||||
@@ -904,12 +897,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
single_result = (await self._query(query))[0]
|
single_result = (await self._query(query))[0]
|
||||||
logger.debug(
|
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
|
||||||
inspect.currentframe().f_code.co_name,
|
|
||||||
query,
|
|
||||||
single_result["edge_exists"],
|
|
||||||
)
|
|
||||||
return single_result["edge_exists"]
|
return single_result["edge_exists"]
|
||||||
|
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
@@ -922,12 +910,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
if record:
|
if record:
|
||||||
node = record[0]
|
node = record[0]
|
||||||
node_dict = node["n"]
|
node_dict = node["n"]
|
||||||
logger.debug(
|
|
||||||
"{%s}: query: {%s}, result: {%s}",
|
|
||||||
inspect.currentframe().f_code.co_name,
|
|
||||||
query,
|
|
||||||
node_dict,
|
|
||||||
)
|
|
||||||
return node_dict
|
return node_dict
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -941,12 +924,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
record = (await self._query(query))[0]
|
record = (await self._query(query))[0]
|
||||||
if record:
|
if record:
|
||||||
edge_count = int(record["total_edge_count"])
|
edge_count = int(record["total_edge_count"])
|
||||||
logger.debug(
|
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
|
||||||
inspect.currentframe().f_code.co_name,
|
|
||||||
query,
|
|
||||||
edge_count,
|
|
||||||
)
|
|
||||||
return edge_count
|
return edge_count
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
@@ -958,11 +936,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
trg_degree = 0 if trg_degree is None else trg_degree
|
trg_degree = 0 if trg_degree is None else trg_degree
|
||||||
|
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
logger.debug(
|
|
||||||
"{%s}:query:src_Degree+trg_degree:result:{%s}",
|
|
||||||
inspect.currentframe().f_code.co_name,
|
|
||||||
degrees,
|
|
||||||
)
|
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
@@ -983,12 +957,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
record = await self._query(query)
|
record = await self._query(query)
|
||||||
if record and record[0] and record[0]["edge_properties"]:
|
if record and record[0] and record[0]["edge_properties"]:
|
||||||
result = record[0]["edge_properties"]
|
result = record[0]["edge_properties"]
|
||||||
logger.debug(
|
|
||||||
"{%s}:query:{%s}:result:{%s}",
|
|
||||||
inspect.currentframe().f_code.co_name,
|
|
||||||
query,
|
|
||||||
result,
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
@@ -1055,13 +1024,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._query(query, readonly=False, upsert=True)
|
await self._query(query, readonly=False, upsert=True)
|
||||||
logger.debug(
|
|
||||||
"Upserted node with label '{%s}' and properties: {%s}",
|
|
||||||
label,
|
|
||||||
properties,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error during upsert: {%s}", e)
|
logger.error("POSTGRES, Error during upsert: {%s}", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -1097,15 +1062,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|||||||
tgt_label,
|
tgt_label,
|
||||||
self._format_properties(edge_properties),
|
self._format_properties(edge_properties),
|
||||||
)
|
)
|
||||||
# logger.info(f"-- inserting edge after formatted: {params}")
|
|
||||||
try:
|
try:
|
||||||
await self._query(query, readonly=False, upsert=True)
|
await self._query(query, readonly=False, upsert=True)
|
||||||
logger.debug(
|
|
||||||
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
|
||||||
src_label,
|
|
||||||
tgt_label,
|
|
||||||
edge_properties,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error during edge upsert: {%s}", e)
|
logger.error("Error during edge upsert: {%s}", e)
|
||||||
raise
|
raise
|
||||||
|
@@ -277,7 +277,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class TiDBVectorDBStorage(BaseVectorStorage):
|
class TiDBVectorDBStorage(BaseVectorStorage):
|
||||||
db: TiDB = field(default=None)
|
db: TiDB | None = field(default=None)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._client_file_name = os.path.join(
|
self._client_file_name = os.path.join(
|
||||||
|
Reference in New Issue
Block a user