magnet.aggmodels.gnn.GeometricGNN#

class magnet.aggmodels.gnn.GeometricGNN(*args, **kwargs)#

Bases: GNN, AgglomerationModel

Constructor

__init__(*args, **kwargs) None#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods

get_sample(mesh[, randomRotate, selfloop, ...])

create a graph data structure sample from a mesh.

loss_function(output, graph)

Loss function used during training.

normalize(x)

Normalize the data before feeding it to the GNN.

train_GNN(training_dataset, ...[, ...])

Train the Graph Neural Network.

get_sample(mesh: Mesh, randomRotate: bool = False, selfloop: bool = False, device=device(type='cpu')) Data#

create a graph data structure sample from a mesh.

This is used for both training and running the GNN.

Parameters:
  • mesh (Mesh) – Mesh to be sampled.

  • randomRotate (bool, optional) – If True, randomly rotate the mesh (default is False).

  • selfloop (bool, optional) – If True, add 1 on the diagonal of the adjacency matrix, i.e self-loops on the graph (default is False).

Returns:

Graph data representing the mesh.

Return type:

Data

Notes

The two tensors x and edge_index are both on DEVICE (cuda, if available).

loss_function(output: Tensor, graph: Data) Tensor#

Loss function used during training.

Parameters:
  • y (torch.Tensor) – Evaluation output of the Neural Network.

  • graph (Data) – Graph on which the GNN was evaluated.

Returns:

The value of the loss function.

Return type:

torch.Tensor

Notes

The loss function may be overridden in subclasses to customize the GNN. See the losses module for the actual definitions.

normalize(x: Tensor) Tensor#

Normalize the data before feeding it to the GNN.

Parameters:

x (torch.Tensor) – The data to be normalized.

Returns:

The normalized data.

Return type:

torch.Tensor

Notes

Normalization consists in aligning the widest direction of mesh to the x axis by rotating it and rescaling the coordinates to have zero mean and unit variance.

The output is returned on the the same torch device as the output of get_sample, i.e. DEVICE.

train_GNN(training_dataset: MeshDataset, validation_dataset: MeshDataset, epochs: int, batch_size: int, learning_rate: float = 0.0001, save_logs: bool = True) None#

Train the Graph Neural Network.

Parameters:
  • training_dataset (MeshDataset) – Dataset of meshes on which to train the GNN.

  • validation_dataset (MeshDataset) – Validation dataset to check that no overfitting is occurring.

  • epochs (int) – Number of training epochs.

  • batch_size (int) – Size of the minibatch to be used.

  • learning_rate (float, optional) – Initial learning rate for the scheduler (default is 1e-4).

  • save_logs (bool, optional) – If True, save the training and validation loss histories, the scheduled learning rate, their plots, plus a short summary (default is True).

Return type:

None

Inherited Methods

__init__(*args, **kwargs)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

agglomerate(mesh[, mode, nref, mult_factor])

Agglomerate a mesh.

agglomerate_dataset(dataset, **kwargs)

Agglomerate all meshes in a dataset.

bisect(mesh)

Bisect the mesh once.

bisection_Nref(mesh, Nref[, warm_start])

Bisect the mesh recursively a set number of times.

bisection_mult_factor(mesh, mult_factor[, ...])

Bisect a mesh until the agglomerated elements are small enough.

bisection_segregated(mesh, mult_factor[, subset])

Bisect heterogeneous mesh until elements are small enough.

coarsen(mesh, subset[, mode, nref, mult_factor])

Coarsen a subregion of the mesh.

get_number_of_parameters()

Get total number of parameters of the GNN.

load_model(model_path)

Load model from state dictionary.

multilevel_bisection(mesh[, refiner, ...])

save_model(output_path)

Save current model to state dictionary.