magnet.aggmodels.gnn.ReinforceLearnGNN#
- class magnet.aggmodels.gnn.ReinforceLearnGNN(*args, **kwargs)#
Bases:
GeometricGNNReinforcement Learning GNN for graph partitioning.
Constructor
- __init__(*args, **kwargs) None#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
A2C_train(training_dataset[, batch_size, ...])change_vert(graph, action)In place change of vertex to other subgraph.
compute_episode_length(graph)reward_function(new_state, old_state, action)update_state(graph, action)- A2C_train(training_dataset: MeshDataset, batch_size: int = 1, epochs: int = 1, gamma: float = 0.9, alpha: float = 0.1, optimizer=None, **kwargs)#
- change_vert(graph: Data, action: int)#
In place change of vertex to other subgraph.
- abstract compute_episode_length(graph: Data) int#
- abstract reward_function(new_state: Data, old_state: Data, action: int) Tensor#
- abstract update_state(graph: Data, action: int) Data#
Inherited Methods
__init__(*args, **kwargs)Initialize internal Module state, shared by both nn.Module and ScriptModule.
agglomerate(mesh[, mode, nref, mult_factor])Agglomerate a mesh.
agglomerate_dataset(dataset, **kwargs)Agglomerate all meshes in a dataset.
bisect(mesh)Bisect the mesh once.
bisection_Nref(mesh, Nref[, warm_start])Bisect the mesh recursively a set number of times.
bisection_mult_factor(mesh, mult_factor[, ...])Bisect a mesh until the agglomerated elements are small enough.
bisection_segregated(mesh, mult_factor[, subset])Bisect heterogeneous mesh until elements are small enough.
coarsen(mesh, subset[, mode, nref, mult_factor])Coarsen a subregion of the mesh.
get_number_of_parameters()Get total number of parameters of the GNN.
get_sample(mesh[, randomRotate, selfloop, ...])create a graph data structure sample from a mesh.
load_model(model_path)Load model from state dictionary.
loss_function(output, graph)Loss function used during training.
multilevel_bisection(mesh[, refiner, ...])normalize(x)Normalize the data before feeding it to the GNN.
save_model(output_path)Save current model to state dictionary.
train_GNN(training_dataset, ...[, ...])Train the Graph Neural Network.