Source code for molearn.trainers.openmm_physics_trainer

import torch
from molearn.loss_functions import openmm_energy
from .trainer import Trainer


[docs]class OpenMM_Physics_Trainer(Trainer): ''' OpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. An extra 'physics_loss' is calculated using OpenMM and the forces are inserted into backwards pass. To use this trainer requires the additional step of calling :func:`prepare_physics <molearn.trainers.OpenMM_Physics_Trainer.prepare_physics>`. ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, **kwargs): ''' Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy <molearn.loss_functions.openmm_energy>` Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data<molearn.trainer.Trainer.set_data>` :param float physics_scaling_factor: scaling factor saved to ``self.psf`` that is used in :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>`. Defaults to 0.1 :param float clamp_threshold: if ``clamp=True`` is passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8 :param bool clamp: Whether to clamp the forces. Defaults to False :param int start_physics_at: As of yet unused parameter saved as ``self.start_physics_at = start_physics_at``. Default: 0 :param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy <molearn.loss_functions.openmm_energy>` ''' self.start_physics_at = start_physics_at self.psf = physics_scaling_factor if clamp: clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold) else: clamp_kwargs = None self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, **kwargs)
[docs] def common_physics_step(self, batch, latent): ''' Called from both :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>` and :func:`valid_step <molearn.trainers.OpenMM_Physics_Trainer.valid_step>`. Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as ``self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss``. :param torch.Tensor batch: tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine ``n_atoms`` :param torch.Tensor latent: tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch. ''' alpha = torch.rand(int(len(batch)//2), 1, 1).type_as(latent) latent_interpolated = (1-alpha)*latent[:-1:2] + alpha*latent[1::2] generated = self.autoencoder.decode(latent_interpolated)[:, :, :batch.size(2)] self._internal['generated'] = generated energy = self.physics_loss(generated) energy[energy.isinf()] = 1e35 energy = torch.clamp(energy, max=1e34) energy = energy.nanmean() return {'physics_loss':energy} # a if not energy.isinf() else torch.tensor(0.0)}
[docs] def train_step(self, batch): ''' This method overrides :func:`Trainer.train_step <molearn.trainers.Trainer.train_step>` and adds an additional 'Physics_loss' term. Mse_loss and physics loss are summed (``Mse_loss + scale*physics_loss``)with a scaling factor ``self.psf*mse_loss/Physics_loss``. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a ``torch.no_grad`` context manager the gradients are not computed. This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step. Called from :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`. :param torch.Tensor batch: tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``. :returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch <molearn.trainers.Trainer.train_epoch>` will call ``result['loss'].backwards()`` to obtain gradients. :rtype: dict ''' results = self.common_step(batch) results.update(self.common_physics_step(batch, self._internal['encoded'])) with torch.no_grad(): scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5) final_loss = results['mse_loss']+scale*results['physics_loss'] results['loss'] = final_loss return results
[docs] def valid_step(self, batch): ''' This method overrides :func:`Trainer.valid_step <molearn.trainers.Trainer.valid_step>` and adds an additional 'Physics_loss' term. Differently to :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>` this method sums the logs of mse_loss and physics_loss ``final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])`` Called from super class :func:`Trainer.valid_epoch<molearn.trainer.Trainer.valid_epoch>` on every mini-batch. :param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``. :returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that will be the score via which the best checkpoint is determined. :rtype: dict ''' results = self.common_step(batch) results.update(self.common_physics_step(batch, self._internal['encoded'])) # scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5) final_loss = torch.log(results['mse_loss'])+self.psf*torch.log(results['physics_loss']) results['loss'] = final_loss return results
if __name__=='__main__': pass