main_merge

This commit is contained in:
Roy
2025-03-08 20:34:29 +00:00
39 changed files with 1475 additions and 244 deletions

View File

@@ -585,6 +585,41 @@ class PGVectorStorage(BaseVectorStorage):
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
async def search_by_prefix(self, prefix: str) -> list[dict[str, Any]]:
"""Search for records with IDs starting with a specific prefix.
Args:
prefix: The prefix to search for in record IDs
Returns:
List of records with matching ID prefixes
"""
table_name = namespace_to_table_name(self.namespace)
if not table_name:
logger.error(f"Unknown namespace for prefix search: {self.namespace}")
return []
search_sql = f"SELECT * FROM {table_name} WHERE workspace=$1 AND id LIKE $2"
params = {"workspace": self.db.workspace, "prefix": f"{prefix}%"}
try:
results = await self.db.query(search_sql, params, multirows=True)
logger.debug(f"Found {len(results)} records with prefix '{prefix}'")
# Format results to match the expected return format
formatted_results = []
for record in results:
formatted_record = dict(record)
# Ensure id field is available (for consistency with NanoVectorDB implementation)
if "id" not in formatted_record:
formatted_record["id"] = record["id"]
formatted_results.append(formatted_record)
return formatted_results
except Exception as e:
logger.error(f"Error during prefix search for '{prefix}': {e}")
return []
@final
@dataclass
@@ -785,42 +820,85 @@ class PGGraphStorage(BaseGraphStorage):
v = record[k]
# agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
if v.startswith("[") and v.endswith("]"):
if "::vertex" not in v:
continue
v = v.replace("::vertex", "")
vertexes = json.loads(v)
for vertex in vertexes:
vertices[vertex["id"]] = vertex.get("properties")
else:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
# iterate returned fields and parse appropriately
for k in record.keys():
v = record[k]
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
else:
dtype = ""
if v.startswith("[") and v.endswith("]"):
if "::vertex" in v:
v = v.replace("::vertex", "")
vertexes = json.loads(v)
dl = []
for vertex in vertexes:
prop = vertex.get("properties")
if not prop:
prop = {}
prop["label"] = PGGraphStorage._decode_graph_label(
prop["node_id"]
)
dl.append(prop)
d[k] = dl
if dtype == "vertex":
vertex = json.loads(v)
field = vertex.get("properties")
if not field:
field = {}
field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
elif "::edge" in v:
v = v.replace("::edge", "")
edges = json.loads(v)
dl = []
for edge in edges:
dl.append(
(
vertices[edge["start_id"]],
edge["label"],
vertices[edge["end_id"]],
)
)
d[k] = dl
else:
print("WARNING: unsupported type")
continue
else:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
field = vertex.get("properties")
if not field:
field = {}
field["label"] = PGGraphStorage._decode_graph_label(
field["node_id"]
)
d[k] = field
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge[
"label"
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
vertices.get(edge["end_id"], {}),
)
else:
d[k] = json.loads(v) if isinstance(v, str) else v
if v is None or (v.count("{") < 1 and v.count("[") < 1):
d[k] = v
else:
d[k] = json.loads(v) if isinstance(v, str) else v
return d
@@ -1294,7 +1372,7 @@ class PGGraphStorage(BaseGraphStorage):
OPTIONAL MATCH p = (n)-[*..%d]-(m)
RETURN nodes(p) AS nodes, relationships(p) AS relationships
LIMIT %d
$$) AS (nodes agtype[], relationships agtype[])""" % (
$$) AS (nodes agtype, relationships agtype)""" % (
self.graph_name,
encoded_node_label,
max_depth,
@@ -1303,17 +1381,23 @@ class PGGraphStorage(BaseGraphStorage):
results = await self._query(query)
nodes = set()
nodes = {}
edges = []
unique_edge_ids = set()
for result in results:
if node_label == "*":
if result["n"]:
node = result["n"]
nodes.add(self._decode_graph_label(node["node_id"]))
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["m"]:
node = result["m"]
nodes.add(self._decode_graph_label(node["node_id"]))
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["r"]:
edge = result["r"]
src_id = self._decode_graph_label(edge["start_id"])
@@ -1322,16 +1406,36 @@ class PGGraphStorage(BaseGraphStorage):
else:
if result["nodes"]:
for node in result["nodes"]:
nodes.add(self._decode_graph_label(node["node_id"]))
node_id = self._decode_graph_label(node["node_id"])
if node_id not in nodes:
nodes[node_id] = node
if result["relationships"]:
for edge in result["relationships"]:
src_id = self._decode_graph_label(edge["start_id"])
tgt_id = self._decode_graph_label(edge["end_id"])
edges.append((src_id, tgt_id))
for edge in result["relationships"]: # src --DIRECTED--> target
src_id = self._decode_graph_label(edge[0]["node_id"])
tgt_id = self._decode_graph_label(edge[2]["node_id"])
id = src_id + "," + tgt_id
if id in unique_edge_ids:
continue
else:
unique_edge_ids.add(id)
edges.append(
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
)
kg = KnowledgeGraph(
nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
nodes=[
KnowledgeGraphNode(
id=node_id, labels=[node_id], properties=nodes[node_id]
)
for node_id in nodes
],
edges=[
KnowledgeGraphEdge(
id=id, type="DIRECTED", source=src, target=tgt, properties=props
)
for id, src, tgt, props in edges
],
)
return kg