magnet.aggmodels.refiner.DRLRefiner#

class magnet.aggmodels.refiner.DRLRefiner(*args, **kwargs)#

Bases: ReinforceLearnGNN

Constructor

__init__(*args, **kwargs) None#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods

A2C_train(training_dataset[, ...])

compute_episode_length(graph)

cut(graph)

get_sample(mesh[, k, partitioner])

create a graph data structure sample from a mesh.

k_hop_graph_cut(graph, k)

Exrtact k-hop subgraph around the current cut fo refinement.

objective(graph, starter)

reward_function(new_state, old_state, ...[, ...])

Modified normalized cut to take into account cell volumes instead

update_state(graph, action, nnz)

volumes(graph)

A2C_train(training_dataset: MeshDataset, time_to_sample: int = 8, epochs: int = 1, gamma: float = 0.9, alpha: float = 0.1, partitioner: AgglomerationModel | None = None, lr: float = 0.001, optimizer=None, **kwargs)#
compute_episode_length(graph: Data) int#
cut(graph)#
get_sample(mesh: Mesh | Data, k: int = 3, partitioner: AgglomerationModel | None = None, **kwargs)#

create a graph data structure sample from a mesh.

This is used for both training and running the GNN.

Parameters:
  • mesh (Mesh) – Mesh to be sampled.

  • randomRotate (bool, optional) – If True, randomly rotate the mesh (default is False).

  • selfloop (bool, optional) – If True, add 1 on the diagonal of the adjacency matrix, i.e self-loops on the graph (default is False).

Returns:

Graph data representing the mesh.

Return type:

Data

Notes

The two tensors x and edge_index are both on DEVICE (cuda, if available).

k_hop_graph_cut(graph: Data, k: int)#

Exrtact k-hop subgraph around the current cut fo refinement.

objective(graph, starter)#
reward_function(new_state: Data, old_state: Data, new_starter: Data, old_starter: Data, action: int, disc_penalty=0, cerchiobottismo=0, **kwargs)#

Modified normalized cut to take into account cell volumes instead

update_state(graph: Data, action: int, nnz)#
volumes(graph)#

Inherited Methods

__init__(*args, **kwargs)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

agglomerate(mesh[, mode, nref, mult_factor])

Agglomerate a mesh.

agglomerate_dataset(dataset, **kwargs)

Agglomerate all meshes in a dataset.

bisect(mesh)

Bisect the mesh once.

bisection_Nref(mesh, Nref[, warm_start])

Bisect the mesh recursively a set number of times.

bisection_mult_factor(mesh, mult_factor[, ...])

Bisect a mesh until the agglomerated elements are small enough.

bisection_segregated(mesh, mult_factor[, subset])

Bisect heterogeneous mesh until elements are small enough.

change_vert(graph, action)

In place change of vertex to other subgraph.

coarsen(mesh, subset[, mode, nref, mult_factor])

Coarsen a subregion of the mesh.

get_number_of_parameters()

Get total number of parameters of the GNN.

load_model(model_path)

Load model from state dictionary.

loss_function(output, graph)

Loss function used during training.

multilevel_bisection(mesh[, refiner, ...])

normalize(x)

Normalize the data before feeding it to the GNN.

save_model(output_path)

Save current model to state dictionary.

train_GNN(training_dataset, ...[, ...])

Train the Graph Neural Network.