magnet.aggmodels.gnn.GNN#

class magnet.aggmodels.gnn.GNN(*args, **kwargs)#

Bases: Module

Abstract base class for Graph Neural networks for mesh agglomeration.

Constructor

__init__(*args, **kwargs) None#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods

get_number_of_parameters()

Get total number of parameters of the GNN.

get_sample(mesh)

Get graph data from mesh for training.

load_model(model_path)

Load model from state dictionary.

save_model(output_path)

Save current model to state dictionary.

get_number_of_parameters() int#

Get total number of parameters of the GNN.

Parameters:

None

Returns:

The number of parameters.

Return type:

int

abstract get_sample(mesh: Mesh) Data#

Get graph data from mesh for training.

load_model(model_path: str) None#

Load model from state dictionary.

Parameters:

model_path (str) – The path of the .pt state dictionary file.

Return type:

None

save_model(output_path: str) None#

Save current model to state dictionary.

Parameters:

output_path (str) – The path where the state dictionary .pt file will be saved.

Return type:

None

Inherited Methods

__init__(*args, **kwargs)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Note

This class inherits from torch.nn.Module. To see the full list of inherited members, please see the Pytorch documentation.