Merge branch 'main' into select-datastore-in-api-server
This commit is contained in:
@@ -18,7 +18,7 @@ config.read("config.ini", "utf-8")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MilvusVectorDBStorge(BaseVectorStorage):
|
||||
class MilvusVectorDBStorage(BaseVectorStorage):
|
||||
cosine_better_than_threshold: float = None
|
||||
|
||||
@staticmethod
|
||||
@@ -65,7 +65,7 @@ class MilvusVectorDBStorge(BaseVectorStorage):
|
||||
),
|
||||
)
|
||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||
MilvusVectorDBStorge.create_collection_if_not_exist(
|
||||
MilvusVectorDBStorage.create_collection_if_not_exist(
|
||||
self._client,
|
||||
self.namespace,
|
||||
dimension=self.embedding_func.embedding_dim,
|
||||
|
@@ -26,6 +26,7 @@ from tenacity import (
|
||||
|
||||
from ..utils import logger
|
||||
from ..base import BaseGraphStorage
|
||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
@@ -379,7 +380,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
|
||||
async def get_knowledge_graph(
|
||||
self, node_label: str, max_depth: int = 5
|
||||
) -> Dict[str, List[Dict]]:
|
||||
) -> KnowledgeGraph:
|
||||
"""
|
||||
Get complete connected subgraph for specified node (including the starting node itself)
|
||||
|
||||
@@ -390,32 +391,41 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
4. Add depth control
|
||||
"""
|
||||
label = node_label.strip('"')
|
||||
result = {"nodes": [], "edges": []}
|
||||
result = KnowledgeGraph()
|
||||
seen_nodes = set()
|
||||
seen_edges = set()
|
||||
|
||||
async with self._driver.session(database=self._DATABASE) as session:
|
||||
try:
|
||||
# Critical debug step: first verify if starting node exists
|
||||
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
||||
validate_result = await session.run(validate_query)
|
||||
if not await validate_result.single():
|
||||
logger.warning(f"Starting node {label} does not exist!")
|
||||
return result
|
||||
main_query = ""
|
||||
if label == "*":
|
||||
main_query = """
|
||||
MATCH (n)
|
||||
WITH collect(DISTINCT n) AS nodes
|
||||
MATCH ()-[r]-()
|
||||
RETURN nodes, collect(DISTINCT r) AS relationships;
|
||||
"""
|
||||
else:
|
||||
# Critical debug step: first verify if starting node exists
|
||||
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
|
||||
validate_result = await session.run(validate_query)
|
||||
if not await validate_result.single():
|
||||
logger.warning(f"Starting node {label} does not exist!")
|
||||
return result
|
||||
|
||||
# Optimized query (including direction handling and self-loops)
|
||||
main_query = f"""
|
||||
MATCH (start:`{label}`)
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {{
|
||||
relationshipFilter: '>',
|
||||
minLevel: 0,
|
||||
maxLevel: {max_depth},
|
||||
bfs: true
|
||||
}})
|
||||
YIELD nodes, relationships
|
||||
RETURN nodes, relationships
|
||||
"""
|
||||
# Optimized query (including direction handling and self-loops)
|
||||
main_query = f"""
|
||||
MATCH (start:`{label}`)
|
||||
WITH start
|
||||
CALL apoc.path.subgraphAll(start, {{
|
||||
relationshipFilter: '>',
|
||||
minLevel: 0,
|
||||
maxLevel: {max_depth},
|
||||
bfs: true
|
||||
}})
|
||||
YIELD nodes, relationships
|
||||
RETURN nodes, relationships
|
||||
"""
|
||||
result_set = await session.run(main_query)
|
||||
record = await result_set.single()
|
||||
|
||||
@@ -423,35 +433,36 @@ class Neo4JStorage(BaseGraphStorage):
|
||||
# Handle nodes (compatible with multi-label cases)
|
||||
for node in record["nodes"]:
|
||||
# Use node ID + label combination as unique identifier
|
||||
node_id = f"{node.id}_{'_'.join(node.labels)}"
|
||||
node_id = node.id
|
||||
if node_id not in seen_nodes:
|
||||
node_data = dict(node)
|
||||
node_data["labels"] = list(node.labels) # Keep all labels
|
||||
result["nodes"].append(node_data)
|
||||
result.nodes.append(
|
||||
KnowledgeGraphNode(
|
||||
id=f"{node_id}",
|
||||
labels=list(node.labels),
|
||||
properties=dict(node),
|
||||
)
|
||||
)
|
||||
seen_nodes.add(node_id)
|
||||
|
||||
# Handle relationships (including direction information)
|
||||
for rel in record["relationships"]:
|
||||
edge_id = f"{rel.id}_{rel.type}"
|
||||
edge_id = rel.id
|
||||
if edge_id not in seen_edges:
|
||||
start = rel.start_node
|
||||
end = rel.end_node
|
||||
edge_data = dict(rel)
|
||||
edge_data.update(
|
||||
{
|
||||
"source": f"{start.id}_{'_'.join(start.labels)}",
|
||||
"target": f"{end.id}_{'_'.join(end.labels)}",
|
||||
"type": rel.type,
|
||||
"direction": rel.element_id.split(
|
||||
"->" if rel.end_node == end else "<-"
|
||||
)[1],
|
||||
}
|
||||
result.edges.append(
|
||||
KnowledgeGraphEdge(
|
||||
id=f"{edge_id}",
|
||||
type=rel.type,
|
||||
source=f"{start.id}",
|
||||
target=f"{end.id}",
|
||||
properties=dict(rel),
|
||||
)
|
||||
)
|
||||
result["edges"].append(edge_data)
|
||||
seen_edges.add(edge_id)
|
||||
|
||||
logger.info(
|
||||
f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}"
|
||||
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||
)
|
||||
|
||||
except neo4jExceptions.ClientError as e:
|
||||
|
Reference in New Issue
Block a user