Added minimum degree filter for graph queries
- Introduced min_degree parameter in graph query - Updated UI to include minimum degree setting - Modified API to handle min_degree parameter - Updated graph query logic in LightRAG
This commit is contained in:
@@ -5,7 +5,6 @@ This module contains all graph-related routes for the LightRAG API.
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from ...utils import logger
|
|
||||||
from ..utils_api import get_api_key_dependency
|
from ..utils_api import get_api_key_dependency
|
||||||
|
|
||||||
router = APIRouter(tags=["graph"])
|
router = APIRouter(tags=["graph"])
|
||||||
@@ -25,7 +24,9 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||||||
return await rag.get_graph_labels()
|
return await rag.get_graph_labels()
|
||||||
|
|
||||||
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
||||||
async def get_knowledge_graph(label: str, max_depth: int = 3, inclusive: bool = False, min_degree: int = 0):
|
async def get_knowledge_graph(
|
||||||
|
label: str, max_depth: int = 3, min_degree: int = 0, inclusive: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
Retrieve a connected subgraph of nodes where the label includes the specified label.
|
||||||
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
|
||||||
@@ -44,7 +45,11 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict[str, List[str]]: Knowledge graph for label
|
Dict[str, List[str]]: Knowledge graph for label
|
||||||
"""
|
"""
|
||||||
logger.info(f"Inclusive search : {inclusive}, Min degree: {min_degree}, Label: {label}")
|
return await rag.get_knowledge_graph(
|
||||||
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth, inclusive=inclusive, min_degree=min_degree)
|
node_label=label,
|
||||||
|
max_depth=max_depth,
|
||||||
|
inclusive=inclusive,
|
||||||
|
min_degree=min_degree,
|
||||||
|
)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
@@ -232,7 +232,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
return sorted(list(labels))
|
return sorted(list(labels))
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5, search_mode: str = "exact", min_degree: int = 0
|
self,
|
||||||
|
node_label: str,
|
||||||
|
max_depth: int = 5,
|
||||||
|
min_degree: int = 0,
|
||||||
|
inclusive: bool = False,
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
"""
|
"""
|
||||||
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
|
||||||
@@ -268,7 +272,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
nodes_to_explore = []
|
nodes_to_explore = []
|
||||||
for n, attr in graph.nodes(data=True):
|
for n, attr in graph.nodes(data=True):
|
||||||
node_str = str(n)
|
node_str = str(n)
|
||||||
if search_mode == "exact":
|
if not inclusive:
|
||||||
if node_label == node_str: # Use exact matching
|
if node_label == node_str: # Use exact matching
|
||||||
nodes_to_explore.append(n)
|
nodes_to_explore.append(n)
|
||||||
else: # inclusive mode
|
else: # inclusive mode
|
||||||
@@ -284,12 +288,16 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
for start_node in nodes_to_explore:
|
for start_node in nodes_to_explore:
|
||||||
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
|
||||||
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
|
||||||
|
|
||||||
# Filter nodes based on min_degree
|
# Filter nodes based on min_degree
|
||||||
if min_degree > 0:
|
if min_degree > 0:
|
||||||
nodes_to_keep = [node for node, degree in combined_subgraph.degree() if degree >= min_degree]
|
nodes_to_keep = [
|
||||||
|
node
|
||||||
|
for node, degree in combined_subgraph.degree()
|
||||||
|
if degree >= min_degree
|
||||||
|
]
|
||||||
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
|
combined_subgraph = combined_subgraph.subgraph(nodes_to_keep)
|
||||||
|
|
||||||
subgraph = combined_subgraph
|
subgraph = combined_subgraph
|
||||||
|
|
||||||
# Check if number of nodes exceeds max_graph_nodes
|
# Check if number of nodes exceeds max_graph_nodes
|
||||||
|
@@ -504,7 +504,11 @@ class LightRAG:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int, inclusive: bool = False, min_degree: int = 0
|
self,
|
||||||
|
node_label: str,
|
||||||
|
max_depth: int,
|
||||||
|
min_degree: int = 0,
|
||||||
|
inclusive: bool = False,
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
"""Get knowledge graph for a given label
|
"""Get knowledge graph for a given label
|
||||||
|
|
||||||
@@ -520,6 +524,8 @@ class LightRAG:
|
|||||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
max_depth=max_depth,
|
max_depth=max_depth,
|
||||||
|
min_degree=min_degree,
|
||||||
|
inclusive=inclusive,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||||
|
@@ -162,11 +162,11 @@ axiosInstance.interceptors.response.use(
|
|||||||
|
|
||||||
// API methods
|
// API methods
|
||||||
export const queryGraphs = async (
|
export const queryGraphs = async (
|
||||||
label: string,
|
label: string,
|
||||||
maxDepth: number,
|
maxDepth: number,
|
||||||
inclusive: boolean = false
|
minDegree: number
|
||||||
): Promise<LightragGraphType> => {
|
): Promise<LightragGraphType> => {
|
||||||
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&inclusive=${inclusive}`)
|
const response = await axiosInstance.get(`/graphs?label=${encodeURIComponent(label)}&max_depth=${maxDepth}&min_degree=${minDegree}`)
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -90,9 +90,12 @@ const LabeledNumberInput = ({
|
|||||||
{label}
|
{label}
|
||||||
</label>
|
</label>
|
||||||
<Input
|
<Input
|
||||||
value={currentValue || ''}
|
type="number"
|
||||||
|
value={currentValue === null ? '' : currentValue}
|
||||||
onChange={onValueChange}
|
onChange={onValueChange}
|
||||||
className="h-6 w-full min-w-0"
|
className="h-6 w-full min-w-0 pr-1"
|
||||||
|
min={min}
|
||||||
|
max={max}
|
||||||
onBlur={onBlur}
|
onBlur={onBlur}
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
if (e.key === 'Enter') {
|
if (e.key === 'Enter') {
|
||||||
@@ -119,6 +122,7 @@ export default function Settings() {
|
|||||||
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
|
const enableHideUnselectedEdges = useSettingsStore.use.enableHideUnselectedEdges()
|
||||||
const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
|
const showEdgeLabel = useSettingsStore.use.showEdgeLabel()
|
||||||
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
|
const graphQueryMaxDepth = useSettingsStore.use.graphQueryMaxDepth()
|
||||||
|
const graphMinDegree = useSettingsStore.use.graphMinDegree()
|
||||||
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
|
const graphLayoutMaxIterations = useSettingsStore.use.graphLayoutMaxIterations()
|
||||||
|
|
||||||
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
const enableHealthCheck = useSettingsStore.use.enableHealthCheck()
|
||||||
@@ -177,6 +181,11 @@ export default function Settings() {
|
|||||||
useSettingsStore.setState({ graphQueryMaxDepth: depth })
|
useSettingsStore.setState({ graphQueryMaxDepth: depth })
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
const setGraphMinDegree = useCallback((degree: number) => {
|
||||||
|
if (degree < 0) return
|
||||||
|
useSettingsStore.setState({ graphMinDegree: degree })
|
||||||
|
}, [])
|
||||||
|
|
||||||
const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
|
const setGraphLayoutMaxIterations = useCallback((iterations: number) => {
|
||||||
if (iterations < 1) return
|
if (iterations < 1) return
|
||||||
useSettingsStore.setState({ graphLayoutMaxIterations: iterations })
|
useSettingsStore.setState({ graphLayoutMaxIterations: iterations })
|
||||||
@@ -266,6 +275,12 @@ export default function Settings() {
|
|||||||
value={graphQueryMaxDepth}
|
value={graphQueryMaxDepth}
|
||||||
onEditFinished={setGraphQueryMaxDepth}
|
onEditFinished={setGraphQueryMaxDepth}
|
||||||
/>
|
/>
|
||||||
|
<LabeledNumberInput
|
||||||
|
label="Minimum Degree"
|
||||||
|
min={0}
|
||||||
|
value={graphMinDegree}
|
||||||
|
onEditFinished={setGraphMinDegree}
|
||||||
|
/>
|
||||||
<LabeledNumberInput
|
<LabeledNumberInput
|
||||||
label="Max Layout Iterations"
|
label="Max Layout Iterations"
|
||||||
min={1}
|
min={1}
|
||||||
|
@@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<'input'>>(
|
|||||||
<input
|
<input
|
||||||
type={type}
|
type={type}
|
||||||
className={cn(
|
className={cn(
|
||||||
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm',
|
'border-input file:text-foreground placeholder:text-muted-foreground focus-visible:ring-ring flex h-9 rounded-md border bg-transparent px-3 py-1 text-base shadow-sm transition-colors file:border-0 file:bg-transparent file:text-sm file:font-medium focus-visible:ring-1 focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm [&::-webkit-inner-spin-button]:opacity-100 [&::-webkit-outer-spin-button]:opacity-100',
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
ref={ref}
|
ref={ref}
|
||||||
|
@@ -50,11 +50,11 @@ export type NodeType = {
|
|||||||
}
|
}
|
||||||
export type EdgeType = { label: string }
|
export type EdgeType = { label: string }
|
||||||
|
|
||||||
const fetchGraph = async (label: string, maxDepth: number) => {
|
const fetchGraph = async (label: string, maxDepth: number, minDegree: number) => {
|
||||||
let rawData: any = null
|
let rawData: any = null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
rawData = await queryGraphs(label, maxDepth)
|
rawData = await queryGraphs(label, maxDepth, minDegree)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!')
|
useBackendState.getState().setErrorMessage(errorMessage(e), 'Query Graphs Error!')
|
||||||
return null
|
return null
|
||||||
@@ -161,13 +161,14 @@ const createSigmaGraph = (rawGraph: RawGraph | null) => {
|
|||||||
return graph
|
return graph
|
||||||
}
|
}
|
||||||
|
|
||||||
const lastQueryLabel = { label: '', maxQueryDepth: 0 }
|
const lastQueryLabel = { label: '', maxQueryDepth: 0, minDegree: 0 }
|
||||||
|
|
||||||
const useLightrangeGraph = () => {
|
const useLightrangeGraph = () => {
|
||||||
const queryLabel = useSettingsStore.use.queryLabel()
|
const queryLabel = useSettingsStore.use.queryLabel()
|
||||||
const rawGraph = useGraphStore.use.rawGraph()
|
const rawGraph = useGraphStore.use.rawGraph()
|
||||||
const sigmaGraph = useGraphStore.use.sigmaGraph()
|
const sigmaGraph = useGraphStore.use.sigmaGraph()
|
||||||
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
|
const maxQueryDepth = useSettingsStore.use.graphQueryMaxDepth()
|
||||||
|
const minDegree = useSettingsStore.use.graphMinDegree()
|
||||||
|
|
||||||
const getNode = useCallback(
|
const getNode = useCallback(
|
||||||
(nodeId: string) => {
|
(nodeId: string) => {
|
||||||
@@ -185,13 +186,16 @@ const useLightrangeGraph = () => {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (queryLabel) {
|
if (queryLabel) {
|
||||||
if (lastQueryLabel.label !== queryLabel || lastQueryLabel.maxQueryDepth !== maxQueryDepth) {
|
if (lastQueryLabel.label !== queryLabel ||
|
||||||
|
lastQueryLabel.maxQueryDepth !== maxQueryDepth ||
|
||||||
|
lastQueryLabel.minDegree !== minDegree) {
|
||||||
lastQueryLabel.label = queryLabel
|
lastQueryLabel.label = queryLabel
|
||||||
lastQueryLabel.maxQueryDepth = maxQueryDepth
|
lastQueryLabel.maxQueryDepth = maxQueryDepth
|
||||||
|
lastQueryLabel.minDegree = minDegree
|
||||||
|
|
||||||
const state = useGraphStore.getState()
|
const state = useGraphStore.getState()
|
||||||
state.reset()
|
state.reset()
|
||||||
fetchGraph(queryLabel, maxQueryDepth).then((data) => {
|
fetchGraph(queryLabel, maxQueryDepth, minDegree).then((data) => {
|
||||||
// console.debug('Query label: ' + queryLabel)
|
// console.debug('Query label: ' + queryLabel)
|
||||||
state.setSigmaGraph(createSigmaGraph(data))
|
state.setSigmaGraph(createSigmaGraph(data))
|
||||||
data?.buildDynamicMap()
|
data?.buildDynamicMap()
|
||||||
@@ -203,7 +207,7 @@ const useLightrangeGraph = () => {
|
|||||||
state.reset()
|
state.reset()
|
||||||
state.setSigmaGraph(new DirectedGraph())
|
state.setSigmaGraph(new DirectedGraph())
|
||||||
}
|
}
|
||||||
}, [queryLabel, maxQueryDepth])
|
}, [queryLabel, maxQueryDepth, minDegree])
|
||||||
|
|
||||||
const lightrageGraph = useCallback(() => {
|
const lightrageGraph = useCallback(() => {
|
||||||
if (sigmaGraph) {
|
if (sigmaGraph) {
|
||||||
|
@@ -22,6 +22,9 @@ interface SettingsState {
|
|||||||
graphQueryMaxDepth: number
|
graphQueryMaxDepth: number
|
||||||
setGraphQueryMaxDepth: (depth: number) => void
|
setGraphQueryMaxDepth: (depth: number) => void
|
||||||
|
|
||||||
|
graphMinDegree: number
|
||||||
|
setGraphMinDegree: (degree: number) => void
|
||||||
|
|
||||||
graphLayoutMaxIterations: number
|
graphLayoutMaxIterations: number
|
||||||
setGraphLayoutMaxIterations: (iterations: number) => void
|
setGraphLayoutMaxIterations: (iterations: number) => void
|
||||||
|
|
||||||
@@ -66,6 +69,7 @@ const useSettingsStoreBase = create<SettingsState>()(
|
|||||||
enableEdgeEvents: false,
|
enableEdgeEvents: false,
|
||||||
|
|
||||||
graphQueryMaxDepth: 3,
|
graphQueryMaxDepth: 3,
|
||||||
|
graphMinDegree: 0,
|
||||||
graphLayoutMaxIterations: 10,
|
graphLayoutMaxIterations: 10,
|
||||||
|
|
||||||
queryLabel: defaultQueryLabel,
|
queryLabel: defaultQueryLabel,
|
||||||
@@ -107,6 +111,8 @@ const useSettingsStoreBase = create<SettingsState>()(
|
|||||||
|
|
||||||
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
|
setGraphQueryMaxDepth: (depth: number) => set({ graphQueryMaxDepth: depth }),
|
||||||
|
|
||||||
|
setGraphMinDegree: (degree: number) => set({ graphMinDegree: degree }),
|
||||||
|
|
||||||
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
|
setEnableHealthCheck: (enable: boolean) => set({ enableHealthCheck: enable }),
|
||||||
|
|
||||||
setApiKey: (apiKey: string | null) => set({ apiKey }),
|
setApiKey: (apiKey: string | null) => set({ apiKey }),
|
||||||
|
Reference in New Issue
Block a user