magnet.aggmodels.gnn.ReinforceLearnGNN#
- class magnet.aggmodels.gnn.ReinforceLearnGNN(*args, **kwargs)#
- Bases: - GeometricGNN- Reinforcement Learning GNN for graph partitioning. - Constructor - __init__(*args, **kwargs) None#
- Initialize internal Module state, shared by both nn.Module and ScriptModule. 
 - Methods - A2C_train(training_dataset[, batch_size, ...])- change_vert(graph, action)- In place change of vertex to other subgraph. - compute_episode_length(graph)- reward_function(new_state, old_state, action)- update_state(graph, action)- A2C_train(training_dataset: MeshDataset, batch_size: int = 1, epochs: int = 1, gamma: float = 0.9, alpha: float = 0.1, optimizer=None, **kwargs)#
 - change_vert(graph: Data, action: int)#
- In place change of vertex to other subgraph. 
 - abstract compute_episode_length(graph: Data) int#
 - abstract reward_function(new_state: Data, old_state: Data, action: int) Tensor#
 - abstract update_state(graph: Data, action: int) Data#
 - Inherited Methods - __init__(*args, **kwargs)- Initialize internal Module state, shared by both nn.Module and ScriptModule. - agglomerate(mesh[, mode, nref, mult_factor])- Agglomerate a mesh. - agglomerate_dataset(dataset, **kwargs)- Agglomerate all meshes in a dataset. - bisect(mesh)- Bisect the mesh once. - bisection_Nref(mesh, Nref[, warm_start])- Bisect the mesh recursively a set number of times. - bisection_mult_factor(mesh, mult_factor[, ...])- Bisect a mesh until the agglomerated elements are small enough. - bisection_segregated(mesh, mult_factor[, subset])- Bisect heterogeneous mesh until elements are small enough. - coarsen(mesh, subset[, mode, nref, mult_factor])- Coarsen a subregion of the mesh. - get_number_of_parameters()- Get total number of parameters of the GNN. - get_sample(mesh[, randomRotate, selfloop, ...])- create a graph data structure sample from a mesh. - load_model(model_path)- Load model from state dictionary. - loss_function(output, graph)- Loss function used during training. - multilevel_bisection(mesh[, refiner, ...])- normalize(x)- Normalize the data before feeding it to the GNN. - save_model(output_path)- Save current model to state dictionary. - train_GNN(training_dataset, ...[, ...])- Train the Graph Neural Network. 
