Added search mode and min degree filtering for NetworkX

- Implemented exact and inclusive search modes
- Added min degree filtering for nodes
- Updated API to parse label for search options
This commit is contained in:
yangdx
2025-03-04 16:08:05 +08:00
parent 735231d851
commit 002948d342
3 changed files with 70 additions and 7 deletions

View File

@@ -34,6 +34,11 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
2. Followed by nodes directly connected to the matching nodes 2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes 3. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000) Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Control search mode by label content:
1. only label-name : exact search with the label name (selecting from the label list return previously)
2. label-name follow by '>n' : exact search of nodes with degree more than n
3. label-name follow by* : inclusive search of nodes with degree more than n
4. label-name follow by '>n*' : inclusive search
Args: Args:
label (str): Label to get knowledge graph for label (str): Label to get knowledge graph for
@@ -42,6 +47,37 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
Returns: Returns:
Dict[str, List[str]]: Knowledge graph for label Dict[str, List[str]]: Knowledge graph for label
""" """
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth) # Parse label to extract search mode and min degree if specified
search_mode = "exact" # Default search mode
min_degree = 0 # Default minimum degree
original_label = label
# First check if label ends with *
if label.endswith("*"):
search_mode = "inclusive" # Always set to inclusive if ends with *
label = label[:-1].strip() # Remove trailing *
# Try to parse >n if it exists
if ">" in label:
try:
degree_pos = label.rfind(">")
degree_str = label[degree_pos + 1:].strip()
min_degree = int(degree_str) + 1
label = label[:degree_pos].strip()
except ValueError:
# If degree parsing fails, just remove * and keep the rest as label
label = original_label[:-1].strip()
# If no *, check for >n pattern
elif ">" in label:
try:
degree_pos = label.rfind(">")
degree_str = label[degree_pos + 1:].strip()
min_degree = int(degree_str) + 1
label = label[:degree_pos].strip()
except ValueError:
# If degree parsing fails, treat the whole string as label
label = original_label
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth, search_mode=search_mode, min_degree=min_degree)
return router return router

View File

@@ -232,7 +232,7 @@ class NetworkXStorage(BaseGraphStorage):
return sorted(list(labels)) return sorted(list(labels))
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5 self, node_label: str, max_depth: int = 5, search_mode: str = "exact", min_degree: int = 0
) -> KnowledgeGraph: ) -> KnowledgeGraph:
""" """
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
@@ -245,6 +245,8 @@ class NetworkXStorage(BaseGraphStorage):
Args: Args:
node_label: Label of the starting node node_label: Label of the starting node
max_depth: Maximum depth of the subgraph max_depth: Maximum depth of the subgraph
search_mode (str, optional): Search mode, either "exact" or "inclusive". Defaults to "exact".
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
Returns: Returns:
KnowledgeGraph object containing nodes and edges KnowledgeGraph object containing nodes and edges
@@ -262,11 +264,16 @@ class NetworkXStorage(BaseGraphStorage):
graph.copy() graph.copy()
) # Create a copy to avoid modifying the original graph ) # Create a copy to avoid modifying the original graph
else: else:
# Find nodes with matching node id (partial match) # Find nodes with matching node id based on search_mode
nodes_to_explore = [] nodes_to_explore = []
for n, attr in graph.nodes(data=True): for n, attr in graph.nodes(data=True):
if node_label in str(n): # Use partial matching node_str = str(n)
nodes_to_explore.append(n) if search_mode == "exact":
if node_label == node_str: # Use exact matching
nodes_to_explore.append(n)
else: # inclusive mode
if node_label in node_str: # Use partial matching
nodes_to_explore.append(n)
if not nodes_to_explore: if not nodes_to_explore:
logger.warning(f"No nodes found with label {node_label}") logger.warning(f"No nodes found with label {node_label}")
@@ -277,6 +284,12 @@ class NetworkXStorage(BaseGraphStorage):
for start_node in nodes_to_explore: for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth) node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph) combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
# Filter nodes based on min_degree
if min_degree > 0:
nodes_to_keep = [node for node, degree in combined_subgraph.degree() if degree >= min_degree]
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
subgraph = combined_subgraph subgraph = combined_subgraph
# Check if number of nodes exceeds max_graph_nodes # Check if number of nodes exceeds max_graph_nodes

View File

@@ -504,10 +504,24 @@ class LightRAG:
return text return text
async def get_knowledge_graph( async def get_knowledge_graph(
self, node_label: str, max_depth: int self, node_label: str, max_depth: int, search_mode: str = "exact", min_degree: int = 0
) -> KnowledgeGraph: ) -> KnowledgeGraph:
"""Get knowledge graph for a given label
Args:
node_label (str): Label to get knowledge graph for
max_depth (int): Maximum depth of graph
search_mode (str, optional): Search mode, either "exact" or "inclusive". Defaults to "exact".
min_degree (int, optional): Minimum degree of nodes to include. Defaults to 0.
Returns:
KnowledgeGraph: Knowledge graph containing nodes and edges
"""
return await self.chunk_entity_relation_graph.get_knowledge_graph( return await self.chunk_entity_relation_graph.get_knowledge_graph(
node_label=node_label, max_depth=max_depth node_label=node_label,
max_depth=max_depth,
search_mode=search_mode,
min_degree=min_degree
) )
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: def _get_storage_class(self, storage_name: str) -> Callable[..., Any]: