Trainers

class Trainer(device=None, log_filename='log_file.dat')[source]

Bases: object

Trainer class that defines a number of useful methods for training an autoencoder.

Variables:
  • autoencoder – any torch.nn.module network that has methods autoencoder.encode and autoencoder.decode with the weights associated with these operations accessible via autoencoder.encoder and autoencoder.decoder. This can be set using set_autoencoder

  • _autoencoder_kwargs – kwargs used to initialise the network. Saved in every checkpoint under the key ‘kwargs’

  • optimiser (torch.optim.optimiser) – pytorch optimiser with access to self.autoencoder.parameters()

  • device (torch.Device) – The device used for all operations.

  • epoch (int) – the current epoch

  • best (float) – The best validation score corresponding to the current best checkpoint

  • best_name (float) – the filename corresponding to self.best

  • std (float) – Standard deviation of the training dataset. Can be used to unscale structures produced by the network.

  • mol (float) – Biobox molecule containing a single example frame of the protein being trained on. This can be used to save examples during training. It is also used to save a temporary pdb that may be used to initialise thirdparty packages.

  • train_dataloader (torch.Dataloader) – Training data

  • valid_dataloader (torch.Dataloader) – Validation data

  • _data – (molearn.data Data object given to set_data

Parameters:
  • device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()

  • log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to

checkpoint(epoch, valid_logs, checkpoint_folder, loss_key='valid_loss')[source]

Checkpoint the current network. The checkpoint will be saved as 'last.ckpt'. If valid_logs[loss_key] is better than self.best then this checkpoint will replace self.best and 'last.ckpt' will be renamed to f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt' and the former best (filename saved as self.best_name) will be deleted

Parameters:
  • epoch (int) – current epoch, will be saved within the ckpt. Current epoch can usually be obtained with self.epoch

  • valid_logs (dict) – results dictionary containing loss_key.

  • checkpoint_folder (str) – The folder in which to save the checkpoint.

  • loss_key (str) – (default: ‘valid_loss’) The key with which to get loss from valid_logs.

common_step(batch)[source]

Called from both train_step and valid_step. Calculates the mean squared error loss for self.autoencoder. Encoded and decoded frames are saved in self._internal under keys encoded and decoded respectively should you wish to use them elsewhere.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms] A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return calculated mse_loss

Return type:

dict

get_network_summary()[source]

returns a dictionary containing information about the size of the autoencoder.

learning_rate_sweep(max_lr=100, min_lr=1e-05, number_of_iterations=1000, checkpoint_folder='checkpoint_sweep', train_on='mse_loss', save=['loss', 'mse_loss'])[source]

Deprecated method. Performs a sweep of learning rate between max_lr and min_lr over number_of_iterations. See Finding Good Learning Rate and The One Cycle Policy

Parameters:
  • max_lr (float) – (default: 100.0) final/maximum learning rate to be used

  • min_lr (float) – (default: 1e-5) Starting learning rate

  • number_of_iterations (int) – (default: 1000) Number of steps to run sweep over.

  • train_on (str) – (default: ‘mse_loss’) key returned from trainer.train_step(batch) on which to train

  • save (list) – (default: [‘loss’, ‘mse_loss’]) what loss values to return.

Returns:

array of shape [len(save), min(number_of_iterations, iterations before NaN)] containing loss values defined in save key word.

Return type:

numpy.ndarray

load_checkpoint(checkpoint_name='best', checkpoint_folder='', load_optimiser=True)[source]

Load checkpoint.

Parameters:
  • checkpoint_name (str) – (default: 'best') if 'best' then checkpoint_folder is searched for all files beginning with 'checkpoint_' and loss values are extracted from the filename by assuming all characters after 'loss' and before '.ckpt' are a float. The checkpoint with the lowest loss is loaded. checkpoint_name is not 'best' we search for a checkpoint file at f'{checkpoint_folder}/{checkpoint_name}'.

  • checkpoint_folder (str) – Folder whithin which to search for checkpoints.

  • load_optimiser (bool) – (default: True) Should optimiser state dictionary be loaded.

log(log_dict, verbose=None)[source]

Then contents of log_dict are dumped using json.dumps(log_dict) and printed and/or appended to self.log_filename This function is called from self.run

Parameters:
  • log_dict (dict) – dictionary to be printed or saved

  • verbose (bool) – (default: False) if True or self.verbose is true the output will be printed

prepare_optimiser(lr=0.001, weight_decay=0.0001, **optimiser_kwargs)[source]

The Default optimiser is AdamW and is saved in self.optimiser. With no optional arguments this function is the same as doing: trainer.optimiser = torch.optim.AdawW(self.autoencoder.parameters(), lr=1e-3, weight_decay = 0.0001)

Parameters:
  • lr (float) – (default: 1e-3) optimiser learning rate.

  • weight_decay (float) – (default: 0.0001) optimiser weight_decay

  • **optimiser_kwargs – other kwargs that are passed onto AdamW

run(max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None)[source]

Calls the following in a loop:

Parameters:
  • max_epochs (int) – (default: 100). run until self.epoch matches max_epochs

  • log_filename (str) – (default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created.

  • log_folder (str) – (default: None) If not None log_folder directory is created and the log file is saved within this folder

  • checkpoint_frequency (int) – (default: 1) The frequency at which last.ckpt is saved. A checkpoint is saved every epoch if 'valid_loss' is lower else when self.epoch is divisible by checkpoint_frequency.

  • checkpoint_folder (str) – (default: ‘checkpoint_folder’) Where to save checkpoints.

  • allow_n_failures (int) – (default: 10) How many times should training be restarted on error. Each epoch is run in a try except block. If an error is raised training is continued from the best checkpoint.

  • verbose (bool) – (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename

scheduler_step(logs)[source]

This function does nothing. It is called after self.valid_epoch in Trainer.run() and before checkpointing. It is designed to be overridden if you wish to use a scheduler.

Parameters:

logs (dict) – Dictionary passed passed containing all logs returned from self.train_epoch and self.valid_epoch.

set_autoencoder(autoencoder, **kwargs)[source]
Parameters:
  • autoencoder – (autoencoder,) torch network class that implements autoencoder.encode, and autoencoder.decode. Please pass the class not the instance

  • **kwargs – any other kwargs given to this method will be used to initialise the network self.autoencoder = autoencoder(**kwargs)

set_data(data, **kwargs)[source]

Sets up internal variables and gives trainer access to dataloaders. self.train_dataloader, self.valid_dataloader, self.std, self.mean, self.mol will all be obtained from this object.

:param PDBData data: data object to be set. :param **kwargs: will be passed on to data.get_dataloader(**kwargs)

set_dataloader(train_dataloader=None, valid_dataloader=None)[source]
Parameters:
  • train_dataloader (torch.DataLoader) – Alternatively set using trainer.train_dataloader = dataloader

  • valid_dataloader (torch.DataLoader) – Alternatively set using trainer.valid_dataloader = dataloader

train_epoch(epoch)[source]

Train one epoch. Called once an epoch from trainer.run This method performs the following functions: - Sets network to train mode via self.autoencoder.train() - for each batch in self.train_dataloader implements typical pytorch training protocol:

  • zero gradients with call self.optimiser.zero_grad()

  • Use training implemented in trainer.train_step result = self.train_step(batch)

  • Determine gradients using keyword 'loss' e.g. result['loss'].backward()

  • Update network gradients. self.optimiser.step

  • All results are aggregated via averaging and returned with 'train_' prepended on the dictionary key

Parameters:

epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.

Returns:

Return all results from train_step averaged. These results will be printed and/or logged in trainer.run() via a call to self.log(results)

Return type:

dict

train_step(batch)[source]

Called from Trainer.train_epoch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that self.train_epoch will call result['loss'].backwards() to obtain gradients.

Return type:

dict

update_optimiser_hyperparameters(**kwargs)[source]

Update optimeser hyperparameter e.g. trainer.update_optimiser_hyperparameters(lr = 1e3)

Parameters:

**kwargs – each key value pair in **kwargs is inserted into self.optimiser

valid_epoch(epoch)[source]

Called once an epoch from trainer.run within a no_grad context. This method performs the following functions: - Sets network to eval mode via self.autoencoder.eval() - for each batch in self.valid_dataloader calls trainer.valid_step to retrieve validation loss - All results are aggregated via averaging and returned with 'valid_' prepended on the dictionary key

  • The loss with key 'loss' is returned as 'valid_loss' this will be the loss value by which the best checkpoint is determined.

Parameters:

epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.

Returns:

Return all results from valid_step averaged. These results will be printed and/or logged in Trainer.run() via a call to self.log(results)

Return type:

dict

valid_step(batch)[source]

Called from Trainer.valid_epoch on every mini-batch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that will be the score via which the best checkpoint is determined.

Return type:

dict

class OpenMM_Physics_Trainer(*args, **kwargs)[source]

Bases: Trainer

OpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ is calculated using OpenMM and the forces are inserted into backwards pass. To use this trainer requires the additional step of calling prepare_physics.

Parameters:
  • device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()

  • log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to

common_physics_step(batch, latent)[source]

Called from both train_step and valid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.

Parameters:
  • batch (torch.Tensor) – tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine n_atoms

  • latent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.

prepare_physics(physics_scaling_factor=0.1, clamp_threshold=100000000.0, clamp=False, start_physics_at=0, **kwargs)[source]

Create self.physics_loss object from loss_functions.openmm_energy Needs self.mol, self.std, and self._data.atoms to have been set with Trainer.set_data

Parameters:
  • physics_scaling_factor (float) – scaling factor saved to self.psf that is used in train_step. Defaults to 0.1

  • clamp_threshold (float) – if clamp=True is passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8

  • clamp (bool) – Whether to clamp the forces. Defaults to False

  • start_physics_at (int) – As of yet unused parameter saved as self.start_physics_at = start_physics_at. Default: 0

  • **kwargs – All aditional kwargs will be passed to openmm_energy

train_step(batch)[source]

This method overrides Trainer.train_step and adds an additional ‘Physics_loss’ term.

Mse_loss and physics loss are summed (Mse_loss + scale*physics_loss)with a scaling factor self.psf*mse_loss/Physics_loss. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a torch.no_grad context manager the gradients are not computed. This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.

Called from Trainer.train_epoch.

Parameters:

batch (torch.Tensor) – tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that self.train_epoch will call result['loss'].backwards() to obtain gradients.

Return type:

dict

valid_step(batch)[source]

This method overrides Trainer.valid_step and adds an additional ‘Physics_loss’ term.

Differently to train_step this method sums the logs of mse_loss and physics_loss final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])

Called from super class Trainer.valid_epoch on every mini-batch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that will be the score via which the best checkpoint is determined.

Return type:

dict

class Torch_Physics_Trainer(*args, **kwargs)[source]

Bases: Trainer

Torch_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ (bonds, angles, and torsions) is calculated using pytorch. To use this trainer requires the additional step of calling :func: prepare_physics <molearn.trainers.Torch_Physics_Trainer>.

Parameters:
  • device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()

  • log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to

common_physics_step(batch, latent)[source]

Called from both train_step and valid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.

Parameters:
  • batch (torch.Tensor) – tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine n_atoms

  • latent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.

prepare_physics(physics_scaling_factor=0.1)[source]

Create self.physics_loss object from loss_functions.TorchProteinEnergy Needs self.std, self._data to have been set with Trainer.set_data :param float physics_scaling_factor: (default: 0.1) scaling factor saved to self.psf that is used in :func: train_step <molearn.trainers.Torch_Physics_Trainer.train_step> It will control the relative importance of mse_loss and physics_loss in training.

train_step(batch)[source]

This method overrides Trainer.train_step and adds an additional ‘Physics_loss’ term.

Mse_loss and physics loss are summed (Mse_loss + scale*physics_loss)with a scaling factor self.psf*mse_loss/Physics_loss. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a torch.no_grad context manager the gradients are not computed. This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.

Called from Trainer.train_epoch.

Parameters:

batch (torch.Tensor) – tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that self.train_epoch will call result['loss'].backwards() to obtain gradients.

Return type:

dict

valid_step(batch)[source]

This method overrides Trainer.valid_step and adds an additional ‘Physics_loss’ term.

Differently to train_step this method sums the logs of mse_loss and physics_loss final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])

Called from super class Trainer.valid_epoch on every mini-batch.

Parameters:

batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by self.std.

Returns:

Return loss. The dictionary must contain an entry with key 'loss' that will be the score via which the best checkpoint is determined.

Return type:

dict