import os
import glob
import shutil
import numpy as np
import time
import torch
from molearn.data import PDBData
import json
class TrainingFailure(Exception):
pass
[docs]class Trainer:
'''
Trainer class that defines a number of useful methods for training an autoencoder.
:ivar autoencoder: any torch.nn.module network that has methods ``autoencoder.encode`` and ``autoencoder.decode`` with the weights associated with these operations accessible via ``autoencoder.encoder`` and ``autoencoder.decoder``. This can be set using set_autoencoder
:ivar _autoencoder_kwargs: kwargs used to initialise the network. Saved in every checkpoint under the key 'kwargs'
:ivar torch.optim.optimiser optimiser: pytorch optimiser with access to self.autoencoder.parameters()
:ivar torch.Device device: The device used for all operations.
:ivar int epoch: the current epoch
:ivar float best: The best validation score corresponding to the current best checkpoint
:ivar float best_name: the filename corresponding to self.best
:ivar float std: Standard deviation of the training dataset. Can be used to unscale structures produced by the network.
:ivar float mol: Biobox molecule containing a single example frame of the protein being trained on. This can be used to save examples during training. It is also used to save a temporary pdb that may be used to initialise thirdparty packages.
:ivar torch.Dataloader train_dataloader: Training data
:ivar torch.Dataloader valid_dataloader: Validation data
:ivar _data: (:func:`molearn.data <molearn.data.PDBata>` Data object given to :func:`set_data <molearn.trainers.Trainer.set_data>`
'''
def __init__(self, device=None, log_filename='log_file.dat'):
'''
:param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available()
:param str log_filename: (default: 'default_log_filename.json') file used to log outputs to
'''
if not device:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
else:
self.device = device
print(f'device: {self.device}')
self.best = None
self.best_name = None
self.epoch = 0
self.scheduler = None
self.verbose = True
self.log_filename = 'default_log_filename.json'
self.scheduler_key = None
[docs] def get_network_summary(self):
'''
returns a dictionary containing information about the size of the autoencoder.
'''
def get_parameters(trainable_only, model):
return sum(p.numel() for p in model.parameters() if (p.requires_grad and trainable_only))
return dict(
encoder_trainable=get_parameters(True, self.autoencoder.encoder),
encoder_total=get_parameters(False, self.autoencoder.encoder),
decoder_trainable=get_parameters(True, self.autoencoder.decoder),
decoder_total=get_parameters(False, self.autoencoder.decoder),
autoencoder_trainable=get_parameters(True, self.autoencoder),
autoencoder_total=get_parameters(False, self.autoencoder))
[docs] def set_autoencoder(self, autoencoder, **kwargs):
'''
:param autoencoder: (:func:`autoencoder <molearn.models>`,) torch network class that implements ``autoencoder.encode``, and ``autoencoder.decode``. Please pass the class not the instance
:param \*\*kwargs: any other kwargs given to this method will be used to initialise the network ``self.autoencoder = autoencoder(**kwargs)``
'''
if isinstance(autoencoder, type):
self.autoencoder = autoencoder(**kwargs).to(self.device)
else:
self.autoencoder = autoencoder.to(self.device)
self._autoencoder_kwargs = kwargs
[docs] def set_dataloader(self, train_dataloader=None, valid_dataloader=None):
'''
:param torch.DataLoader train_dataloader: Alternatively set using ``trainer.train_dataloader = dataloader``
:param torch.DataLoader valid_dataloader: Alternatively set using ``trainer.valid_dataloader = dataloader``
'''
if train_dataloader is not None:
self.train_dataloader = train_dataloader
if valid_dataloader is not None:
self.valid_dataloader = valid_dataloader
[docs] def set_data(self, data, **kwargs):
'''
Sets up internal variables and gives trainer access to dataloaders.
``self.train_dataloader``, ``self.valid_dataloader``, ``self.std``, ``self.mean``, ``self.mol`` will all be obtained from this object.
:param :func:`PDBData <molearn.data.PDBData>` data: data object to be set.
:param \*\*kwargs: will be passed on to :func:`data.get_dataloader(**kwargs) <molearn.data.PDBData.get_dataloader>`
'''
if isinstance(data, PDBData):
self.set_dataloader(*data.get_dataloader(**kwargs))
else:
raise NotImplementedError('Have not implemented this method to use any data other than PDBData yet')
self.std = data.std
self.mean = data.mean
self.mol = data.mol
self._data = data
[docs] def prepare_optimiser(self, lr=1e-3, weight_decay=0.0001, **optimiser_kwargs):
'''
The Default optimiser is ``AdamW`` and is saved in ``self.optimiser``.
With no optional arguments this function is the same as doing:
``trainer.optimiser = torch.optim.AdawW(self.autoencoder.parameters(), lr=1e-3, weight_decay = 0.0001)``
:param float lr: (default: 1e-3) optimiser learning rate.
:param float weight_decay: (default: 0.0001) optimiser weight_decay
:param \*\*optimiser_kwargs: other kwargs that are passed onto AdamW
'''
self.optimiser = torch.optim.AdamW(self.autoencoder.parameters(), lr=lr, weight_decay=weight_decay, **optimiser_kwargs)
[docs] def log(self, log_dict, verbose=None):
'''
Then contents of log_dict are dumped using ``json.dumps(log_dict)`` and printed and/or appended to ``self.log_filename``
This function is called from :func:`self.run <molearn.trainers.Trainer.run>`
:param dict log_dict: dictionary to be printed or saved
:param bool verbose: (default: False) if True or self.verbose is true the output will be printed
'''
dump = json.dumps(log_dict)
if verbose or self.verbose:
print(dump)
with open(self.log_filename, 'a') as f:
f.write(dump+'\n')
[docs] def scheduler_step(self, logs):
'''
This function does nothing. It is called after :func:`self.valid_epoch <molearn.trainers.Trainer.valid_epoch>` in :func:`Trainer.run() <molearn.trainers.Trainer.run>` and before :func:`checkpointing <molearn.trainers.Trainer.checkpoint>`. It is designed to be overridden if you wish to use a scheduler.
:param dict logs: Dictionary passed passed containing all logs returned from ``self.train_epoch`` and ``self.valid_epoch``.
'''
pass
[docs] def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None):
'''
Calls the following in a loop:
- :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`
- :func:`Trainer.valid_epoch <molearn.trainers.Trainer.valid_epoch>`
- :func:`Trainer.scheduler_step <molearn.trainers.Trainer.scheduler_step>`
- :func:`Trainer.checkpoint <molearn.trainers.Trainer.checkpoint>`
- :func:`Trainer.checkpoint <molearn.trainers.Trainer.checkpoint>`
- :func:`Trainer.log <molearn.trainers.Trainer.log>`
:param int max_epochs: (default: 100). run until ``self.epoch`` matches max_epochs
:param str log_filename: (default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created.
:param str log_folder: (default: None) If not None log_folder directory is created and the log file is saved within this folder
:param int checkpoint_frequency: (default: 1) The frequency at which last.ckpt is saved. A checkpoint is saved every epoch if ``'valid_loss'`` is lower else when ``self.epoch`` is divisible by checkpoint_frequency.
:param str checkpoint_folder: (default: 'checkpoint_folder') Where to save checkpoints.
:param int allow_n_failures: (default: 10) How many times should training be restarted on error. Each epoch is run in a try except block. If an error is raised training is continued from the best checkpoint.
:param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename
'''
if log_filename is not None:
self.log_filename = log_filename
if log_folder is not None:
if not os.path.exists(log_folder):
os.mkdir(log_folder)
self.log_filename = log_folder+'/'+self.log_filename
if verbose is not None:
self.verbose = verbose
for attempt in range(allow_n_failures):
try:
for epoch in range(self.epoch, max_epochs):
time1 = time.time()
logs = self.train_epoch(epoch)
time2 = time.time()
with torch.no_grad():
logs.update(self.valid_epoch(epoch))
time3 = time.time()
self.scheduler_step(logs)
if self.best is None or self.best > logs['valid_loss']:
self.checkpoint(epoch, logs, checkpoint_folder)
elif epoch % checkpoint_frequency == 0:
self.checkpoint(epoch, logs, checkpoint_folder)
time4 = time.time()
logs.update(epoch=epoch,
train_seconds=time2-time1,
valid_seconds=time3-time2,
checkpoint_seconds=time4-time3,
total_seconds=time4-time1)
self.log(logs)
if np.isnan(logs['valid_loss']) or np.isnan(logs['train_loss']):
raise TrainingFailure('nan received, failing')
self.epoch+= 1
except TrainingFailure:
if attempt == (allow_n_failures-1):
failure_message = f'Training Failure due to Nan in attempt {attempt}, end now/n'
self.log({'Failure':failure_message})
raise TrainingFailure('nan received, failing')
failure_message = f'Training Failure due to Nan in attempt {attempt}, try again from best/n'
self.log({'Failure':failure_message})
if hasattr(self, 'best'):
self.load_checkpoint('best', checkpoint_folder)
else:
break
[docs] def train_epoch(self,epoch):
'''
Train one epoch. Called once an epoch from :func:`trainer.run <molearn.trainers.Trainer.run>`
This method performs the following functions:
- Sets network to train mode via ``self.autoencoder.train()``
- for each batch in self.train_dataloader implements typical pytorch training protocol:
* zero gradients with call ``self.optimiser.zero_grad()``
* Use training implemented in trainer.train_step ``result = self.train_step(batch)``
* Determine gradients using keyword ``'loss'`` e.g. ``result['loss'].backward()``
* Update network gradients. ``self.optimiser.step``
- All results are aggregated via averaging and returned with ``'train_'`` prepended on the dictionary key
:param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
:returns: Return all results from train_step averaged. These results will be printed and/or logged in :func:`trainer.run() <molearn.trainers.Trainer.run>` via a call to :func:`self.log(results) <molearn.trainers.Trainer.log>`
:rtype: dict
'''
self.autoencoder.train()
N = 0
results = {}
for i, batch in enumerate(self.train_dataloader):
batch = batch[0].to(self.device)
self.optimiser.zero_grad()
train_result = self.train_step(batch)
train_result['loss'].backward()
self.optimiser.step()
if i == 0:
results = {key:value.item()*len(batch) for key, value in train_result.items()}
else:
for key in train_result.keys():
results[key] += train_result[key].item()*len(batch)
N+=len(batch)
return {f'train_{key}': results[key]/N for key in results.keys()}
[docs] def train_step(self, batch):
'''
Called from :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`.
:param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch <molearn.trainers.Trainer.train_epoch>` will call ``result['loss'].backwards()`` to obtain gradients.
:rtype: dict
'''
results = self.common_step(batch)
results['loss'] = results['mse_loss']
return results
[docs] def common_step(self, batch):
'''
Called from both train_step and valid_step.
Calculates the mean squared error loss for self.autoencoder.
Encoded and decoded frames are saved in self._internal under keys ``encoded`` and ``decoded`` respectively should you wish to use them elsewhere.
:param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms] A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return calculated mse_loss
:rtype: dict
'''
self._internal = {}
encoded = self.autoencoder.encode(batch)
self._internal['encoded'] = encoded
decoded = self.autoencoder.decode(encoded)[:,:,:batch.size(2)]
self._internal['decoded'] = decoded
return dict(mse_loss=((batch-decoded)**2).mean())
[docs] def valid_epoch(self, epoch):
'''
Called once an epoch from :func:`trainer.run <molearn.trainers.Trainer.run>` within a no_grad context.
This method performs the following functions:
- Sets network to eval mode via ``self.autoencoder.eval()``
- for each batch in ``self.valid_dataloader`` calls :func:`trainer.valid_step <molearn.trainers.Trainer.valid_step>` to retrieve validation loss
- All results are aggregated via averaging and returned with ``'valid_'`` prepended on the dictionary key
* The loss with key ``'loss'`` is returned as ``'valid_loss'`` this will be the loss value by which the best checkpoint is determined.
:param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
:returns: Return all results from valid_step averaged. These results will be printed and/or logged in :func:`Trainer.run() <molearn.trainers.Trainer.run>` via a call to :func:`self.log(results) <molearn.trainers.Trainer.log>`
:rtype: dict
'''
self.autoencoder.eval()
N = 0
results = {}
for i, batch in enumerate(self.valid_dataloader):
batch = batch[0].to(self.device)
valid_result = self.valid_step(batch)
if i == 0:
results = {key:value.item()*len(batch) for key, value in valid_result.items()}
else:
for key in valid_result.keys():
results[key] += valid_result[key].item()*len(batch)
N+=len(batch)
return {f'valid_{key}': results[key]/N for key in results.keys()}
[docs] def valid_step(self, batch):
'''
Called from :func:`Trainer.valid_epoch<molearn.trainer.Trainer.valid_epoch>` on every mini-batch.
:param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that will be the score via which the best checkpoint is determined.
:rtype: dict
'''
results = self.common_step(batch)
results['loss'] = results['mse_loss']
return results
[docs] def learning_rate_sweep(self, max_lr=100, min_lr=1e-5, number_of_iterations=1000, checkpoint_folder='checkpoint_sweep', train_on='mse_loss', save=['loss', 'mse_loss']):
'''
Deprecated method.
Performs a sweep of learning rate between ``max_lr`` and ``min_lr`` over ``number_of_iterations``.
See `Finding Good Learning Rate and The One Cycle Policy <https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6>`_
:param float max_lr: (default: 100.0) final/maximum learning rate to be used
:param float min_lr: (default: 1e-5) Starting learning rate
:param int number_of_iterations: (default: 1000) Number of steps to run sweep over.
:param str train_on: (default: 'mse_loss') key returned from trainer.train_step(batch) on which to train
:param list save: (default: ['loss', 'mse_loss']) what loss values to return.
:returns: array of shape [len(save), min(number_of_iterations, iterations before NaN)] containing loss values defined in `save` key word.
:rtype: numpy.ndarray
'''
self.autoencoder.train()
def cycle(iterable):
while True:
for i in iterable:
yield i
init_loss = 0.0
values = []
data = iter(cycle(self.train_dataloader))
for i in range(number_of_iterations):
lr = min_lr*((max_lr/min_lr)**(i/number_of_iterations))
self.update_optimiser_hyperparameters(lr=lr)
batch = next(data)[0].to(self.device).float()
self.optimiser.zero_grad()
result = self.train_step(batch)
# result['loss']/=len(batch)
result[train_on].backward()
self.optimiser.step()
values.append((lr,)+tuple((result[name].item() for name in save)))
if i==0:
init_loss = result[train_on].item()
# if result[train_on].item()>1e6*init_loss:
# break
values = np.array(values)
print('min value ', values[np.nanargmin(values[:,1])])
return values
[docs] def update_optimiser_hyperparameters(self, **kwargs):
'''
Update optimeser hyperparameter e.g. ``trainer.update_optimiser_hyperparameters(lr = 1e3)``
:param \*\*kwargs: each key value pair in \*\*kwargs is inserted into ``self.optimiser``
'''
for g in self.optimiser.param_groups:
for key, value in kwargs.items():
g[key] = value
[docs] def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'):
'''
Checkpoint the current network. The checkpoint will be saved as ``'last.ckpt'``.
If valid_logs[loss_key] is better than self.best then this checkpoint will replace self.best and ``'last.ckpt'`` will be renamed to ``f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'`` and the former best (filename saved as ``self.best_name``) will be deleted
:param int epoch: current epoch, will be saved within the ckpt. Current epoch can usually be obtained with ``self.epoch``
:param dict valid_logs: results dictionary containing loss_key.
:param str checkpoint_folder: The folder in which to save the checkpoint.
:param str loss_key: (default: 'valid_loss') The key with which to get loss from valid_logs.
'''
valid_loss = valid_logs[loss_key]
if not os.path.exists(checkpoint_folder):
os.mkdir(checkpoint_folder)
torch.save({'epoch':epoch,
'model_state_dict': self.autoencoder.state_dict(),
'optimizer_state_dict': self.optimiser.state_dict(),
'loss': valid_loss,
'network_kwargs': self._autoencoder_kwargs,
'atoms': self._data.atoms,
'std': self.std,
'mean': self.mean},
f'{checkpoint_folder}/last.ckpt')
if self.best is None or self.best > valid_loss:
filename = f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(f'{checkpoint_folder}/last.ckpt', filename)
if self.best is not None:
os.remove(self.best_name)
self.best_name = filename
self.best_epoch = epoch
self.best = valid_loss
[docs] def load_checkpoint(self, checkpoint_name='best', checkpoint_folder='', load_optimiser=True):
'''
Load checkpoint.
:param str checkpoint_name: (default: ``'best'``) if ``'best'`` then checkpoint_folder is searched for all files beginning with ``'checkpoint_'`` and loss values are extracted from the filename by assuming all characters after ``'loss'`` and before ``'.ckpt'`` are a float. The checkpoint with the lowest loss is loaded. checkpoint_name is not ``'best'`` we search for a checkpoint file at ``f'{checkpoint_folder}/{checkpoint_name}'``.
:param str checkpoint_folder: Folder whithin which to search for checkpoints.
:param bool load_optimiser: (default: True) Should optimiser state dictionary be loaded.
'''
if checkpoint_name=='best':
if self.best_name is not None:
_name = self.best_name
else:
ckpts = glob.glob(checkpoint_folder+'/checkpoint_*')
indexs = [x.rfind('loss') for x in ckpts]
losses = [float(x[y+4:-5]) for x,y in zip(ckpts, indexs)]
_name = ckpts[np.argmin(losses)]
elif checkpoint_name =='last':
_name = f'{checkpoint_folder}/last.ckpt'
else:
_name = f'{checkpoint_folder}/{checkpoint_name}'
checkpoint = torch.load(_name, map_location=self.device)
if not hasattr(self, 'autoencoder'):
raise NotImplementedError('self.autoencoder does not exist, I have no way of knowing what network you want to load checkoint weights into yet, please set the network first')
self.autoencoder.load_state_dict(checkpoint['model_state_dict'])
if load_optimiser:
if not hasattr(self, 'optimiser'):
raise NotImplementedError('self.optimiser does not exist, I have no way of knowing what optimiser you previously used, please set it first.')
self.optimiser.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
self.epoch = epoch+1
if __name__=='__main__':
pass