magnet.aggmodels.gnn.GeometricGNN#
- class magnet.aggmodels.gnn.GeometricGNN(*args, **kwargs)#
Bases:
GNN
,AgglomerationModel
Constructor
- __init__(*args, **kwargs) None #
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
get_sample
(mesh[, randomRotate, selfloop, ...])create a graph data structure sample from a mesh.
loss_function
(output, graph)Loss function used during training.
normalize
(x)Normalize the data before feeding it to the GNN.
train_GNN
(training_dataset, ...[, ...])Train the Graph Neural Network.
- get_sample(mesh: Mesh, randomRotate: bool = False, selfloop: bool = False, device=device(type='cpu')) Data #
create a graph data structure sample from a mesh.
This is used for both training and running the GNN.
- Parameters:
mesh (Mesh) – Mesh to be sampled.
randomRotate (bool, optional) – If True, randomly rotate the mesh (default is False).
selfloop (bool, optional) – If True, add 1 on the diagonal of the adjacency matrix, i.e self-loops on the graph (default is False).
- Returns:
Graph data representing the mesh.
- Return type:
Data
Notes
The two tensors x and edge_index are both on DEVICE (cuda, if available).
- loss_function(output: Tensor, graph: Data) Tensor #
Loss function used during training.
- Parameters:
y (torch.Tensor) – Evaluation output of the Neural Network.
graph (Data) – Graph on which the GNN was evaluated.
- Returns:
The value of the loss function.
- Return type:
torch.Tensor
Notes
The loss function may be overridden in subclasses to customize the GNN. See the losses module for the actual definitions.
- normalize(x: Tensor) Tensor #
Normalize the data before feeding it to the GNN.
- Parameters:
x (torch.Tensor) – The data to be normalized.
- Returns:
The normalized data.
- Return type:
torch.Tensor
Notes
Normalization consists in aligning the widest direction of mesh to the x axis by rotating it and rescaling the coordinates to have zero mean and unit variance.
The output is returned on the the same torch device as the output of get_sample, i.e. DEVICE.
- train_GNN(training_dataset: MeshDataset, validation_dataset: MeshDataset, epochs: int, batch_size: int, learning_rate: float = 0.0001, save_logs: bool = True) None #
Train the Graph Neural Network.
- Parameters:
training_dataset (MeshDataset) – Dataset of meshes on which to train the GNN.
validation_dataset (MeshDataset) – Validation dataset to check that no overfitting is occurring.
epochs (int) – Number of training epochs.
batch_size (int) – Size of the minibatch to be used.
learning_rate (float, optional) – Initial learning rate for the scheduler (default is 1e-4).
save_logs (bool, optional) – If True, save the training and validation loss histories, the scheduled learning rate, their plots, plus a short summary (default is True).
- Return type:
None
Inherited Methods
__init__
(*args, **kwargs)Initialize internal Module state, shared by both nn.Module and ScriptModule.
agglomerate
(mesh[, mode, nref, mult_factor])Agglomerate a mesh.
agglomerate_dataset
(dataset, **kwargs)Agglomerate all meshes in a dataset.
bisect
(mesh)Bisect the mesh once.
bisection_Nref
(mesh, Nref[, warm_start])Bisect the mesh recursively a set number of times.
bisection_mult_factor
(mesh, mult_factor[, ...])Bisect a mesh until the agglomerated elements are small enough.
bisection_segregated
(mesh, mult_factor[, subset])Bisect heterogeneous mesh until elements are small enough.
coarsen
(mesh, subset[, mode, nref, mult_factor])Coarsen a subregion of the mesh.
get_number_of_parameters
()Get total number of parameters of the GNN.
load_model
(model_path)Load model from state dictionary.
multilevel_bisection
(mesh[, refiner, ...])save_model
(output_path)Save current model to state dictionary.