Merge branch 'main' into select-datastore-in-api-server

This commit is contained in:
zrguo
2025-02-13 20:03:38 +08:00
committed by GitHub
73 changed files with 6101 additions and 1322 deletions

View File

@@ -18,7 +18,7 @@ config.read("config.ini", "utf-8")
@dataclass
class MilvusVectorDBStorge(BaseVectorStorage):
class MilvusVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = None
@staticmethod
@@ -65,7 +65,7 @@ class MilvusVectorDBStorge(BaseVectorStorage):
),
)
self._max_batch_size = self.global_config["embedding_batch_num"]
MilvusVectorDBStorge.create_collection_if_not_exist(
MilvusVectorDBStorage.create_collection_if_not_exist(
self._client,
self.namespace,
dimension=self.embedding_func.embedding_dim,

View File

@@ -26,6 +26,7 @@ from tenacity import (
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
config = configparser.ConfigParser()
@@ -379,7 +380,7 @@ class Neo4JStorage(BaseGraphStorage):
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> Dict[str, List[Dict]]:
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
@@ -390,32 +391,41 @@ class Neo4JStorage(BaseGraphStorage):
4. Add depth control
"""
label = node_label.strip('"')
result = {"nodes": [], "edges": []}
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE) as session:
try:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
return result
main_query = ""
if label == "*":
main_query = """
MATCH (n)
WITH collect(DISTINCT n) AS nodes
MATCH ()-[r]-()
RETURN nodes, collect(DISTINCT r) AS relationships;
"""
else:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
return result
# Optimized query (including direction handling and self-loops)
main_query = f"""
MATCH (start:`{label}`)
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',
minLevel: 0,
maxLevel: {max_depth},
bfs: true
}})
YIELD nodes, relationships
RETURN nodes, relationships
"""
# Optimized query (including direction handling and self-loops)
main_query = f"""
MATCH (start:`{label}`)
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',
minLevel: 0,
maxLevel: {max_depth},
bfs: true
}})
YIELD nodes, relationships
RETURN nodes, relationships
"""
result_set = await session.run(main_query)
record = await result_set.single()
@@ -423,35 +433,36 @@ class Neo4JStorage(BaseGraphStorage):
# Handle nodes (compatible with multi-label cases)
for node in record["nodes"]:
# Use node ID + label combination as unique identifier
node_id = f"{node.id}_{'_'.join(node.labels)}"
node_id = node.id
if node_id not in seen_nodes:
node_data = dict(node)
node_data["labels"] = list(node.labels) # Keep all labels
result["nodes"].append(node_data)
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=list(node.labels),
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = f"{rel.id}_{rel.type}"
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
edge_data = dict(rel)
edge_data.update(
{
"source": f"{start.id}_{'_'.join(start.labels)}",
"target": f"{end.id}_{'_'.join(end.labels)}",
"type": rel.type,
"direction": rel.element_id.split(
"->" if rel.end_node == end else "<-"
)[1],
}
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
result["edges"].append(edge_data)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result['nodes'])} | Edge count: {len(result['edges'])}"
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except neo4jExceptions.ClientError as e: