Fix linting
This commit is contained in:
@@ -363,7 +363,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""Get nodes as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches nodes one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -377,7 +377,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""Node degrees as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches node degrees one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -388,9 +388,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
result[node_id] = degree
|
||||
return result
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
async def edge_degrees_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> dict[tuple[str, str], int]:
|
||||
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
||||
|
||||
|
||||
Default implementation calculates edge degrees one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -401,9 +403,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
result[(src_id, tgt_id)] = degree
|
||||
return result
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
async def get_edges_batch(
|
||||
self, pairs: list[dict[str, str]]
|
||||
) -> dict[tuple[str, str], dict]:
|
||||
"""Get edges as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches edges one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
@@ -417,9 +421,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
||||
result[(src_id, tgt_id)] = edge
|
||||
return result
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""Get nodes edges as a batch using UNWIND
|
||||
|
||||
|
||||
Default implementation fetches node edges one by one.
|
||||
Override this method for better performance in storage backends
|
||||
that support batch operations.
|
||||
|
@@ -311,10 +311,10 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Retrieve multiple nodes in one query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node entity IDs to fetch.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
node_dict = dict(node)
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
node_dict["labels"] = [
|
||||
label for label in node_dict["labels"] if label != "base"
|
||||
]
|
||||
nodes[entity_id] = node_dict
|
||||
await result.consume() # Make sure to consume the result fully
|
||||
return nodes
|
||||
@@ -385,12 +387,12 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node labels (entity_id values) to look up.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
async with self._driver.session(
|
||||
@@ -407,13 +409,13 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
entity_id = record["entity_id"]
|
||||
degrees[entity_id] = record["degree"]
|
||||
await result.consume() # Ensure result is fully consumed
|
||||
|
||||
|
||||
# For any node_id that did not return a record, set degree to 0.
|
||||
for nid in node_ids:
|
||||
if nid not in degrees:
|
||||
logger.warning(f"No node found with label '{nid}'")
|
||||
degrees[nid] = 0
|
||||
|
||||
|
||||
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
||||
return degrees
|
||||
|
||||
@@ -436,25 +438,27 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
degrees = int(src_degree) + int(trg_degree)
|
||||
return degrees
|
||||
|
||||
async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
|
||||
async def edge_degrees_batch(
|
||||
self, edge_pairs: list[tuple[str, str]]
|
||||
) -> dict[tuple[str, str], int]:
|
||||
"""
|
||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||
in batch using the already implemented node_degrees_batch.
|
||||
|
||||
|
||||
Args:
|
||||
edge_pairs: List of (src, tgt) tuples.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
||||
"""
|
||||
# Collect unique node IDs from all edge pairs.
|
||||
unique_node_ids = {src for src, _ in edge_pairs}
|
||||
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
||||
|
||||
|
||||
# Get degrees for all nodes in one go.
|
||||
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
||||
|
||||
|
||||
# Sum up degrees for each edge pair.
|
||||
edge_degrees = {}
|
||||
for src, tgt in edge_pairs:
|
||||
@@ -547,13 +551,15 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
async def get_edges_batch(
|
||||
self, pairs: list[dict[str, str]]
|
||||
) -> dict[tuple[str, str], dict]:
|
||||
"""
|
||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||
|
||||
|
||||
Args:
|
||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
if edges and len(edges) > 0:
|
||||
edge_props = edges[0] # choose the first if multiple exist
|
||||
# Ensure required keys exist with defaults
|
||||
for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
|
||||
for key, default in {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}.items():
|
||||
if key not in edge_props:
|
||||
edge_props[key] = default
|
||||
edges_dict[(src, tgt)] = edge_props
|
||||
else:
|
||||
# No edge found – set default edge properties
|
||||
edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
|
||||
edges_dict[(src, tgt)] = {
|
||||
"weight": 0.0,
|
||||
"source_id": None,
|
||||
"description": None,
|
||||
"keywords": None,
|
||||
}
|
||||
await result.consume()
|
||||
return edges_dict
|
||||
|
||||
@@ -644,17 +660,21 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
||||
"""
|
||||
async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
|
||||
async with self._driver.session(
|
||||
database=self._DATABASE, default_access_mode="READ"
|
||||
) as session:
|
||||
query = """
|
||||
UNWIND $node_ids AS id
|
||||
MATCH (n:base {entity_id: id})
|
||||
|
@@ -1461,30 +1461,29 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||
"""
|
||||
Retrieve multiple nodes in one query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node entity IDs to fetch.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its node data (or None if not found).
|
||||
"""
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
||||
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
RETURN node_id, n
|
||||
$$) AS (node_id text, n agtype)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids
|
||||
)
|
||||
|
||||
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# Build result dictionary
|
||||
nodes_dict = {}
|
||||
for result in results:
|
||||
@@ -1492,28 +1491,32 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
node_dict = result["n"]["properties"]
|
||||
# Remove the 'base' label if present in a 'labels' property
|
||||
if "labels" in node_dict:
|
||||
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
||||
node_dict["labels"] = [
|
||||
label for label in node_dict["labels"] if label != "base"
|
||||
]
|
||||
nodes_dict[result["node_id"]] = node_dict
|
||||
|
||||
|
||||
return nodes_dict
|
||||
|
||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node labels (entity_id values) to look up.
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
A dictionary mapping each node_id to its degree (number of relationships).
|
||||
If a node is not found, its degree will be set to 0.
|
||||
"""
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
||||
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
@@ -1521,112 +1524,122 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
RETURN node_id, count(r) AS degree
|
||||
$$) AS (node_id text, degree bigint)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# Build result dictionary
|
||||
degrees_dict = {}
|
||||
for result in results:
|
||||
if result["node_id"] is not None:
|
||||
degrees_dict[result["node_id"]] = int(result["degree"])
|
||||
|
||||
|
||||
# Ensure all requested node_ids are in the result dictionary
|
||||
for node_id in node_ids:
|
||||
if node_id not in degrees_dict:
|
||||
degrees_dict[node_id] = 0
|
||||
|
||||
|
||||
return degrees_dict
|
||||
|
||||
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
||||
|
||||
async def edge_degrees_batch(
|
||||
self, edges: list[tuple[str, str]]
|
||||
) -> dict[tuple[str, str], int]:
|
||||
"""
|
||||
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
||||
in batch using the already implemented node_degrees_batch.
|
||||
|
||||
|
||||
Args:
|
||||
edges: List of (source_node_id, target_node_id) tuples
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping edge tuples to their combined degrees
|
||||
"""
|
||||
if not edges:
|
||||
return {}
|
||||
|
||||
|
||||
# Use node_degrees_batch to get all node degrees efficiently
|
||||
all_nodes = set()
|
||||
for src, tgt in edges:
|
||||
all_nodes.add(src)
|
||||
all_nodes.add(tgt)
|
||||
|
||||
|
||||
node_degrees = await self.node_degrees_batch(list(all_nodes))
|
||||
|
||||
|
||||
# Calculate edge degrees
|
||||
edge_degrees_dict = {}
|
||||
for src, tgt in edges:
|
||||
src_degree = node_degrees.get(src, 0)
|
||||
tgt_degree = node_degrees.get(tgt, 0)
|
||||
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
|
||||
|
||||
|
||||
return edge_degrees_dict
|
||||
|
||||
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
||||
|
||||
async def get_edges_batch(
|
||||
self, pairs: list[dict[str, str]]
|
||||
) -> dict[tuple[str, str], dict]:
|
||||
"""
|
||||
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
||||
|
||||
|
||||
Args:
|
||||
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary mapping (src, tgt) tuples to their edge properties.
|
||||
"""
|
||||
if not pairs:
|
||||
return {}
|
||||
|
||||
|
||||
# 从字典列表中提取源节点和目标节点ID
|
||||
src_nodes = []
|
||||
tgt_nodes = []
|
||||
for pair in pairs:
|
||||
src_nodes.append(pair["src"].replace('"', ''))
|
||||
tgt_nodes.append(pair["tgt"].replace('"', ''))
|
||||
|
||||
src_nodes.append(pair["src"].replace('"', ""))
|
||||
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
||||
|
||||
# 构建查询,使用数组索引来匹配源节点和目标节点
|
||||
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
||||
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
||||
|
||||
|
||||
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
||||
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
||||
UNWIND range(0, size(sources)-1) AS i
|
||||
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
|
||||
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
||||
$$) AS (source text, target text, edge_properties agtype)"""
|
||||
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# 构建结果字典
|
||||
edges_dict = {}
|
||||
for result in results:
|
||||
if result["source"] and result["target"] and result["edge_properties"]:
|
||||
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
||||
|
||||
edges_dict[(result["source"], result["target"])] = result[
|
||||
"edge_properties"
|
||||
]
|
||||
|
||||
return edges_dict
|
||||
|
||||
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
||||
|
||||
async def get_nodes_edges_batch(
|
||||
self, node_ids: list[str]
|
||||
) -> dict[str, list[tuple[str, str]]]:
|
||||
"""
|
||||
Get all edges for multiple nodes in a single batch operation.
|
||||
|
||||
|
||||
Args:
|
||||
node_ids: List of node IDs to get edges for
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping node IDs to lists of (source, target) edge tuples
|
||||
"""
|
||||
if not node_ids:
|
||||
return {}
|
||||
|
||||
|
||||
# Format node IDs for the query
|
||||
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
||||
|
||||
formatted_ids = ", ".join(
|
||||
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
||||
)
|
||||
|
||||
query = """SELECT * FROM cypher('%s', $$
|
||||
UNWIND [%s] AS node_id
|
||||
MATCH (n:base {entity_id: node_id})
|
||||
@@ -1634,11 +1647,11 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
RETURN node_id, connected.entity_id AS connected_id
|
||||
$$) AS (node_id text, connected_id text)""" % (
|
||||
self.graph_name,
|
||||
formatted_ids
|
||||
formatted_ids,
|
||||
)
|
||||
|
||||
|
||||
results = await self._query(query)
|
||||
|
||||
|
||||
# Build result dictionary
|
||||
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
||||
for result in results:
|
||||
@@ -1646,9 +1659,9 @@ class PGGraphStorage(BaseGraphStorage):
|
||||
nodes_edges_dict[result["node_id"]].append(
|
||||
(result["node_id"], result["connected_id"])
|
||||
)
|
||||
|
||||
|
||||
return nodes_edges_dict
|
||||
|
||||
|
||||
async def get_all_labels(self) -> list[str]:
|
||||
"""
|
||||
Get all labels (node IDs) in the graph.
|
||||
|
@@ -1323,14 +1323,14 @@ async def _get_node_data(
|
||||
|
||||
if not len(results):
|
||||
return "", "", ""
|
||||
|
||||
|
||||
# Extract all entity IDs from your results list
|
||||
node_ids = [r["entity_name"] for r in results]
|
||||
|
||||
# Call the batch node retrieval and degree functions concurrently.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids)
|
||||
knowledge_graph_inst.get_nodes_batch(node_ids),
|
||||
knowledge_graph_inst.node_degrees_batch(node_ids),
|
||||
)
|
||||
|
||||
# Now, if you need the node data and degree in order:
|
||||
@@ -1459,7 +1459,7 @@ async def _find_most_related_text_unit_from_entities(
|
||||
for dp in node_datas
|
||||
if dp["source_id"] is not None
|
||||
]
|
||||
|
||||
|
||||
node_names = [dp["entity_name"] for dp in node_datas]
|
||||
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
||||
# Build the edges list in the same order as node_datas.
|
||||
@@ -1472,10 +1472,14 @@ async def _find_most_related_text_unit_from_entities(
|
||||
all_one_hop_nodes.update([e[1] for e in this_edges])
|
||||
|
||||
all_one_hop_nodes = list(all_one_hop_nodes)
|
||||
|
||||
|
||||
# Batch retrieve one-hop node data using get_nodes_batch
|
||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
|
||||
all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
|
||||
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
|
||||
all_one_hop_nodes
|
||||
)
|
||||
all_one_hop_nodes_data = [
|
||||
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
|
||||
]
|
||||
|
||||
# Add null check for node data
|
||||
all_one_hop_text_units_lookup = {
|
||||
@@ -1571,13 +1575,13 @@ async def _find_most_related_edges_from_entities(
|
||||
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
|
||||
# For edge degrees, use tuples.
|
||||
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
|
||||
|
||||
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
||||
)
|
||||
|
||||
|
||||
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
||||
all_edges_data = []
|
||||
for pair in all_edges:
|
||||
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
|
||||
}
|
||||
all_edges_data.append(combined)
|
||||
|
||||
|
||||
all_edges_data = sorted(
|
||||
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
|
||||
# Call the batched functions concurrently.
|
||||
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
||||
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
||||
)
|
||||
|
||||
# Reconstruct edge_datas list in the same order as results.
|
||||
@@ -1652,7 +1655,7 @@ async def _get_edge_data(
|
||||
**edge_props,
|
||||
}
|
||||
edge_datas.append(combined)
|
||||
|
||||
|
||||
edge_datas = sorted(
|
||||
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
||||
)
|
||||
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
|
||||
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
||||
nodes_dict, degrees_dict = await asyncio.gather(
|
||||
knowledge_graph_inst.get_nodes_batch(entity_names),
|
||||
knowledge_graph_inst.node_degrees_batch(entity_names)
|
||||
knowledge_graph_inst.node_degrees_batch(entity_names),
|
||||
)
|
||||
|
||||
# Rebuild the list in the same order as entity_names
|
||||
|
Reference in New Issue
Block a user