This commit is contained in:
zrguo
2025-03-17 15:59:54 +08:00
parent 2967fd2cd2
commit 3df20ae787

View File

@@ -437,7 +437,7 @@ class PGVectorStorage(BaseVectorStorage):
chunk_ids = source_id.split("<SEP>") chunk_ids = source_id.split("<SEP>")
else: else:
chunk_ids = [source_id] chunk_ids = [source_id]
data: dict[str, Any] = { data: dict[str, Any] = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": item["__id__"], "id": item["__id__"],
@@ -456,7 +456,7 @@ class PGVectorStorage(BaseVectorStorage):
chunk_ids = source_id.split("<SEP>") chunk_ids = source_id.split("<SEP>")
else: else:
chunk_ids = [source_id] chunk_ids = [source_id]
data: dict[str, Any] = { data: dict[str, Any] = {
"workspace": self.db.workspace, "workspace": self.db.workspace,
"id": item["__id__"], "id": item["__id__"],
@@ -962,7 +962,11 @@ class PGGraphStorage(BaseGraphStorage):
vertices.get(edge["end_id"], {}), vertices.get(edge["end_id"], {}),
) )
else: else:
d[k] = json.loads(v) if isinstance(v, str) and ("{" in v or "[" in v) else v d[k] = (
json.loads(v)
if isinstance(v, str) and ("{" in v or "[" in v)
else v
)
return d return d
@@ -1404,7 +1408,9 @@ class PGGraphStorage(BaseGraphStorage):
embed_func = self._node_embed_algorithms[algorithm] embed_func = self._node_embed_algorithms[algorithm]
return await embed_func() return await embed_func()
async def get_knowledge_graph(self, node_label: str, max_depth: int = 5) -> KnowledgeGraph: async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
""" """
Retrieve a subgraph containing the specified node and its neighbors up to the specified depth. Retrieve a subgraph containing the specified node and its neighbors up to the specified depth.
@@ -1451,7 +1457,14 @@ class PGGraphStorage(BaseGraphStorage):
edge_key = f"{src_id},{tgt_id}" edge_key = f"{src_id},{tgt_id}"
if edge_key not in unique_edge_ids: if edge_key not in unique_edge_ids:
unique_edge_ids.add(edge_key) unique_edge_ids.add(edge_key)
edges.append((edge_key, src_id, tgt_id, {"source": edge_data[0], "target": edge_data[2]})) edges.append(
(
edge_key,
src_id,
tgt_id,
{"source": edge_data[0], "target": edge_data[2]},
)
)
# Process the query results. # Process the query results.
if node_label == "*": if node_label == "*":
@@ -1476,12 +1489,19 @@ class PGGraphStorage(BaseGraphStorage):
for node_id, node_data in nodes.items() for node_id, node_data in nodes.items()
], ],
edges=[ edges=[
KnowledgeGraphEdge(id=edge_id, type="DIRECTED", source=src, target=tgt, properties=props) KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
source=src,
target=tgt,
properties=props,
)
for edge_id, src, tgt, props in edges for edge_id, src, tgt, props in edges
], ],
) )
return kg return kg
async def drop(self) -> None: async def drop(self) -> None:
"""Drop the storage""" """Drop the storage"""
drop_sql = SQL_TEMPLATES["drop_vdb_entity"] drop_sql = SQL_TEMPLATES["drop_vdb_entity"]