magnet.aggmodels.refiner.DRLRefiner#
- class magnet.aggmodels.refiner.DRLRefiner(*args, **kwargs)#
Bases:
ReinforceLearnGNN
Constructor
- __init__(*args, **kwargs) None #
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
A2C_train
(training_dataset[, ...])compute_episode_length
(graph)cut
(graph)get_sample
(mesh[, k, partitioner])create a graph data structure sample from a mesh.
k_hop_graph_cut
(graph, k)Exrtact k-hop subgraph around the current cut fo refinement.
objective
(graph, starter)reward_function
(new_state, old_state, ...[, ...])Modified normalized cut to take into account cell volumes instead
update_state
(graph, action, nnz)volumes
(graph)- A2C_train(training_dataset: MeshDataset, time_to_sample: int = 8, epochs: int = 1, gamma: float = 0.9, alpha: float = 0.1, partitioner: AgglomerationModel | None = None, lr: float = 0.001, optimizer=None, **kwargs)#
- compute_episode_length(graph: Data) int #
- cut(graph)#
- get_sample(mesh: Mesh | Data, k: int = 3, partitioner: AgglomerationModel | None = None, **kwargs)#
create a graph data structure sample from a mesh.
This is used for both training and running the GNN.
- Parameters:
mesh (Mesh) – Mesh to be sampled.
randomRotate (bool, optional) – If True, randomly rotate the mesh (default is False).
selfloop (bool, optional) – If True, add 1 on the diagonal of the adjacency matrix, i.e self-loops on the graph (default is False).
- Returns:
Graph data representing the mesh.
- Return type:
Data
Notes
The two tensors x and edge_index are both on DEVICE (cuda, if available).
- k_hop_graph_cut(graph: Data, k: int)#
Exrtact k-hop subgraph around the current cut fo refinement.
- objective(graph, starter)#
- reward_function(new_state: Data, old_state: Data, new_starter: Data, old_starter: Data, action: int, disc_penalty=0, cerchiobottismo=0, **kwargs)#
Modified normalized cut to take into account cell volumes instead
- update_state(graph: Data, action: int, nnz)#
- volumes(graph)#
Inherited Methods
__init__
(*args, **kwargs)Initialize internal Module state, shared by both nn.Module and ScriptModule.
agglomerate
(mesh[, mode, nref, mult_factor])Agglomerate a mesh.
agglomerate_dataset
(dataset, **kwargs)Agglomerate all meshes in a dataset.
bisect
(mesh)Bisect the mesh once.
bisection_Nref
(mesh, Nref[, warm_start])Bisect the mesh recursively a set number of times.
bisection_mult_factor
(mesh, mult_factor[, ...])Bisect a mesh until the agglomerated elements are small enough.
bisection_segregated
(mesh, mult_factor[, subset])Bisect heterogeneous mesh until elements are small enough.
change_vert
(graph, action)In place change of vertex to other subgraph.
coarsen
(mesh, subset[, mode, nref, mult_factor])Coarsen a subregion of the mesh.
get_number_of_parameters
()Get total number of parameters of the GNN.
load_model
(model_path)Load model from state dictionary.
loss_function
(output, graph)Loss function used during training.
multilevel_bisection
(mesh[, refiner, ...])normalize
(x)Normalize the data before feeding it to the GNN.
save_model
(output_path)Save current model to state dictionary.
train_GNN
(training_dataset, ...[, ...])Train the Graph Neural Network.