Merge branch 'main' into standalone-logger-setup

This commit is contained in:
yangdx
2025-03-04 09:54:14 +08:00
5 changed files with 130 additions and 33 deletions

View File

@@ -5,6 +5,7 @@
# PORT=9621
# WORKERS=1
# NAMESPACE_PREFIX=lightrag # separating data from difference Lightrag instances
# MAX_GRAPH_NODES=1000 # Max nodes return from grap retrieval
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
### Optional SSL Configuration

View File

@@ -16,12 +16,32 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
@router.get("/graph/label/list", dependencies=[Depends(optional_api_key)])
async def get_graph_labels():
"""Get all graph labels"""
"""
Get all graph labels
Returns:
List[str]: List of graph labels
"""
return await rag.get_graph_labels()
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
async def get_knowledge_graph(label: str, max_depth: int = 3):
"""Get knowledge graph for a specific label"""
"""
Retrieve a connected subgraph of nodes where the label includes the specified label.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Maximum number of nodes is limited to env MAX_GRAPH_NODES(default: 1000)
Args:
label (str): Label to get knowledge graph for
max_depth (int, optional): Maximum depth of graph. Defaults to 3.
Returns:
Dict[str, List[str]]: Knowledge graph for label
"""
return await rag.get_knowledge_graph(node_label=label, max_depth=max_depth)
return router

View File

@@ -23,7 +23,7 @@ import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
from neo4j import ( # type: ignore
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
@@ -34,6 +34,9 @@ from neo4j import (
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final
@dataclass
@@ -470,40 +473,61 @@ class Neo4JStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence (nodes containing the specified label string)
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Key fixes:
1. Include the starting node itself
2. Handle multi-label nodes
3. Clarify relationship directions
4. Add depth control
Args:
node_label (str): String to match in node labels (will match any node containing this string in its label)
max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
Returns:
KnowledgeGraph: Complete connected subgraph for specified node
"""
label = node_label.strip('"')
# Escape single quotes to prevent injection attacks
escaped_label = label.replace("'", "\\'")
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
async with self._driver.session(database=self._DATABASE) as session:
try:
main_query = ""
if label == "*":
main_query = """
MATCH (n)
WITH collect(DISTINCT n) AS nodes
MATCH ()-[r]-()
RETURN nodes, collect(DISTINCT r) AS relationships;
OPTIONAL MATCH (n)-[r]-()
WITH n, count(r) AS degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect(n) AS nodes
MATCH (a)-[r]->(b)
WHERE a IN nodes AND b IN nodes
RETURN nodes, collect(DISTINCT r) AS relationships
"""
result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES}
)
else:
# Critical debug step: first verify if starting node exists
validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
validate_query = f"""
MATCH (n)
WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}')
RETURN n LIMIT 1
"""
validate_result = await session.run(validate_query)
if not await validate_result.single():
logger.warning(f"Starting node {label} does not exist!")
logger.warning(
f"No nodes containing '{label}' in their labels found!"
)
return result
# Optimized query (including direction handling and self-loops)
# Main query uses partial matching
main_query = f"""
MATCH (start:`{label}`)
MATCH (start)
WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}')
WITH start
CALL apoc.path.subgraphAll(start, {{
relationshipFilter: '>',
@@ -512,9 +536,25 @@ class Neo4JStorage(BaseGraphStorage):
bfs: true
}})
YIELD nodes, relationships
RETURN nodes, relationships
WITH start, nodes, relationships
UNWIND nodes AS node
OPTIONAL MATCH (node)-[r]-()
WITH node, count(r) AS degree, start, nodes, relationships,
CASE
WHEN id(node) = id(start) THEN 2
WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
ELSE 0
END AS priority
ORDER BY priority DESC, degree DESC
LIMIT $max_nodes
WITH collect(node) AS filtered_nodes, nodes, relationships
RETURN filtered_nodes AS nodes,
[rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
"""
result_set = await session.run(main_query)
result_set = await session.run(
main_query, {"max_nodes": MAX_GRAPH_NODES}
)
record = await result_set.single()
if record:

View File

@@ -24,6 +24,8 @@ from .shared_storage import (
is_multiprocess,
)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
@final
@dataclass
@@ -233,7 +235,12 @@ class NetworkXStorage(BaseGraphStorage):
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get complete connected subgraph for specified node (including the starting node itself)
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
When reducing the number of nodes, the prioritization criteria are as follows:
1. Label matching nodes take precedence
2. Followed by nodes directly connected to the matching nodes
3. Finally, the degree of the nodes
Args:
node_label: Label of the starting node
@@ -265,22 +272,51 @@ class NetworkXStorage(BaseGraphStorage):
logger.warning(f"No nodes found with label {node_label}")
return result
# Get subgraph using ego_graph
subgraph = nx.ego_graph(graph, nodes_to_explore[0], radius=max_depth)
# Get subgraph using ego_graph from all matching nodes
combined_subgraph = nx.Graph()
for start_node in nodes_to_explore:
node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
subgraph = combined_subgraph
# Check if number of nodes exceeds max_graph_nodes
max_graph_nodes = 500
if len(subgraph.nodes()) > max_graph_nodes:
if len(subgraph.nodes()) > MAX_GRAPH_NODES:
origin_nodes = len(subgraph.nodes())
node_degrees = dict(subgraph.degree())
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
:max_graph_nodes
start_nodes = set()
direct_connected_nodes = set()
if node_label != "*" and nodes_to_explore:
start_nodes = set(nodes_to_explore)
# Get nodes directly connected to all start nodes
for start_node in start_nodes:
direct_connected_nodes.update(subgraph.neighbors(start_node))
# Remove start nodes from directly connected nodes (avoid duplicates)
direct_connected_nodes -= start_nodes
def priority_key(node_item):
node, degree = node_item
# Priority order: start(2) > directly connected(1) > other nodes(0)
if node in start_nodes:
priority = 2
elif node in direct_connected_nodes:
priority = 1
else:
priority = 0
return (priority, degree)
# Sort by priority and degree and select top MAX_GRAPH_NODES nodes
top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
:MAX_GRAPH_NODES
]
top_node_ids = [node[0] for node in top_nodes]
# Create new subgraph with only top nodes
# Create new subgraph and keep nodes only with most degree
subgraph = subgraph.subgraph(top_node_ids)
logger.info(
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
)
# Add nodes to result
@@ -320,7 +356,7 @@ class NetworkXStorage(BaseGraphStorage):
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="DIRECTED",
type="RELATED",
source=str(source),
target=str(target),
properties=edge_data,

View File

@@ -1173,7 +1173,7 @@ class LightRAG:
"""
if param.mode in ["local", "global", "hybrid"]:
response = await kg_query(
query,
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
@@ -1194,7 +1194,7 @@ class LightRAG:
)
elif param.mode == "naive":
response = await naive_query(
query,
query.strip(),
self.chunks_vdb,
self.text_chunks,
param,
@@ -1213,7 +1213,7 @@ class LightRAG:
)
elif param.mode == "mix":
response = await mix_kg_vector_query(
query,
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,