Trainers¶
- class Trainer(device=None, json_log=False)[source]¶
Bases:
objectTrainer class that defines a number of useful methods for training an autoencoder.
- Variables:
autoencoder – any torch.nn.module network that has methods
autoencoder.encodeandautoencoder.decodewith the weights associated with these operations accessible viaautoencoder.encoderandautoencoder.decoder. This can be set using set_autoencoder_autoencoder_kwargs – kwargs used to initialise the network. Saved in every checkpoint under the key ‘kwargs’
optimiser (torch.optim.optimiser) – pytorch optimiser with access to self.autoencoder.parameters()
device (torch.Device) – The device used for all operations.
epoch (int) – the current epoch
progress (TrainerProgress) – Tracks best validation loss, checkpoint metadata, and convergence state.
std (float) – Standard deviation of the training dataset. Can be used to unscale structures produced by the network.
mol (float) – Biobox molecule containing a single example frame of the protein being trained on. This can be used to save examples during training. It is also used to save a temporary pdb that may be used to initialise thirdparty packages.
train_dataloader (torch.Dataloader) – Training data
valid_dataloader (torch.Dataloader) – Validation data
_data – (
molearn.dataData object given toset_data
- Parameters:
device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()
log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to
json_log (bool) – True to use json.dump to save the log file
- checkpoint(epoch, valid_logs, checkpoint_folder, is_best, has_converged=False)[source]¶
Checkpoint the current network. The checkpoint will be saved as
'last.ckpt'. If valid_logs[loss_key] is better than self.progress.best_loss then this checkpoint will replace self.progress.best_checkpoint and'last.ckpt'will be renamed tof'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'and the former best (filename saved asself.progress.best_checkpoint) will be deleted- Parameters:
epoch (int) – current epoch, will be saved within the ckpt. Current epoch can usually be obtained with
self.epochvalid_logs (dict) – results dictionary containing loss_key.
checkpoint_folder (str) – The folder in which to save the checkpoint.
is_best (bool) – if the current checkpoint is the best so far, this will be used to determine if the checkpoint should be renamed to
f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'.has_converged (bool) – if True, the checkpoint will be saved as
f'{checkpoint_folder}/checkpoint_converged.ckpt'. This is used to indicate that training has converged.
- common_step(batch)[source]¶
Called from both train_step and valid_step. Calculates the mean squared error loss for self.autoencoder. Encoded and decoded frames are saved in self._internal under keys
encodedanddecodedrespectively should you wish to use them elsewhere.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, Number of Atoms, 3] A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return calculated mse_loss
- Return type:
dict
- dry_run(log_filename='dry_run.dat', log_folder='tmp_folder')[source]¶
Pass the training and validation set through the network once without updating model parameters. This is useful for debugging or restoring statistics for models loaded from checkpoint.
- Parameters:
log_filename (str) – the name of the log file.
log_folder (str) – the folder in which the log file will be saved.
- fit(data: PDBData, autoencoder, *, autoencoder_kwargs: Dict[str, Any] | None = None, optimiser_kwargs: Dict[str, Any] | None = None, data_kwargs: Dict[str, Any] | None = None, epochs: int | None = None, patience: int | None = None, log_filename: str = 'log.dat', log_folder: str = 'checkpoint_folder', checkpoint_folder: str = 'checkpoint_folder', verbose: bool | None = None) FitResult[source]¶
End-to-end helper that wires data, model, optimiser, and training.
- Returns:
FitResultdescribing the executed training procedure.
- get_network_summary()[source]¶
returns a dictionary containing information about the size of the autoencoder.
- get_scale(ref_loss: float, tar_loss: float, scale_scale: float = 1.0)[source]¶
Get a scaling factor to scale a loss to be in the same order of magnitude like cur_mse_loss
- Parameters:
target_loss (float) – the reference loss
tar_loss (float) – the loss that should be scaled to be in the same order of magnitude as the target_loss
scale_scale (float) – scale to in-/ decrease the scale further
- Return float scaling_factor:
the calculated scaling factor for the loss_to_scale
- load_checkpoint(checkpoint_name='best', checkpoint_folder='', load_optimiser=True)[source]¶
Load checkpoint.
- Parameters:
checkpoint_name (str) – (default:
'best') if'best'then checkpoint_folder is searched for all files beginning with'checkpoint_'and loss values are extracted from the filename by assuming all characters after'loss'and before'.ckpt'are a float. The checkpoint with the lowest loss is loaded. checkpoint_name is not'best'we search for a checkpoint file atf'{checkpoint_folder}/{checkpoint_name}'.checkpoint_folder (str) – Folder whithin which to search for checkpoints.
load_optimiser (bool) – (default: True) Should optimiser state dictionary be loaded.
- log(log_file, log_dict, verbose=None)[source]¶
Then contents of log_dict are dumped using
json.dumps(log_dict)and printed and/or appended toself.log_filenameThis function is called fromself.run- Parameters:
log_file (str) – file to which the log_dict will be saved. If the file does not exist it will be created.
log_dict (dict) – dictionary to be printed or saved
verbose (bool) – if True or self.verbose is true the output will be printed
- prepare_logs(log_filename, log_folder)[source]¶
- Parameters:
log_filename (str) – (default: ‘log.dat’) The name of the log file.
log_folder (str) – (default: ‘checkpoint_folder’) The folder in which the log file will be saved.
- Returns:
The full path to the log file.
- prepare_optimiser(lr=0.001, weight_decay=0.0001, **optimiser_kwargs)[source]¶
The Default optimiser is
AdamWand is saved inself.optimiser. With no optional arguments this function is the same as doing:trainer.optimiser = torch.optim.AdawW(self.autoencoder.parameters(), lr=1e-3, weight_decay = 0.0001)- Parameters:
lr (float) – (default: 1e-3) optimiser learning rate.
weight_decay (float) – (default: 0.0001) optimiser weight_decay
**optimiser_kwargs – other kwargs that are passed onto AdamW
- run(epochs=100, log_filename='log.dat', log_folder='checkpoint_folder', checkpoint_folder='checkpoint_folder', verbose=None)[source]¶
Calls the following in a loop:
- param int epochs:
(default: 100). run this many epochs.
- param str log_filename:
(default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created.
- param str log_folder:
(default: None) If not None log_folder directory is created and the log file is saved within this folder
- param int checkpoint_frequency:
(default: 1) The frequency at which last.ckpt is saved. A checkpoint is saved every epoch if
'valid_loss'is lower else whenself.epochis divisible by checkpoint_frequency.- param str checkpoint_folder:
(default: ‘checkpoint_folder’) Where to save checkpoints.
- param bool verbose:
(default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename
- Returns:
FitResultwith information about the training run.
- run_until_converge(patience=16, log_filename='log.dat', log_folder='checkpoint_folder', checkpoint_folder='checkpoint_folder', verbose=None)[source]¶
Train until convergence. This method will call
Trainer.runin a loop until the training has converged. The training will stop when the validation loss has not improved forpatienceepochs.- param int patience:
how many epochs to wait before stopping training if the validation loss has not improved.
- param str log_filename:
the name of the log file.
- param str log_folder:
the folder in which the log file will be saved.
- param str checkpoint_folder:
the folder in which the checkpoint will be saved.
- param bool verbose:
if True, the epoch logs will be printed as well as written to log_filename
- Returns:
FitResultwith information about the training run.
- set_autoencoder(autoencoder, **kwargs)[source]¶
- Parameters:
autoencoder – (
autoencoder,) torch network class that implementsautoencoder.encode, andautoencoder.decode. Please pass the class not the instance**kwargs – any other kwargs given to this method will be used to initialise the network
self.autoencoder = autoencoder(**kwargs)
- set_data(data, **kwargs)[source]¶
Sets up internal variables and gives trainer access to dataloaders.
self.train_dataloader,self.valid_dataloader,self.std,self.mean,self.molwill all be obtained from this object.:param
PDBDatadata: data object to be set. :param **kwargs: will be passed on todata.get_dataloader(**kwargs)
- set_dataloader(train_dataloader=None, valid_dataloader=None)[source]¶
- Parameters:
train_dataloader (torch.DataLoader) – Alternatively set using
trainer.train_dataloader = dataloadervalid_dataloader (torch.DataLoader) – Alternatively set using
trainer.valid_dataloader = dataloader
- train_epoch(dry_run=False)[source]¶
Train one epoch. Called once an epoch from
trainer.runThis method performs the following functions: - Sets network to train mode viaself.autoencoder.train()- for each batch in self.train_dataloader implements typical pytorch training protocol:zero gradients with call
self.optimiser.zero_grad()Use training implemented in trainer.train_step
result = self.train_step(batch)Determine gradients using keyword
'loss'e.g.result['loss'].backward()Update network gradients.
self.optimiser.step
All results are aggregated via averaging and returned with
'train_'prepended on the dictionary key
- Parameters:
epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
- Returns:
Return all results from train_step averaged. These results will be printed and/or logged in
trainer.run()via a call toself.log(results)- Return type:
dict
- train_step(batch)[source]¶
Called from
Trainer.train_epoch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, Number of Atoms, 3]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'thatself.train_epochwill callresult['loss'].backwards()to obtain gradients.- Return type:
dict
- update_optimiser_hyperparameters(**kwargs)[source]¶
Update optimeser hyperparameter e.g.
trainer.update_optimiser_hyperparameters(lr = 1e3)- Parameters:
**kwargs – each key value pair in **kwargs is inserted into
self.optimiser
- valid_epoch()[source]¶
Called once an epoch from
trainer.runwithin a no_grad context. This method performs the following functions: - Sets network to eval mode viaself.autoencoder.eval()- for each batch inself.valid_dataloadercallstrainer.valid_stepto retrieve validation loss - All results are aggregated via averaging and returned with'valid_'prepended on the dictionary keyThe loss with key
'loss'is returned as'valid_loss'this will be the loss value by which the best checkpoint is determined.
- Parameters:
epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
- Returns:
Return all results from valid_step averaged. These results will be printed and/or logged in
Trainer.run()via a call toself.log(results)- Return type:
dict
- valid_step(batch)[source]¶
Called from
Trainer.valid_epochon every mini-batch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'that will be the score via which the best checkpoint is determined.- Return type:
dict
- class OpenMM_Physics_Trainer(physics_inter_weight=0, *args, **kwargs)[source]¶
Bases:
TrainerOpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ is calculated using OpenMM and the forces are inserted into backwards pass. To use this trainer requires the additional step of calling
prepare_physics.- Parameters:
device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()
log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to
json_log (bool) – True to use json.dump to save the log file
- common_physics_step(batch, latent)[source]¶
Called from both
train_stepandvalid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved asself._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.- Parameters:
batch (torch.Tensor) – tensor of shape [batch_size, n_atoms, 3]. Give access to the mini-batch of structures. This is used to determine
n_atomslatent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
- prepare_physics(clamp_threshold=100000000.0, clamp=False, xml_file=None, soft_NB=True, **kwargs)[source]¶
Create
self.physics_lossobject fromloss_functions.openmm_energyNeedsself.mol,self.std, andself._data.atomsto have been set withTrainer.set_data- Parameters:
physics_scaling_factor (float) – scaling factor saved to
self.psfthat is used intrain_step. Defaults to 0.1clamp_threshold (float) – if
clamp=Trueis passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8clamp (bool) – Whether to clamp the forces. Defaults to False
**kwargs – All aditional kwargs will be passed to
openmm_energy
- train_step(batch)[source]¶
This method overrides
Trainer.train_stepand adds an additional ‘Physics_loss’ term. Called fromTrainer.train_epoch.- Parameters:
batch (torch.Tensor) – tensor shape [Batch size, Number of Atoms, 3]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'thatself.train_epochwill callresult['loss'].backwards()to obtain gradients.- Return type:
dict
- valid_step(batch)[source]¶
This method overrides
Trainer.valid_stepand adds an additional ‘Physics_loss’ term.Differently to
train_stepthis method sums the logs of mse_loss and physics_lossfinal_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])Called from super class
Trainer.valid_epochon every mini-batch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, Number of Atoms, 3]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'that will be the score via which the best checkpoint is determined.- Return type:
dict
- class Torch_Physics_Trainer(*args, **kwargs)[source]¶
Bases:
TrainerTorch_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ (bonds, angles, and torsions) is calculated using pytorch. To use this trainer requires the additional step of calling :func: prepare_physics <molearn.trainers.Torch_Physics_Trainer>.
- Parameters:
device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()
log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to
json_log (bool) – True to use json.dump to save the log file
- common_physics_step(batch, latent)[source]¶
Called from both
train_stepandvalid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved asself._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.- Parameters:
batch (torch.Tensor) – tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine
n_atomslatent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
- prepare_physics(physics_scaling_factor=0.1)[source]¶
Create
self.physics_lossobject fromloss_functions.TorchProteinEnergyNeedsself.std,self._datato have been set withTrainer.set_data:param float physics_scaling_factor: (default: 0.1) scaling factor saved toself.psfthat is used in :func: train_step <molearn.trainers.Torch_Physics_Trainer.train_step> It will control the relative importance of mse_loss and physics_loss in training.
- train_step(batch)[source]¶
This method overrides
Trainer.train_stepand adds an additional ‘Physics_loss’ term.Mse_loss and physics loss are summed (
Mse_loss + scale*physics_loss)with a scaling factorself.psf*mse_loss/Physics_loss. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within atorch.no_gradcontext manager the gradients are not computed. This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.Called from
Trainer.train_epoch.- Parameters:
batch (torch.Tensor) – tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'thatself.train_epochwill callresult['loss'].backwards()to obtain gradients.- Return type:
dict
- valid_step(batch)[source]¶
This method overrides
Trainer.valid_stepand adds an additional ‘Physics_loss’ term.Differently to
train_stepthis method sums the logs of mse_loss and physics_lossfinal_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])Called from super class
Trainer.valid_epochon every mini-batch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'that will be the score via which the best checkpoint is determined.- Return type:
dict