Source code for molearn.analysis.GUI

# Copyright (c) 2021 Venkata K. Ramaswamy, Samuel C. Musson, Chris G. Willcocks, Matteo T. Degiacomi
#
# Molearn is free software ;
# you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation ;
# either version 2 of the License, or (at your option) any later version.
# molearn is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY ;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
# You should have received a copy of the GNU General Public License along with molearn ;
# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
#
# Author: Matteo Degiacomi

import time
import pickle
from IPython import display

import numpy as np
import MDAnalysis as mda

import warnings
warnings.filterwarnings("ignore")

from ipywidgets import Layout
from ipywidgets import widgets
from ipywidgets import interact
from tkinter import Tk, filedialog
import plotly.graph_objects as go
import nglview as nv

from .analyser import MolearnAnalysis
from .path import oversample, get_path_aggregate
from ..utils import as_numpy


[docs] class MolearnGUI: ''' This class produces an interactive visualisation for data stored in a :func:`MolearnAnalysis <molearn.analysis.MolearnAnalysis>` object, viewable within a Jupyter notebook. ''' def __init__(self, MA=None): ''' :param MA: Either :func:`MolearnAnalysis <molearn.analysis.MolearnAnalysis>` instance, or None (default). If None an empty GUI will be produced. ''' if not isinstance(MA, MolearnAnalysis) and MA is not None: raise Exception(f'Expecting an MolearnAnalysis instance, {type(MA)} found') else: self.MA = MA self.waypoints = [] # collection of all saved waypoints self.samples = [] # collection of all calculated sampling points self.run() def update_trails(self): ''' update latent space representation with interpolation points ''' try: crd = self.get_samples(self.mybox.value, int(self.samplebox.value), self.drop_path.value) self.samples = crd.copy() except Exception: self.button_pdb.disabled = False return # update latent space plot if len(self.samples) == 0: if len(self.waypoints)>0: self.latent.data[2].x = self.waypoints[:, 0] self.latent.data[2].y = self.waypoints[:, 1] else: self.latent.data[2].x = self.samples[:, 0] self.latent.data[2].y = self.samples[:, 1] self.latent.update() def on_click(self, trace, points, selector): ''' control display of training set ''' with self.output: if len(points.xs) == 0: return # add new waypoint to list pt = np.array([[points.xs[0], points.ys[0]]]) if len(self.waypoints) == 0: self.waypoints = pt else: self.waypoints = np.concatenate((self.waypoints, pt)) # update textbox (triggering update of 3D representation) try: pt = self.waypoints.flatten().round(decimals=4).astype(str) # pt = np.array([self.latent.data[3].x, self.latent.data[3].y]).T.flatten().round(decimals=4).astype(str) self.mybox.value = " ".join(pt) except Exception: return self.update_trails() def get_samples(self, mybox, samplebox, path): ''' provide a trail of point between list of waypoints, either connected on a straight line or via a shortest path calculated with the A* algorithm ''' if path == "A*": use_path = True else: use_path = False try: crd = np.array(mybox.split()).astype(float) crd = crd.reshape((int(len(crd)/2), 2)) except Exception: raise Exception("Cannot define sampling points") return if use_path: # connect points via A* try: landscape = self.latent.data[0].z crd = get_path_aggregate(crd, landscape.T, self.MA.xvals, self.MA.yvals) except Exception as e: raise Exception(f"Cannot define sampling points: path finding failed. {e})") return else: # connect points via straight line try: crd = oversample(crd, pts=int(samplebox)) except Exception as e: raise Exception(f"Cannot define sampling points: oversample failed. {e}") return return crd def interact_3D(self, mybox, samplebox, path): ''' generate and display proteins according to latent space trail ''' try: crd = self.get_samples(mybox, samplebox, path) self.samples = crd.copy() crd = crd.reshape((1, len(crd), 2)) except Exception: self.button_pdb.disabled = True return if crd.shape[1] == 0: self.button_pdb.disabled = True return # generate structures along path t = time.time() gen = self.MA.generate(crd) print(f'{crd.shape[1]} struct. in {time.time()-t:.4f} sec.') # display generated structures self.mymol.load_new(gen) view = nv.show_mdanalysis(self.mymol) view.add_representation("spacefill") # view.add_representation("cartoon") display.display(view) self.button_pdb.disabled = False def drop_background_event(self, change): ''' control colouring style of latent space surface ''' if "custom" in change.new: mykey = change.new.split(":")[1] else: mykey = change.new try: data = self.MA.surfaces[mykey] except Exception as e: print(f"{e}") return if np.abs(np.max(data) - np.min(data)) < 100: self.block0.children[1].readout_format = '.1f' else: self.block0.children[1].readout_format = 'd' self.latent.data[0].z = data # step below necessary to avoid situations whereby temporarily min>max try: self.latent.data[0].zmin = np.min(data) self.latent.data[0].zmax = np.max(data) self.block0.children[1].min = np.min(data) self.block0.children[1].max = np.max(data) except Exception: self.latent.data[0].zmax = np.max(data) self.latent.data[0].zmin = np.min(data) self.block0.children[1].max = np.max(data) self.block0.children[1].min = np.min(data) self.block0.children[1].value = (np.min(data), np.max(data)) self.update_trails() def drop_dataset_event(self, change): ''' control which dataset is displayed ''' with self.output: if change.new == "none": self.latent.data[1].x = [] self.latent.data[1].y = [] else: try: data = as_numpy(self.MA.get_encoded(change.new)) except Exception as e: print(f"{e}") return with self.latent.batch_update(): self.latent.data[1].x = data[:, 0] self.latent.data[1].y = data[:, 1] self.latent.data[1].name = change.new self.latent.data[1].visible = True def drop_path_event(self, change): ''' control way paths are looked for ''' if change.new == "A*": self.block0.children[4].disabled = True else: self.block0.children[4].disabled = False self.update_trails() def range_slider_event(self, change): ''' update surface colouring upon manipulation of range slider ''' self.latent.data[0].zmin = change.new[0] self.latent.data[0].zmax = change.new[1] self.latent.update() def trail_update_event(self, change): ''' update trails (waypoints and way they are connected) ''' try: crd = np.array(self.mybox.value.split()).astype(float) crd = crd.reshape((int(len(crd)/2), 2)) except Exception: self.button_pdb.disabled = False return self.waypoints = crd.copy() self.update_trails() def button_pdb_event(self, check): ''' save PDB file corresponding to the interpolation shown in the 3D view ''' root = Tk() root.withdraw() # Hide the main window. root.call('wm', 'attributes', '.', '-topmost', True) # Raise the root to the top of all windows. fname = filedialog.asksaveasfilename(defaultextension="pdb", filetypes=[("PDB file", "pdb")]) if fname == "": return crd = self.get_samples(self.mybox.value, self.samplebox.value, self.drop_path.value) self.samples = crd.copy() crd = crd.reshape((1, len(crd), 2)) if crd.shape[1] == 0: return gen = self.MA.generate(crd) self.mymol.load_new(gen) protein = self.mymol.select_atoms("all") with mda.Writer(fname, protein.n_atoms) as W: for ts in self.mymol.trajectory: W.write(protein) def button_save_state_event(self, check): ''' save class state ''' with self.output: root = Tk() root.withdraw() # Hide the main window. root.call('wm', 'attributes', '.', '-topmost', True) # Raise the root to the top of all windows. fname = filedialog.asksaveasfilename(defaultextension="p", filetypes=[("pickle file", "p")]) if fname == "": return def button_load_state_event(self, check): ''' load class state ''' with self.output: root = Tk() root.withdraw() # Hide the main window. root.call('wm', 'attributes', '.', '-topmost', True) # Raise the root to the top of all windows. fname = filedialog.askopenfilename(defaultextension="p", filetypes=[("picke file", "p")]) if fname == "": return try: self.MA, self.waypoints = pickle.load(open(fname, "rb")) self.run() except Exception as e: raise Exception(f"Cannot load state file. {e}") ##################################################### def run(self): self.output = widgets.Output() # create an MDAnalysis instance of input protein (for viewing purposes) if hasattr(self.MA, "mol"): self.MA.mol.write_pdb("tmp.pdb", conformations=[0], split_struc=False) self.mymol = mda.Universe('tmp.pdb') ### MENU ITEMS ### # surface representation dropdown menu options = [] if self.MA is not None: for f in list(self.MA.surfaces): options.append(f) if len(options)>0: val = options else: val = ["none"] self.drop_background = widgets.Dropdown( options=val, value=val[0], description='Surf.:', layout=Layout(flex='1 1 0%', width='auto')) if len(options) == 0: self.drop_background.disabled = True self.drop_background.observe(self.drop_background_event, names='value') # dataset selector dropdown menu options2 = ["none"] if self.MA is not None: for f in list(self.MA._datasets): if "grid_" not in f: options2.append(f) self.drop_dataset = widgets.Dropdown( options=options2, value=options2[0], description='Dataset:', layout=Layout(flex='1 1 0%', width='auto')) if len(options2) == 1: self.drop_dataset.disabled = True else: self.drop_dataset.disabled = False self.drop_dataset.observe(self.drop_dataset_event, names='value') # pathfinder method dropdown menu self.drop_path = widgets.Dropdown( options=["Euclidean", "A*"], value="Euclidean", description='Path:', layout=Layout(flex='1 1 0%', width='auto')) self.drop_path.observe(self.drop_path_event, names='value') # text box holding current coordinates self.mybox = widgets.Textarea(placeholder='coordinates', description='crds:', disabled=False, layout=Layout(flex='1 1 0%', width='auto')) self.mybox.observe(self.trail_update_event, names='value') # text box holding number of sampling points self.samplebox = widgets.Text(value='10', description='sampling:', disabled=False, layout=Layout(flex='1 1 0%', width='auto')) self.samplebox.observe(self.trail_update_event, names='value') # button to save PDB file self.button_pdb = widgets.Button( description='Save PDB', disabled=True, layout=Layout(flex='1 1 0%', width='auto')) self.button_pdb.on_click(self.button_pdb_event) # button to save state file self.button_save_state = widgets.Button( description='Save state', disabled=False, layout=Layout(flex='1 1 0%', width='auto')) self.button_save_state.on_click(self.button_save_state_event) # button to load state file self.button_load_state = widgets.Button( description='Load state', disabled=False, layout=Layout(flex='1 1 0%', width='auto')) self.button_load_state.on_click(self.button_load_state_event) # latent space range slider self.range_slider = widgets.FloatRangeSlider( description='cmap range:', disabled=True, continuous_update=False, orientation='horizontal', readout=True, readout_format='.1f', layout=Layout(flex='1 1 0%', width='auto')) self.range_slider.observe(self.range_slider_event, names='value') if self.MA is None: self.button_save_state.disabled = True self.button_pdb.disabled = True if self.waypoints == []: self.button_pdb.disabled = True ### LATENT SPACE REPRESENTATION ### # surface if len(options)>0: sc = self.MA.surfaces[options[0]] else: sc = [] self.latent = go.FigureWidget() if len(sc)>0: self.latent.add_heatmap(x=self.MA.xvals, y=self.MA.yvals, z=sc, zmin=np.min(sc), zmax=np.max(sc), colorscale='viridis', name="latent_space") # plot1 = go.Heatmap(x=self.MA.xvals, y=self.MA.yvals, z=sc, zmin=np.min(sc), zmax=np.max(sc), # colorscale='viridis', name="latent_space") else: if self.MA is not None: self.MA.setup_grid(samples=50) xvals, yvals = self.MA.xvals, self.MA.yvals else: xvals = np.linspace(0, 1, 10) yvals = np.linspace(0, 1, 10) surf_empty = np.zeros((len(xvals), len(yvals))) self.latent.add_heatmap(x=xvals, y=yvals, z=surf_empty, opacity=0.0, showscale=False, name="latent_space") # plot1 = go.Heatmap(x=xvals, y=yvals, z=surf_empty, opacity=0.0, showscale=False, name="latent_space") # dataset if self.MA is not None and len(list(self.MA._datasets))>0: mydata = as_numpy(self.MA.get_encoded(options2[1])) color = "white" if len(sc)>0 else "black" self.latent.add_scatter(x=mydata[:, 0].flatten(), y=mydata[:, 1].flatten(), showlegend=False, opacity=0.9, mode="markers", marker=dict(color=color, size=5), name=options2[1], visible=False) # plot2 = go.Scatter(x=mydata[:, 0].flatten(), # y=mydata[:, 1].flatten(), # showlegend=False, opacity=0.9, mode="markers", # marker=dict(color=color, size=5), name=options2[1], visible=False) else: self.latent.add_scatter(x=[], y=[]) # plot2 = go.Scatter(x=[], y=[]) # path self.latent.add_scatter(x=np.array([]), y=np.array([]), showlegend=False, opacity=0.9, mode='lines+markers', marker=dict(color='red', size=4)) # plot3 = go.Scatter(x=np.array([]), y=np.array([]), # showlegend=False, opacity=0.9, mode='lines+markers', # marker=dict(color='red', size=4)) # self.latent = go.FigureWidget([plot1, plot2, plot3]) self.latent.update_layout(xaxis_title="latent vector 1", yaxis_title="latent vector 2", autosize=True, width=400, height=350, margin=dict(l=75, r=0, t=25, b=0)) self.latent.update_xaxes(showspikes=False) self.latent.update_yaxes(showspikes=False) if len(sc)>0: scmin = np.min(sc) scmax = np.max(sc) self.range_slider.value = (scmin, scmax) # step below to avoid situations whereby temporarily min>max try: self.range_slider.min = scmin self.range_slider.max = scmax except Exception: self.range_slider.max = scmax self.range_slider.min = scmin self.range_slider.step = (scmax-scmin)/100.0 self.range_slider.disabled = False # 3D protein representation (triggered by update of textbox, sampling box, or pathfinding method) self.protein = widgets.interactive_output(self.interact_3D, {'mybox': self.mybox, 'samplebox': self.samplebox, 'path': self.drop_path}) ### WIDGETS ARRANGEMENT ### self.block0 = widgets.VBox([self.drop_dataset, self.range_slider, self.drop_background, self.drop_path, self.samplebox, self.mybox, self.button_pdb, self.button_save_state, self.button_load_state], layout=Layout(flex='1 1 2', width='auto', border="solid")) self.block1 = widgets.VBox([self.latent], layout=Layout(flex='1 1 auto', width='auto')) # make all items displayed clickable with self.output: for item in self.latent.data: item.on_click(self.on_click) self.block2 = widgets.VBox([self.protein], layout=Layout(flex='1 5 auto', width='auto')) self.scene = widgets.HBox([self.block0, self.block1, self.block2]) self.scene.layout.align_items = 'center' if len(self.waypoints) > 0: self.mybox.value = " ".join(self.waypoints.flatten().astype(str)) display.clear_output(wait=True) # display.display(self.scene) display.display(self.scene, self.output)