Source code for molearn.data.pdb_data

import numpy as np
import torch
from copy import deepcopy
import biobox as bb


[docs]class PDBData: def __init__(self, filename=None, fix_terminal=False, atoms=None): ''' Create object enabling the manipulation of multi-PDB files into a dataset suitable for training. :param filename: None, str or list of strings. If not None, :func:`import_pdb <molearn.data.PDBData.import_pdb>` is called on each filename provided. :param fix_terminal: if True, calls :func:`fix_terminal <molearn.data.PDBData.fix_terminal>` after import, and before atomselect :param atoms: if not None, calls :func:`atomselect <molearn.data.PDBData.atomselect>` ''' if isinstance(filename, str): self.import_pdb(filename) elif filename is not None: for _filename in filename: self.import_pdb(_filename) if fix_terminal: self.fix_terminal() if atoms is not None: self.atomselect(atoms=atoms)
[docs] def import_pdb(self, filename): ''' Load multiPDB file. This command can be called multiple times to load many datasets, if these feature the same number of atoms :param filename: path to multiPDB file. ''' if not hasattr(self, '_mol'): self._mol = bb.Molecule() self._mol.import_pdb(filename) if not hasattr(self, 'filename'): self.filename = [] self.filename.append(filename)
[docs] def fix_terminal(self): ''' Rename OT1 N-terminal Oxygen to O if terminal oxygens are named OT1 and OT2 otherwise no oxygen will be selected during an atomselect using atoms = ['CA', 'C','N','O','CB']. No template will be found for terminal residue in openmm_loss. Alternative solution is to use atoms = ['CA', 'C', 'N', 'O', 'CB', 'OT1']. instead. ''' ot1 = np.where(self._mol.data['name']=='OT1')[0] ot2 = np.where(self._mol.data['name']=='OT2')[0] if len(ot1)!=0 and len(ot2)!=0: self._mol.data.loc[ot1,'name']='O'
[docs] def atomselect(self, atoms, ignore_atoms=[]): ''' From all imported PDBs, extract only atoms of interest. :func:`import_pdb <molearn.data.PDBData.import_pdb>` must have been called at least once, either at class instantiation or as a separate call. :param atoms: list of atom names, or "no_hydrogen". ''' if atoms == "*": _atoms = list(np.unique(self._mol.data["name"].values)) for to_remove in ignore_atoms: if to_remove in _atoms: _atoms.remove(to_remove) elif atoms == "no_hydrogen": _atoms = self.atoms # list(np.unique(self._mol.data["name"].values)) #all the atoms _plain_atoms = [] for a in _atoms: if a in self._mol.knowledge['atomtype']: _plain_atoms.append(self._mol.knowledge['atomtype'][a]) elif a[:-1] in self._mol.knowledge['atomtype']: _plain_atoms.append(self._mol.knowledge['atomtype'][a[:-1]]) print(f'Could not find {a}. I am assuing you meant {a[:-1]} instead.') elif a[:-2] in self._mol.knowledge['atomtype']: _plain_atoms.append(self._mol.knowledge['atomtype'][a[:-2]]) print(f'Could not find {a}. I am assuming you meant {a[:-2]} instead.') else: _plain_atoms.append(self._mol.knowledge['atomtype'][a]) # if above failed just raise the keyerror _atoms = [atom for atom, element in zip(_atoms, _plain_atoms) if element != 'H'] else: _atoms = [_a for _a in atoms if _a not in ignore_atoms] _, self._idxs = self._mol.atomselect("*", "*", _atoms, get_index=True) self._mol = self._mol.get_subset(self._idxs)
[docs] def prepare_dataset(self): ''' Once all datasets have been loaded, normalise data and convert into `torch.Tensor` (ready for training) ''' if not hasattr(self, 'dataset'): assert hasattr(self, '_mol'), 'You need to call import_pdb before preparing the dataset' self.dataset = self._mol.coordinates.copy() if not hasattr(self, 'std'): self.std = self.dataset.std() if not hasattr(self, 'mean'): self.mean = self.dataset.mean() self.dataset -= self.mean self.dataset /= self.std self.dataset = torch.from_numpy(self.dataset).float() self.dataset = self.dataset.permute(0,2,1) print(f'Dataset.shape: {self.dataset.shape}') print(f'mean: {str(self.mean)}, std: {str(self.std)}')
[docs] def get_atominfo(self): ''' generate list of all atoms in dataset, where every line contains [atom name, residue name, resid] ''' if not hasattr(self, 'atominfo'): assert hasattr(self, '_mol'), 'You need to call import_pdb before getting atom info' self.atominfo = self._mol.get_data(columns=['name', 'resname', 'resid']) return self.atominfo
[docs] def frame(self): ''' return `biobox.Molecule` object with loaded data ''' M = bb.Molecule() M.coordinates = self._mol.coordinates[[0]] M.data = self._mol.data M.data['index'] = np.arange(self._mol.coordinates.shape[1]) M.current = 0 M.points = M.coordinates.view()[M.current] M.properties['center'] = M.get_center() return deepcopy(M)
[docs] def get_dataloader(self, batch_size, validation_split=0.1, pin_memory=True, dataset_sample_size=-1, manual_seed=None, shuffle=True, sampler=None): ''' :param batch_size: :param validation_split: :param pin_memory: :param dataset_sample_size: :param manual_seed: :param shuffle: :param sampler: :return: `torch.utils.data.DataLoader` for training set :return: `torch.utils.data.DataLoader` for validation set ''' if not hasattr(self, 'dataset'): self.prepare_dataset() valid_size = int(len(self.dataset)*validation_split) train_size = len(self.dataset) - valid_size dataset = torch.utils.data.TensorDataset(self.dataset.float()) if manual_seed is not None: self.train_dataset, self.valid_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(manual_seed)) self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, pin_memory=pin_memory, sampler=torch.utils.data.RandomSampler(self.train_dataset,generator=torch.Generator().manual_seed(manual_seed))) else: self.train_dataset, self.valid_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size]) self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, pin_memory=pin_memory, shuffle=True) self.valid_dataloader = torch.utils.data.DataLoader(self.valid_dataset, batch_size=batch_size, pin_memory=pin_memory,shuffle=True) return self.train_dataloader, self.valid_dataloader
[docs] def split(self, *args, **kwargs): ''' Split :func:`PDBData <molearn.data.PDBData>` into two other :func:`PDBData <molearn.data.PDBData>` objects corresponding to train and valid sets. :param manual_seed: manual seed used to split dataset :param validation_split: ratio of data to randomly assigned as validation :param train_size: if not None, specify number of train structures to be returned :param valid_size: if not None, speficy number of valid structures to be returned :return: :func:`PDBData <molearn.data.PDBData>` object corresponding to train set :return: :func:`PDBData <molearn.data.PDBData>` object corresponding to validation set ''' # validation_split=0.1, valid_size=None, train_size=None, manual_seed = None): train_dataset, valid_dataset = self.get_datasets(*args, **kwargs) train = PDBData() valid = PDBData() for data in [train, valid]: for key in ['_mol', 'std', 'mean', 'filename']: setattr(data, key, getattr(self, key)) train.dataset = train_dataset valid.dataset = valid_dataset return train, valid
[docs] def get_datasets(self, validation_split=0.1, valid_size=None, train_size=None, manual_seed=None): ''' Create a training and validation set from the imported data :param validation_split: ratio of data to randomly assigned as validation :param valid_size: if not None, specify number of train structures to be returned :param train_size: if not None, speficy number of valid structures to be returned :param manual_seed: seed to initialise the random number generator used for splitting the dataset. Useful to replicate a specific split. :return: two `torch.Tensor`, for training and validation structures. ''' if not hasattr(self, 'dataset'): self.prepare_dataset() dataset = self.dataset.float() if train_size is None: _valid_size = int(len(self.dataset)*validation_split) _train_size = len(self.dataset) - _valid_size else: _train_size = train_size if valid_size is None: _valid_size = validation_split*_train_size else: _valid_size = valid_size from torch import randperm if manual_seed is not None: indices = randperm(len(self.dataset), generator=torch.Generator().manual_seed(manual_seed)) else: indices = randperm(len(self.dataset)) self.indices = indices train_dataset = dataset[indices[:_train_size]] valid_dataset = dataset[indices[_train_size:_train_size+_valid_size]] return train_dataset, valid_dataset
@property def atoms(self): return list(np.unique(self._mol.data["name"].values)) # all the atoms @property def mol(self): return self.frame()