Trainers

class Trainer(device=None, json_log=False)[source]

Bases: object

Trainer class that defines a number of useful methods for training an autoencoder.

Variables:
  • autoencoder – any torch.nn.module network that has methods autoencoder.encode and autoencoder.decode with the weights associated with these operations accessible via autoencoder.encoder and autoencoder.decoder. This can be set using set_autoencoder

  • _autoencoder_kwargs – kwargs used to initialise the network. Saved in every checkpoint under the key ‘kwargs’

  • optimiser (torch.optim.optimiser) – pytorch optimiser with access to self.autoencoder.parameters()

  • device (torch.Device) – The device used for all operations.

  • epoch (int) – the current epoch

  • progress (TrainerProgress) – Tracks best validation loss, checkpoint metadata, and convergence state.

  • std (float) – Standard deviation of the training dataset. Can be used to unscale structures produced by the network.

  • mol (float) – Biobox molecule containing a single example frame of the protein being trained on. This can be used to save examples during training. It is also used to save a temporary pdb that may be used to initialise thirdparty packages.

  • train_dataloader (torch.Dataloader) – Training data

  • valid_dataloader (torch.Dataloader) – Validation data

  • _data – (molearn.data Data object given to set_data

Parameters:
  • device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()

  • log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to

  • json_log (bool) – True to use json.dump to save the log file

checkpoint(epoch, valid_logs, checkpoint_folder, is_best, has_converged=False)[source]

Checkpoint the current network. The checkpoint will be saved as 'last.ckpt'. If valid_logs[loss_key] is better than self.progress.best_loss then this checkpoint will replace self.progress.best_checkpoint and 'last.ckpt' will be renamed to f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt' and the former best (filename saved as self.progress.best_checkpoint) will be deleted

Parameters:
  • epoch (int) – current epoch, will be saved within the ckpt. Current epoch can usually be obtained with self.epoch

  • valid_logs (dict) – results dictionary containing loss_key.

  • checkpoint_folder (str) – The folder in which to save the checkpoint.

  • is_best (bool) – if the current checkpoint is the best so far, this will be used to determine if the checkpoint should be renamed to f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'.

  • has_converged (bool) – if True, the checkpoint will be saved as f'{checkpoint_folder}/checkpoint_converged.ckpt'. This is used to indicate that training has converged.

common_step(batch)[source]

Called from both train_step and valid_step. Calculates the mean squared error loss for self.autoencoder. Encoded and decoded frames are saved in self._internal under keys encoded and decoded respectively should you wish to use them elsewhere.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, Number of Atoms, 3] A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return calculated mse_loss

Return type:

dict

dry_run(log_filename='dry_run.dat', log_folder='tmp_folder')[source]

Pass the training and validation set through the network once without updating model parameters. This is useful for debugging or restoring statistics for models loaded from checkpoint.

Parameters:
  • log_filename (str) – the name of the log file.

  • log_folder (str) – the folder in which the log file will be saved.

fit(data: PDBData, autoencoder, *, autoencoder_kwargs: Dict[str, Any] | None = None, optimiser_kwargs: Dict[str, Any] | None = None, data_kwargs: Dict[str, Any] | None = None, epochs: int | None = None, patience: int | None = None, log_filename: str = 'log.dat', log_folder: str = 'checkpoint_folder', checkpoint_folder: str = 'checkpoint_folder', verbose: bool | None = None) FitResult[source]

End-to-end helper that wires data, model, optimiser, and training.

Returns:

FitResult describing the executed training procedure.

get_network_summary()[source]

returns a dictionary containing information about the size of the autoencoder.

get_scale(ref_loss: float, tar_loss: float, scale_scale: float = 1.0)[source]

Get a scaling factor to scale a loss to be in the same order of magnitude like cur_mse_loss

Parameters:
  • target_loss (float) – the reference loss

  • tar_loss (float) – the loss that should be scaled to be in the same order of magnitude as the target_loss

  • scale_scale (float) – scale to in-/ decrease the scale further

Return float scaling_factor:

the calculated scaling factor for the loss_to_scale

load_checkpoint(checkpoint_name='best', checkpoint_folder='', load_optimiser=True)[source]

Load checkpoint.

Parameters:
  • checkpoint_name (str) – (default: 'best') if 'best' then checkpoint_folder is searched for all files beginning with 'checkpoint_' and loss values are extracted from the filename by assuming all characters after 'loss' and before '.ckpt' are a float. The checkpoint with the lowest loss is loaded. checkpoint_name is not 'best' we search for a checkpoint file at f'{checkpoint_folder}/{checkpoint_name}'.

  • checkpoint_folder (str) – Folder whithin which to search for checkpoints.

  • load_optimiser (bool) – (default: True) Should optimiser state dictionary be loaded.

log(log_file, log_dict, verbose=None)[source]

Then contents of log_dict are dumped using json.dumps(log_dict) and printed and/or appended to self.log_filename This function is called from self.run

Parameters:
  • log_file (str) – file to which the log_dict will be saved. If the file does not exist it will be created.

  • log_dict (dict) – dictionary to be printed or saved

  • verbose (bool) – if True or self.verbose is true the output will be printed

prepare_logs(log_filename, log_folder)[source]
Parameters:
  • log_filename (str) – (default: ‘log.dat’) The name of the log file.

  • log_folder (str) – (default: ‘checkpoint_folder’) The folder in which the log file will be saved.

Returns:

The full path to the log file.

prepare_optimiser(lr=0.001, weight_decay=0.0001, **optimiser_kwargs)[source]

The Default optimiser is AdamW and is saved in self.optimiser. With no optional arguments this function is the same as doing: trainer.optimiser = torch.optim.AdawW(self.autoencoder.parameters(), lr=1e-3, weight_decay = 0.0001)

Parameters:
  • lr (float) – (default: 1e-3) optimiser learning rate.

  • weight_decay (float) – (default: 0.0001) optimiser weight_decay

  • **optimiser_kwargs – other kwargs that are passed onto AdamW

run(epochs=100, log_filename='log.dat', log_folder='checkpoint_folder', checkpoint_folder='checkpoint_folder', verbose=None)[source]

Calls the following in a loop:

param int epochs:

(default: 100). run this many epochs.

param str log_filename:

(default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created.

param str log_folder:

(default: None) If not None log_folder directory is created and the log file is saved within this folder

param int checkpoint_frequency:

(default: 1) The frequency at which last.ckpt is saved. A checkpoint is saved every epoch if 'valid_loss' is lower else when self.epoch is divisible by checkpoint_frequency.

param str checkpoint_folder:

(default: ‘checkpoint_folder’) Where to save checkpoints.

param bool verbose:

(default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename

Returns:

FitResult with information about the training run.

run_until_converge(patience=16, log_filename='log.dat', log_folder='checkpoint_folder', checkpoint_folder='checkpoint_folder', verbose=None)[source]

Train until convergence. This method will call Trainer.run in a loop until the training has converged. The training will stop when the validation loss has not improved for patience epochs.

param int patience:

how many epochs to wait before stopping training if the validation loss has not improved.

param str log_filename:

the name of the log file.

param str log_folder:

the folder in which the log file will be saved.

param str checkpoint_folder:

the folder in which the checkpoint will be saved.

param bool verbose:

if True, the epoch logs will be printed as well as written to log_filename

Returns:

FitResult with information about the training run.

set_autoencoder(autoencoder, **kwargs)[source]
Parameters:
  • autoencoder – (autoencoder,) torch network class that implements autoencoder.encode, and autoencoder.decode. Please pass the class not the instance

  • **kwargs – any other kwargs given to this method will be used to initialise the network self.autoencoder = autoencoder(**kwargs)

set_data(data, **kwargs)[source]

Sets up internal variables and gives trainer access to dataloaders. self.train_dataloader, self.valid_dataloader, self.std, self.mean, self.mol will all be obtained from this object.

:param PDBData data: data object to be set. :param **kwargs: will be passed on to data.get_dataloader(**kwargs)

set_dataloader(train_dataloader=None, valid_dataloader=None)[source]
Parameters:
  • train_dataloader (torch.DataLoader) – Alternatively set using trainer.train_dataloader = dataloader

  • valid_dataloader (torch.DataLoader) – Alternatively set using trainer.valid_dataloader = dataloader

train_epoch(dry_run=False)[source]

Train one epoch. Called once an epoch from trainer.run This method performs the following functions: - Sets network to train mode via self.autoencoder.train() - for each batch in self.train_dataloader implements typical pytorch training protocol:

  • zero gradients with call self.optimiser.zero_grad()

  • Use training implemented in trainer.train_step result = self.train_step(batch)

  • Determine gradients using keyword 'loss' e.g. result['loss'].backward()

  • Update network gradients. self.optimiser.step

  • All results are aggregated via averaging and returned with 'train_' prepended on the dictionary key

Parameters:

epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.

Returns:

Return all results from train_step averaged. These results will be printed and/or logged in trainer.run() via a call to self.log(results)

Return type:

dict

train_step(batch)[source]

Called from Trainer.train_epoch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, Number of Atoms, 3]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that self.train_epoch will call result['loss'].backwards() to obtain gradients.

Return type:

dict

update_optimiser_hyperparameters(**kwargs)[source]

Update optimeser hyperparameter e.g. trainer.update_optimiser_hyperparameters(lr = 1e3)

Parameters:

**kwargs – each key value pair in **kwargs is inserted into self.optimiser

valid_epoch()[source]

Called once an epoch from trainer.run within a no_grad context. This method performs the following functions: - Sets network to eval mode via self.autoencoder.eval() - for each batch in self.valid_dataloader calls trainer.valid_step to retrieve validation loss - All results are aggregated via averaging and returned with 'valid_' prepended on the dictionary key

  • The loss with key 'loss' is returned as 'valid_loss' this will be the loss value by which the best checkpoint is determined.

Parameters:

epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.

Returns:

Return all results from valid_step averaged. These results will be printed and/or logged in Trainer.run() via a call to self.log(results)

Return type:

dict

valid_step(batch)[source]

Called from Trainer.valid_epoch on every mini-batch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that will be the score via which the best checkpoint is determined.

Return type:

dict

class OpenMM_Physics_Trainer(physics_inter_weight=0, *args, **kwargs)[source]

Bases: Trainer

OpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ is calculated using OpenMM and the forces are inserted into backwards pass. To use this trainer requires the additional step of calling prepare_physics.

Parameters:
  • device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()

  • log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to

  • json_log (bool) – True to use json.dump to save the log file

common_physics_step(batch, latent)[source]

Called from both train_step and valid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.

Parameters:
  • batch (torch.Tensor) – tensor of shape [batch_size, n_atoms, 3]. Give access to the mini-batch of structures. This is used to determine n_atoms

  • latent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.

prepare_physics(clamp_threshold=100000000.0, clamp=False, xml_file=None, soft_NB=True, **kwargs)[source]

Create self.physics_loss object from loss_functions.openmm_energy Needs self.mol, self.std, and self._data.atoms to have been set with Trainer.set_data

Parameters:
  • physics_scaling_factor (float) – scaling factor saved to self.psf that is used in train_step. Defaults to 0.1

  • clamp_threshold (float) – if clamp=True is passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8

  • clamp (bool) – Whether to clamp the forces. Defaults to False

  • **kwargs – All aditional kwargs will be passed to openmm_energy

train_step(batch)[source]

This method overrides Trainer.train_step and adds an additional ‘Physics_loss’ term. Called from Trainer.train_epoch.

Parameters:

batch (torch.Tensor) – tensor shape [Batch size, Number of Atoms, 3]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that self.train_epoch will call result['loss'].backwards() to obtain gradients.

Return type:

dict

valid_step(batch)[source]

This method overrides Trainer.valid_step and adds an additional ‘Physics_loss’ term.

Differently to train_step this method sums the logs of mse_loss and physics_loss final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])

Called from super class Trainer.valid_epoch on every mini-batch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, Number of Atoms, 3]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that will be the score via which the best checkpoint is determined.

Return type:

dict

class Torch_Physics_Trainer(*args, **kwargs)[source]

Bases: Trainer

Torch_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ (bonds, angles, and torsions) is calculated using pytorch. To use this trainer requires the additional step of calling :func: prepare_physics <molearn.trainers.Torch_Physics_Trainer>.

Parameters:
  • device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()

  • log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to

  • json_log (bool) – True to use json.dump to save the log file

common_physics_step(batch, latent)[source]

Called from both train_step and valid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.

Parameters:
  • batch (torch.Tensor) – tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine n_atoms

  • latent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.

prepare_physics(physics_scaling_factor=0.1)[source]

Create self.physics_loss object from loss_functions.TorchProteinEnergy Needs self.std, self._data to have been set with Trainer.set_data :param float physics_scaling_factor: (default: 0.1) scaling factor saved to self.psf that is used in :func: train_step <molearn.trainers.Torch_Physics_Trainer.train_step> It will control the relative importance of mse_loss and physics_loss in training.

train_step(batch)[source]

This method overrides Trainer.train_step and adds an additional ‘Physics_loss’ term.

Mse_loss and physics loss are summed (Mse_loss + scale*physics_loss)with a scaling factor self.psf*mse_loss/Physics_loss. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a torch.no_grad context manager the gradients are not computed. This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.

Called from Trainer.train_epoch.

Parameters:

batch (torch.Tensor) – tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that self.train_epoch will call result['loss'].backwards() to obtain gradients.

Return type:

dict

valid_step(batch)[source]

This method overrides Trainer.valid_step and adds an additional ‘Physics_loss’ term.

Differently to train_step this method sums the logs of mse_loss and physics_loss final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])

Called from super class Trainer.valid_epoch on every mini-batch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that will be the score via which the best checkpoint is determined.

Return type:

dict