Trainers¶
- class Trainer(device=None, log_filename='log_file.dat')[source]¶
Bases:
objectTrainer class that defines a number of useful methods for training an autoencoder.
- Variables:
autoencoder – any torch.nn.module network that has methods
autoencoder.encodeandautoencoder.decodewith the weights associated with these operations accessible viaautoencoder.encoderandautoencoder.decoder. This can be set using set_autoencoder_autoencoder_kwargs – kwargs used to initialise the network. Saved in every checkpoint under the key ‘kwargs’
optimiser (torch.optim.optimiser) – pytorch optimiser with access to self.autoencoder.parameters()
device (torch.Device) – The device used for all operations.
epoch (int) – the current epoch
best (float) – The best validation score corresponding to the current best checkpoint
best_name (float) – the filename corresponding to self.best
std (float) – Standard deviation of the training dataset. Can be used to unscale structures produced by the network.
mol (float) – Biobox molecule containing a single example frame of the protein being trained on. This can be used to save examples during training. It is also used to save a temporary pdb that may be used to initialise thirdparty packages.
train_dataloader (torch.Dataloader) – Training data
valid_dataloader (torch.Dataloader) – Validation data
_data – (
molearn.dataData object given toset_data
- Parameters:
device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()
log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to
- checkpoint(epoch, valid_logs, checkpoint_folder, loss_key='valid_loss')[source]¶
Checkpoint the current network. The checkpoint will be saved as
'last.ckpt'. If valid_logs[loss_key] is better than self.best then this checkpoint will replace self.best and'last.ckpt'will be renamed tof'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'and the former best (filename saved asself.best_name) will be deleted- Parameters:
epoch (int) – current epoch, will be saved within the ckpt. Current epoch can usually be obtained with
self.epochvalid_logs (dict) – results dictionary containing loss_key.
checkpoint_folder (str) – The folder in which to save the checkpoint.
loss_key (str) – (default: ‘valid_loss’) The key with which to get loss from valid_logs.
- common_step(batch)[source]¶
Called from both train_step and valid_step. Calculates the mean squared error loss for self.autoencoder. Encoded and decoded frames are saved in self._internal under keys
encodedanddecodedrespectively should you wish to use them elsewhere.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms] A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return calculated mse_loss
- Return type:
dict
- get_network_summary()[source]¶
returns a dictionary containing information about the size of the autoencoder.
- learning_rate_sweep(max_lr=100, min_lr=1e-05, number_of_iterations=1000, checkpoint_folder='checkpoint_sweep', train_on='mse_loss', save=['loss', 'mse_loss'])[source]¶
Deprecated method. Performs a sweep of learning rate between
max_lrandmin_lrovernumber_of_iterations. See Finding Good Learning Rate and The One Cycle Policy- Parameters:
max_lr (float) – (default: 100.0) final/maximum learning rate to be used
min_lr (float) – (default: 1e-5) Starting learning rate
number_of_iterations (int) – (default: 1000) Number of steps to run sweep over.
train_on (str) – (default: ‘mse_loss’) key returned from trainer.train_step(batch) on which to train
save (list) – (default: [‘loss’, ‘mse_loss’]) what loss values to return.
- Returns:
array of shape [len(save), min(number_of_iterations, iterations before NaN)] containing loss values defined in save key word.
- Return type:
numpy.ndarray
- load_checkpoint(checkpoint_name='best', checkpoint_folder='', load_optimiser=True)[source]¶
Load checkpoint.
- Parameters:
checkpoint_name (str) – (default:
'best') if'best'then checkpoint_folder is searched for all files beginning with'checkpoint_'and loss values are extracted from the filename by assuming all characters after'loss'and before'.ckpt'are a float. The checkpoint with the lowest loss is loaded. checkpoint_name is not'best'we search for a checkpoint file atf'{checkpoint_folder}/{checkpoint_name}'.checkpoint_folder (str) – Folder whithin which to search for checkpoints.
load_optimiser (bool) – (default: True) Should optimiser state dictionary be loaded.
- log(log_dict, verbose=None)[source]¶
Then contents of log_dict are dumped using
json.dumps(log_dict)and printed and/or appended toself.log_filenameThis function is called fromself.run- Parameters:
log_dict (dict) – dictionary to be printed or saved
verbose (bool) – (default: False) if True or self.verbose is true the output will be printed
- prepare_optimiser(lr=0.001, weight_decay=0.0001, **optimiser_kwargs)[source]¶
The Default optimiser is
AdamWand is saved inself.optimiser. With no optional arguments this function is the same as doing:trainer.optimiser = torch.optim.AdawW(self.autoencoder.parameters(), lr=1e-3, weight_decay = 0.0001)- Parameters:
lr (float) – (default: 1e-3) optimiser learning rate.
weight_decay (float) – (default: 0.0001) optimiser weight_decay
**optimiser_kwargs – other kwargs that are passed onto AdamW
- run(max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None)[source]¶
Calls the following in a loop:
- Parameters:
max_epochs (int) – (default: 100). run until
self.epochmatches max_epochslog_filename (str) – (default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created.
log_folder (str) – (default: None) If not None log_folder directory is created and the log file is saved within this folder
checkpoint_frequency (int) – (default: 1) The frequency at which last.ckpt is saved. A checkpoint is saved every epoch if
'valid_loss'is lower else whenself.epochis divisible by checkpoint_frequency.checkpoint_folder (str) – (default: ‘checkpoint_folder’) Where to save checkpoints.
allow_n_failures (int) – (default: 10) How many times should training be restarted on error. Each epoch is run in a try except block. If an error is raised training is continued from the best checkpoint.
verbose (bool) – (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename
- scheduler_step(logs)[source]¶
This function does nothing. It is called after
self.valid_epochinTrainer.run()and beforecheckpointing. It is designed to be overridden if you wish to use a scheduler.- Parameters:
logs (dict) – Dictionary passed passed containing all logs returned from
self.train_epochandself.valid_epoch.
- set_autoencoder(autoencoder, **kwargs)[source]¶
- Parameters:
autoencoder – (
autoencoder,) torch network class that implementsautoencoder.encode, andautoencoder.decode. Please pass the class not the instance**kwargs – any other kwargs given to this method will be used to initialise the network
self.autoencoder = autoencoder(**kwargs)
- set_data(data, **kwargs)[source]¶
Sets up internal variables and gives trainer access to dataloaders.
self.train_dataloader,self.valid_dataloader,self.std,self.mean,self.molwill all be obtained from this object.:param
PDBDatadata: data object to be set. :param **kwargs: will be passed on todata.get_dataloader(**kwargs)
- set_dataloader(train_dataloader=None, valid_dataloader=None)[source]¶
- Parameters:
train_dataloader (torch.DataLoader) – Alternatively set using
trainer.train_dataloader = dataloadervalid_dataloader (torch.DataLoader) – Alternatively set using
trainer.valid_dataloader = dataloader
- train_epoch(epoch)[source]¶
Train one epoch. Called once an epoch from
trainer.runThis method performs the following functions: - Sets network to train mode viaself.autoencoder.train()- for each batch in self.train_dataloader implements typical pytorch training protocol:zero gradients with call
self.optimiser.zero_grad()Use training implemented in trainer.train_step
result = self.train_step(batch)Determine gradients using keyword
'loss'e.g.result['loss'].backward()Update network gradients.
self.optimiser.step
All results are aggregated via averaging and returned with
'train_'prepended on the dictionary key
- Parameters:
epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
- Returns:
Return all results from train_step averaged. These results will be printed and/or logged in
trainer.run()via a call toself.log(results)- Return type:
dict
- train_step(batch)[source]¶
Called from
Trainer.train_epoch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'thatself.train_epochwill callresult['loss'].backwards()to obtain gradients.- Return type:
dict
- update_optimiser_hyperparameters(**kwargs)[source]¶
Update optimeser hyperparameter e.g.
trainer.update_optimiser_hyperparameters(lr = 1e3)- Parameters:
**kwargs – each key value pair in **kwargs is inserted into
self.optimiser
- valid_epoch(epoch)[source]¶
Called once an epoch from
trainer.runwithin a no_grad context. This method performs the following functions: - Sets network to eval mode viaself.autoencoder.eval()- for each batch inself.valid_dataloadercallstrainer.valid_stepto retrieve validation loss - All results are aggregated via averaging and returned with'valid_'prepended on the dictionary keyThe loss with key
'loss'is returned as'valid_loss'this will be the loss value by which the best checkpoint is determined.
- Parameters:
epoch (int) – The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
- Returns:
Return all results from valid_step averaged. These results will be printed and/or logged in
Trainer.run()via a call toself.log(results)- Return type:
dict
- valid_step(batch)[source]¶
Called from
Trainer.valid_epochon every mini-batch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'that will be the score via which the best checkpoint is determined.- Return type:
dict
- class OpenMM_Physics_Trainer(*args, **kwargs)[source]¶
Bases:
TrainerOpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ is calculated using OpenMM and the forces are inserted into backwards pass. To use this trainer requires the additional step of calling
prepare_physics.- Parameters:
device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()
log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to
- common_physics_step(batch, latent)[source]¶
Called from both
train_stepandvalid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved asself._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.- Parameters:
batch (torch.Tensor) – tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine
n_atomslatent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
- prepare_physics(physics_scaling_factor=0.1, clamp_threshold=100000000.0, clamp=False, start_physics_at=0, **kwargs)[source]¶
Create
self.physics_lossobject fromloss_functions.openmm_energyNeedsself.mol,self.std, andself._data.atomsto have been set withTrainer.set_data- Parameters:
physics_scaling_factor (float) – scaling factor saved to
self.psfthat is used intrain_step. Defaults to 0.1clamp_threshold (float) – if
clamp=Trueis passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8clamp (bool) – Whether to clamp the forces. Defaults to False
start_physics_at (int) – As of yet unused parameter saved as
self.start_physics_at = start_physics_at. Default: 0**kwargs – All aditional kwargs will be passed to
openmm_energy
- train_step(batch)[source]¶
This method overrides
Trainer.train_stepand adds an additional ‘Physics_loss’ term.Mse_loss and physics loss are summed (
Mse_loss + scale*physics_loss)with a scaling factorself.psf*mse_loss/Physics_loss. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within atorch.no_gradcontext manager the gradients are not computed. This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.Called from
Trainer.train_epoch.- Parameters:
batch (torch.Tensor) – tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'thatself.train_epochwill callresult['loss'].backwards()to obtain gradients.- Return type:
dict
- valid_step(batch)[source]¶
This method overrides
Trainer.valid_stepand adds an additional ‘Physics_loss’ term.Differently to
train_stepthis method sums the logs of mse_loss and physics_lossfinal_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])Called from super class
Trainer.valid_epochon every mini-batch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'that will be the score via which the best checkpoint is determined.- Return type:
dict
- class Torch_Physics_Trainer(*args, **kwargs)[source]¶
Bases:
TrainerTorch_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra ‘physics_loss’ (bonds, angles, and torsions) is calculated using pytorch. To use this trainer requires the additional step of calling :func: prepare_physics <molearn.trainers.Torch_Physics_Trainer>.
- Parameters:
device (torch.Device) – if not given will be determinined automatically based on torch.cuda.is_available()
log_filename (str) – (default: ‘default_log_filename.json’) file used to log outputs to
- common_physics_step(batch, latent)[source]¶
Called from both
train_stepandvalid_step. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved asself._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss.- Parameters:
batch (torch.Tensor) – tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine
n_atomslatent (torch.Tensor) – tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
- prepare_physics(physics_scaling_factor=0.1)[source]¶
Create
self.physics_lossobject fromloss_functions.TorchProteinEnergyNeedsself.std,self._datato have been set withTrainer.set_data:param float physics_scaling_factor: (default: 0.1) scaling factor saved toself.psfthat is used in :func: train_step <molearn.trainers.Torch_Physics_Trainer.train_step> It will control the relative importance of mse_loss and physics_loss in training.
- train_step(batch)[source]¶
This method overrides
Trainer.train_stepand adds an additional ‘Physics_loss’ term.Mse_loss and physics loss are summed (
Mse_loss + scale*physics_loss)with a scaling factorself.psf*mse_loss/Physics_loss. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within atorch.no_gradcontext manager the gradients are not computed. This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.Called from
Trainer.train_epoch.- Parameters:
batch (torch.Tensor) – tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'thatself.train_epochwill callresult['loss'].backwards()to obtain gradients.- Return type:
dict
- valid_step(batch)[source]¶
This method overrides
Trainer.valid_stepand adds an additional ‘Physics_loss’ term.Differently to
train_stepthis method sums the logs of mse_loss and physics_lossfinal_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])Called from super class
Trainer.valid_epochon every mini-batch.- Parameters:
batch (torch.Tensor) – Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by
self.std.- Returns:
Return loss. The dictionary must contain an entry with key
'loss'that will be the score via which the best checkpoint is determined.- Return type:
dict