Fix the lint issue
This commit is contained in:
@@ -19,7 +19,11 @@ from tenacity import (
|
||||
from ..utils import logger
|
||||
from ..base import (
|
||||
BaseKVStorage,
|
||||
BaseVectorStorage, DocStatusStorage, DocStatus, DocProcessingStatus, BaseGraphStorage,
|
||||
BaseVectorStorage,
|
||||
DocStatusStorage,
|
||||
DocStatus,
|
||||
DocProcessingStatus,
|
||||
BaseGraphStorage,
|
||||
)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
@@ -36,14 +40,15 @@ class PostgreSQLDB:
|
||||
self.user = config.get("user", "postgres")
|
||||
self.password = config.get("password", None)
|
||||
self.database = config.get("database", "postgres")
|
||||
self.workspace = config.get("workspace", 'default')
|
||||
self.workspace = config.get("workspace", "default")
|
||||
self.max = 12
|
||||
self.increment = 1
|
||||
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
|
||||
|
||||
if self.user is None or self.password is None or self.database is None:
|
||||
raise ValueError("Missing database user, password, or database in addon_params")
|
||||
|
||||
raise ValueError(
|
||||
"Missing database user, password, or database in addon_params"
|
||||
)
|
||||
|
||||
async def initdb(self):
|
||||
try:
|
||||
@@ -54,12 +59,16 @@ class PostgreSQLDB:
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
min_size=1,
|
||||
max_size=self.max
|
||||
max_size=self.max,
|
||||
)
|
||||
|
||||
logger.info(f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}")
|
||||
logger.info(
|
||||
f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}")
|
||||
logger.error(
|
||||
f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
||||
)
|
||||
logger.error(f"PostgreSQL database error: {e}")
|
||||
raise
|
||||
|
||||
@@ -79,9 +88,13 @@ class PostgreSQLDB:
|
||||
|
||||
logger.info("Finished checking all tables in PostgreSQL database")
|
||||
|
||||
|
||||
async def query(
|
||||
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None
|
||||
self,
|
||||
sql: str,
|
||||
params: dict = None,
|
||||
multirows: bool = False,
|
||||
for_age: bool = False,
|
||||
graph_name: str = None,
|
||||
) -> Union[dict, None, list[dict]]:
|
||||
async with self.pool.acquire() as connection:
|
||||
try:
|
||||
@@ -111,7 +124,13 @@ class PostgreSQLDB:
|
||||
print(params)
|
||||
raise
|
||||
|
||||
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None):
|
||||
async def execute(
|
||||
self,
|
||||
sql: str,
|
||||
data: Union[list, dict] = None,
|
||||
for_age: bool = False,
|
||||
graph_name: str = None,
|
||||
):
|
||||
try:
|
||||
async with self.pool.acquire() as connection:
|
||||
if for_age:
|
||||
@@ -130,7 +149,7 @@ class PostgreSQLDB:
|
||||
@staticmethod
|
||||
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
||||
try:
|
||||
await conn.execute(f'SET search_path = ag_catalog, "$user", public')
|
||||
await conn.execute('SET search_path = ag_catalog, "$user", public')
|
||||
await conn.execute(f"""select create_graph('{graph_name}')""")
|
||||
except asyncpg.exceptions.InvalidSchemaNameError:
|
||||
pass
|
||||
@@ -138,7 +157,7 @@ class PostgreSQLDB:
|
||||
|
||||
@dataclass
|
||||
class PGKVStorage(BaseKVStorage):
|
||||
db:PostgreSQLDB = None
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._data = {}
|
||||
@@ -180,7 +199,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
dict_res[mode] = {}
|
||||
for row in array_res:
|
||||
dict_res[row["mode"]][row["id"]] = row
|
||||
res = [{k:v} for k,v in dict_res.items()]
|
||||
res = [{k: v} for k, v in dict_res.items()]
|
||||
else:
|
||||
res = await self.db.query(sql, params, multirows=True)
|
||||
if res:
|
||||
@@ -191,7 +210,8 @@ class PGKVStorage(BaseKVStorage):
|
||||
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
||||
"""Filter out duplicated content"""
|
||||
sql = SQL_TEMPLATES["filter_keys"].format(
|
||||
table_name=NAMESPACE_TABLE_MAP[self.namespace], ids=",".join([f"'{id}'" for id in keys])
|
||||
table_name=NAMESPACE_TABLE_MAP[self.namespace],
|
||||
ids=",".join([f"'{id}'" for id in keys]),
|
||||
)
|
||||
params = {"workspace": self.db.workspace}
|
||||
try:
|
||||
@@ -207,7 +227,6 @@ class PGKVStorage(BaseKVStorage):
|
||||
print(sql)
|
||||
print(params)
|
||||
|
||||
|
||||
################ INSERT METHODS ################
|
||||
async def upsert(self, data: Dict[str, dict]):
|
||||
left_data = {k: v for k, v in data.items() if k not in self._data}
|
||||
@@ -246,7 +265,7 @@ class PGKVStorage(BaseKVStorage):
|
||||
@dataclass
|
||||
class PGVectorStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = 0.2
|
||||
db:PostgreSQLDB = None
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
@@ -282,6 +301,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
"content_vector": json.dumps(item["__vector__"].tolist()),
|
||||
}
|
||||
return upsert_sql, data
|
||||
|
||||
def _upsert_relationships(self, item: dict):
|
||||
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
||||
data = {
|
||||
@@ -340,8 +360,6 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
|
||||
await self.db.execute(upsert_sql, data)
|
||||
|
||||
|
||||
|
||||
async def index_done_callback(self):
|
||||
logger.info("vector data had been saved into postgresql db!")
|
||||
|
||||
@@ -350,7 +368,7 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
"""从向量数据库中查询数据"""
|
||||
embeddings = await self.embedding_func([query])
|
||||
embedding = embeddings[0]
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
embedding_string = ",".join(map(str, embedding))
|
||||
|
||||
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
||||
params = {
|
||||
@@ -361,10 +379,12 @@ class PGVectorStorage(BaseVectorStorage):
|
||||
results = await self.db.query(sql, params=params, multirows=True)
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class PGDocStatusStorage(DocStatusStorage):
|
||||
"""PostgreSQL implementation of document status storage"""
|
||||
db:PostgreSQLDB = None
|
||||
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
@@ -372,41 +392,47 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
async def filter_keys(self, data: list[str]) -> set[str]:
|
||||
"""Return keys that don't exist in storage"""
|
||||
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
|
||||
result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
|
||||
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
||||
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
||||
if result is None:
|
||||
return set(data)
|
||||
else:
|
||||
existed = set([element['id'] for element in result])
|
||||
existed = set([element["id"] for element in result])
|
||||
return set(data) - existed
|
||||
|
||||
async def get_status_counts(self) -> Dict[str, int]:
|
||||
"""Get counts of documents in each status"""
|
||||
sql = '''SELECT status as "status", COUNT(1) as "count"
|
||||
sql = """SELECT status as "status", COUNT(1) as "count"
|
||||
FROM LIGHTRAG_DOC_STATUS
|
||||
where workspace=$1 GROUP BY STATUS
|
||||
'''
|
||||
result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
|
||||
"""
|
||||
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
||||
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
|
||||
counts = {}
|
||||
for doc in result:
|
||||
counts[doc["status"]] = doc["count"]
|
||||
return counts
|
||||
|
||||
async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
|
||||
async def get_docs_by_status(
|
||||
self, status: DocStatus
|
||||
) -> Dict[str, DocProcessingStatus]:
|
||||
"""Get all documents by status"""
|
||||
sql = 'select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1'
|
||||
params = {'workspace': self.db.workspace, 'status': status}
|
||||
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
|
||||
params = {"workspace": self.db.workspace, "status": status}
|
||||
result = await self.db.query(sql, params, True)
|
||||
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
|
||||
# Converting to be a dict
|
||||
return {element["id"]:
|
||||
DocProcessingStatus(content_summary=element["content_summary"],
|
||||
content_length=element["content_length"],
|
||||
status=element["status"],
|
||||
created_at=element["created_at"],
|
||||
updated_at=element["updated_at"],
|
||||
chunks_count=element["chunks_count"]) for element in result}
|
||||
return {
|
||||
element["id"]: DocProcessingStatus(
|
||||
content_summary=element["content_summary"],
|
||||
content_length=element["content_length"],
|
||||
status=element["status"],
|
||||
created_at=element["created_at"],
|
||||
updated_at=element["updated_at"],
|
||||
chunks_count=element["chunks_count"],
|
||||
)
|
||||
for element in result
|
||||
}
|
||||
|
||||
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
||||
"""Get all failed documents"""
|
||||
@@ -436,14 +462,17 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||
updated_at = CURRENT_TIMESTAMP"""
|
||||
for k, v in data.items():
|
||||
# chunks_count is optional
|
||||
await self.db.execute(sql, {
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"content_summary": v["content_summary"],
|
||||
"content_length": v["content_length"],
|
||||
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
||||
"status": v["status"],
|
||||
})
|
||||
await self.db.execute(
|
||||
sql,
|
||||
{
|
||||
"workspace": self.db.workspace,
|
||||
"id": k,
|
||||
"content_summary": v["content_summary"],
|
||||
"content_length": v["content_length"],
|
||||
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
||||
"status": v["status"],
|
||||
},
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -467,7 +496,7 @@ class PGGraphQueryException(Exception):
|
||||
|
||||
@dataclass
|
||||
class PGGraphStorage(BaseGraphStorage):
|
||||
db:PostgreSQLDB = None
|
||||
db: PostgreSQLDB = None
|
||||
|
||||
@staticmethod
|
||||
def load_nx_graph(file_name):
|
||||
@@ -484,7 +513,6 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
"node2vec": self._node2vec_embed,
|
||||
}
|
||||
|
||||
|
||||
async def index_done_callback(self):
|
||||
print("KG successfully indexed.")
|
||||
|
||||
@@ -552,7 +580,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
@staticmethod
|
||||
def _format_properties(
|
||||
properties: Dict[str, Any], _id: Union[str, None] = None
|
||||
properties: Dict[str, Any], _id: Union[str, None] = None
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary of properties to a string representation that
|
||||
@@ -669,7 +697,8 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
|
||||
# get pgsql formatted field names
|
||||
fields = [
|
||||
PGGraphStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
|
||||
PGGraphStorage._get_col_name(field, idx)
|
||||
for idx, field in enumerate(fields)
|
||||
]
|
||||
|
||||
# build resulting pgsql relation
|
||||
@@ -690,7 +719,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
projection=select_str,
|
||||
)
|
||||
|
||||
async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]:
|
||||
async def _query(
|
||||
self, query: str, readonly=True, upsert_edge=False, **params: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query the graph by taking a cypher query, converting it to an
|
||||
age compatible query, executing it and converting the result
|
||||
@@ -708,14 +739,25 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
# execute the query, rolling back on an error
|
||||
try:
|
||||
if readonly:
|
||||
data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name)
|
||||
data = await self.db.query(
|
||||
wrapped_query,
|
||||
multirows=True,
|
||||
for_age=True,
|
||||
graph_name=self.graph_name,
|
||||
)
|
||||
else:
|
||||
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
||||
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
||||
if upsert_edge:
|
||||
data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name)
|
||||
data = await self.db.execute(
|
||||
f"{wrapped_query};{wrapped_query};",
|
||||
for_age=True,
|
||||
graph_name=self.graph_name,
|
||||
)
|
||||
else:
|
||||
data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name)
|
||||
data = await self.db.execute(
|
||||
wrapped_query, for_age=True, graph_name=self.graph_name
|
||||
)
|
||||
except Exception as e:
|
||||
raise PGGraphQueryException(
|
||||
{
|
||||
@@ -819,7 +861,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
return degrees
|
||||
|
||||
async def get_edge(
|
||||
self, source_node_id: str, target_node_id: str
|
||||
self, source_node_id: str, target_node_id: str
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Find all edges between nodes of two given labels
|
||||
@@ -922,7 +964,7 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
retry=retry_if_exception_type((PGGraphQueryException,)),
|
||||
)
|
||||
async def upsert_edge(
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Upsert an edge and its properties between two nodes identified by their labels.
|
||||
@@ -935,7 +977,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
source_node_label = source_node_id.strip('"')
|
||||
target_node_label = target_node_id.strip('"')
|
||||
edge_properties = edge_data
|
||||
logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}")
|
||||
logger.info(
|
||||
f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
|
||||
)
|
||||
|
||||
query = """MATCH (source:`{src_label}`)
|
||||
WITH source
|
||||
@@ -1056,7 +1100,6 @@ TABLES = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
SQL_TEMPLATES = {
|
||||
# SQL for KVStorage
|
||||
"get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
|
||||
@@ -1107,7 +1150,7 @@ SQL_TEMPLATES = {
|
||||
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (workspace,id) DO UPDATE
|
||||
SET entity_name=EXCLUDED.entity_name,
|
||||
SET entity_name=EXCLUDED.entity_name,
|
||||
content=EXCLUDED.content,
|
||||
content_vector=EXCLUDED.content_vector,
|
||||
updatetime=CURRENT_TIMESTAMP
|
||||
@@ -1136,5 +1179,5 @@ SQL_TEMPLATES = {
|
||||
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
||||
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
||||
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
||||
"""
|
||||
""",
|
||||
}
|
||||
|
Reference in New Issue
Block a user