magnet.aggmodels.sageheterogeneous.GNNHeterogeneous#

class magnet.aggmodels.sageheterogeneous.GNNHeterogeneous(*args, **kwargs)#

Bases: GeometricGNN

Abstract base class for GNNs for agglomerating heterogeneous meshes.

Constructor

__init__(*args, **kwargs) None#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods

get_sample(mesh[, randomRotate, selfloop, ...])

Returns a graph data structure sample for training.

loss_function(y, graph[, coeff])

Loss function used during training.

normalize(x, edge_index)

Normalize the data before feeding it to the GNN.

get_sample(mesh: Mesh, randomRotate=False, selfloop=False, device=device(type='cpu')) Data#

Returns a graph data structure sample for training.

Parameters:
  • mesh (Mesh) – Heterogeneous mesh to be sampled.

  • randomRotate (bool, optional) – If True, randomly rotate the mesh (default is False).

  • selfloop (bool, optional) – if True, add 1 on the diagonal of the adjacency matrix, i.e self-loops on the graph (default is False).

Returns:

graph data representing the mesh.

Return type:

Data

loss_function(y: Tensor, graph: Data, coeff=1) Tensor#

Loss function used during training.

Parameters:
  • y (torch.Tensor) – Evaluation output of the Neural Network.

  • graph (Data) – Graph on which the GNN was evaluated.

Returns:

The value of the loss function.

Return type:

torch.Tensor

Notes

The loss function may be overridden in subclasses to customize the GNN. See the losses module for the actual definitions.

normalize(x: Tensor, edge_index: Tensor) Tensor#

Normalize the data before feeding it to the GNN.

Parameters:
  • x (torch.Tensor) – The data to be normalized.

  • edge_index (torch.Tensor) – Edge index tensor equivalent to the adjacency matrix of the mesh.

Returns:

The normalized data.

Return type:

torch.Tensor

Notes

Overridden implementation of GNN._normalize to also handle the physical group, which is averaged across neighbours to avoid discontinuities that may hamper the GNN learning.

Inherited Methods

__init__(*args, **kwargs)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

agglomerate(mesh[, mode, nref, mult_factor])

Agglomerate a mesh.

agglomerate_dataset(dataset, **kwargs)

Agglomerate all meshes in a dataset.

bisect(mesh)

Bisect the mesh once.

bisection_Nref(mesh, Nref[, warm_start])

Bisect the mesh recursively a set number of times.

bisection_mult_factor(mesh, mult_factor[, ...])

Bisect a mesh until the agglomerated elements are small enough.

bisection_segregated(mesh, mult_factor[, subset])

Bisect heterogeneous mesh until elements are small enough.

coarsen(mesh, subset[, mode, nref, mult_factor])

Coarsen a subregion of the mesh.

get_number_of_parameters()

Get total number of parameters of the GNN.

load_model(model_path)

Load model from state dictionary.

multilevel_bisection(mesh[, refiner, ...])

save_model(output_path)

Save current model to state dictionary.

train_GNN(training_dataset, ...[, ...])

Train the Graph Neural Network.

Note

This class inherits from torch.nn.Module. To see the full list of inherited members, please see the Pytorch documentation.