magnet.graph_utils.loss_heterogeneous_domains#
- magnet.graph_utils.loss_heterogeneous_domains(Y: Tensor, graph: Data) Tensor #
Compute loss function for heterogeneous domains.
- Parameters:
y (torch.Tensor) – Evaluation output of the Neural Network: tensor of shape (num_nodes, 2) whose values are the probabilities of belonging to one of two sets assigned by the GNN to node of the graph.
graph (Data) – Graph on which the GNN was evaluated.
- Returns:
The value of the loss function.
- Return type:
torch.Tensor
Notes
This loss function is the sum of the expected normalized cut and a term that penalizes the presence of very different physical groups in the same set of nodes, with suitable weights.