Merge remote-tracking branch 'origin/main' into dev-webui
This commit is contained in:
@@ -22,6 +22,6 @@ def create_graph_routes(rag, api_key: Optional[str] = None):
|
|||||||
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
@router.get("/graphs", dependencies=[Depends(optional_api_key)])
|
||||||
async def get_knowledge_graph(label: str):
|
async def get_knowledge_graph(label: str):
|
||||||
"""Get knowledge graph for a specific label"""
|
"""Get knowledge graph for a specific label"""
|
||||||
return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
|
return await rag.get_knowledge_graph(node_label=label, max_depth=3)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
@@ -5,7 +5,7 @@ from typing import Any, final
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from lightrag.types import KnowledgeGraph
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from lightrag.utils import (
|
from lightrag.utils import (
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
@@ -169,9 +169,118 @@ class NetworkXStorage(BaseGraphStorage):
|
|||||||
self._graph.remove_edge(source, target)
|
self._graph.remove_edge(source, target)
|
||||||
|
|
||||||
async def get_all_labels(self) -> list[str]:
|
async def get_all_labels(self) -> list[str]:
|
||||||
raise NotImplementedError
|
"""
|
||||||
|
Get all node labels in the graph
|
||||||
|
Returns:
|
||||||
|
[label1, label2, ...] # Alphabetically sorted label list
|
||||||
|
"""
|
||||||
|
labels = set()
|
||||||
|
for node in self._graph.nodes():
|
||||||
|
labels.add(str(node)) # Add node id as a label
|
||||||
|
|
||||||
|
# Return sorted list
|
||||||
|
return sorted(list(labels))
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, node_label: str, max_depth: int = 5
|
self, node_label: str, max_depth: int = 5
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
raise NotImplementedError
|
"""
|
||||||
|
Get complete connected subgraph for specified node (including the starting node itself)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_label: Label of the starting node
|
||||||
|
max_depth: Maximum depth of the subgraph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeGraph object containing nodes and edges
|
||||||
|
"""
|
||||||
|
result = KnowledgeGraph()
|
||||||
|
seen_nodes = set()
|
||||||
|
seen_edges = set()
|
||||||
|
|
||||||
|
# Handle special case for "*" label
|
||||||
|
if node_label == "*":
|
||||||
|
# For "*", return the entire graph including all nodes and edges
|
||||||
|
subgraph = (
|
||||||
|
self._graph.copy()
|
||||||
|
) # Create a copy to avoid modifying the original graph
|
||||||
|
else:
|
||||||
|
# Find nodes with matching node id (partial match)
|
||||||
|
nodes_to_explore = []
|
||||||
|
for n, attr in self._graph.nodes(data=True):
|
||||||
|
if node_label in str(n): # Use partial matching
|
||||||
|
nodes_to_explore.append(n)
|
||||||
|
|
||||||
|
if not nodes_to_explore:
|
||||||
|
logger.warning(f"No nodes found with label {node_label}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Get subgraph using ego_graph
|
||||||
|
subgraph = nx.ego_graph(self._graph, nodes_to_explore[0], radius=max_depth)
|
||||||
|
|
||||||
|
# Check if number of nodes exceeds max_graph_nodes
|
||||||
|
max_graph_nodes = 500
|
||||||
|
if len(subgraph.nodes()) > max_graph_nodes:
|
||||||
|
origin_nodes = len(subgraph.nodes())
|
||||||
|
node_degrees = dict(subgraph.degree())
|
||||||
|
top_nodes = sorted(node_degrees.items(), key=lambda x: x[1], reverse=True)[
|
||||||
|
:max_graph_nodes
|
||||||
|
]
|
||||||
|
top_node_ids = [node[0] for node in top_nodes]
|
||||||
|
# Create new subgraph with only top nodes
|
||||||
|
subgraph = subgraph.subgraph(top_node_ids)
|
||||||
|
logger.info(
|
||||||
|
f"Reduced graph from {origin_nodes} nodes to {max_graph_nodes} nodes (depth={max_depth})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add nodes to result
|
||||||
|
for node in subgraph.nodes():
|
||||||
|
if str(node) in seen_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_data = dict(subgraph.nodes[node])
|
||||||
|
# Get entity_type as labels
|
||||||
|
labels = []
|
||||||
|
if "entity_type" in node_data:
|
||||||
|
if isinstance(node_data["entity_type"], list):
|
||||||
|
labels.extend(node_data["entity_type"])
|
||||||
|
else:
|
||||||
|
labels.append(node_data["entity_type"])
|
||||||
|
|
||||||
|
# Create node with properties
|
||||||
|
node_properties = {k: v for k, v in node_data.items()}
|
||||||
|
|
||||||
|
result.nodes.append(
|
||||||
|
KnowledgeGraphNode(
|
||||||
|
id=str(node), labels=[str(node)], properties=node_properties
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_nodes.add(str(node))
|
||||||
|
|
||||||
|
# Add edges to result
|
||||||
|
for edge in subgraph.edges():
|
||||||
|
source, target = edge
|
||||||
|
edge_id = f"{source}-{target}"
|
||||||
|
if edge_id in seen_edges:
|
||||||
|
continue
|
||||||
|
|
||||||
|
edge_data = dict(subgraph.edges[edge])
|
||||||
|
|
||||||
|
# Create edge with complete information
|
||||||
|
result.edges.append(
|
||||||
|
KnowledgeGraphEdge(
|
||||||
|
id=edge_id,
|
||||||
|
type="DIRECTED",
|
||||||
|
source=str(source),
|
||||||
|
target=str(target),
|
||||||
|
properties=edge_data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_edges.add(edge_id)
|
||||||
|
|
||||||
|
# logger.info(result.edges)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
@@ -466,10 +466,10 @@ class LightRAG:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
self, nodel_label: str, max_depth: int
|
self, node_label: str, max_depth: int
|
||||||
) -> KnowledgeGraph:
|
) -> KnowledgeGraph:
|
||||||
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
return await self.chunk_entity_relation_graph.get_knowledge_graph(
|
||||||
node_label=nodel_label, max_depth=max_depth
|
node_label=node_label, max_depth=max_depth
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
|
||||||
|
@@ -139,11 +139,14 @@ async def hf_model_complete(
|
|||||||
|
|
||||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||||
device = next(embed_model.parameters()).device
|
device = next(embed_model.parameters()).device
|
||||||
input_ids = tokenizer(
|
encoded_texts = tokenizer(
|
||||||
texts, return_tensors="pt", padding=True, truncation=True
|
texts, return_tensors="pt", padding=True, truncation=True
|
||||||
).input_ids.to(device)
|
).to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = embed_model(input_ids)
|
outputs = embed_model(
|
||||||
|
input_ids=encoded_texts["input_ids"],
|
||||||
|
attention_mask=encoded_texts["attention_mask"],
|
||||||
|
)
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
if embeddings.dtype == torch.bfloat16:
|
if embeddings.dtype == torch.bfloat16:
|
||||||
return embeddings.detach().to(torch.float32).cpu().numpy()
|
return embeddings.detach().to(torch.float32).cpu().numpy()
|
||||||
|
Reference in New Issue
Block a user