magnet.aggmodels.gnn.GNN#
- class magnet.aggmodels.gnn.GNN(*args, **kwargs)#
Bases:
Module
Abstract base class for Graph Neural networks for mesh agglomeration.
Constructor
- __init__(*args, **kwargs) None #
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
Get total number of parameters of the GNN.
get_sample
(mesh)Get graph data from mesh for training.
load_model
(model_path)Load model from state dictionary.
save_model
(output_path)Save current model to state dictionary.
- get_number_of_parameters() int #
Get total number of parameters of the GNN.
- Parameters:
None
- Returns:
The number of parameters.
- Return type:
int
- load_model(model_path: str) None #
Load model from state dictionary.
- Parameters:
model_path (str) – The path of the .pt state dictionary file.
- Return type:
None
- save_model(output_path: str) None #
Save current model to state dictionary.
- Parameters:
output_path (str) – The path where the state dictionary .pt file will be saved.
- Return type:
None
Inherited Methods
__init__
(*args, **kwargs)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Note
This class inherits from
torch.nn.Module
. To see the full list of inherited members, please see the Pytorch documentation.