import os
import glob
import shutil
import math
import numpy as np
import time
import torch
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional, Tuple
from molearn.data import PDBData
import json
@dataclass
class TrainerProgress:
best_loss: Optional[float] = None
best_checkpoint: Optional[str] = None
epochs_since_improve: int = 0
converged: bool = False
repeat_index: int = 0
@dataclass
class FitResult:
epochs_run: int
best_loss: Optional[float]
best_checkpoint: Optional[str]
last_checkpoint: Optional[str]
log_file: Optional[str]
final_metrics: Dict[str, float] = field(default_factory=dict)
class TrainingFailure(Exception):
pass
[docs]
class Trainer:
"""
Trainer class that defines a number of useful methods for training an autoencoder.
:ivar autoencoder: any torch.nn.module network that has methods ``autoencoder.encode`` and ``autoencoder.decode`` with the weights associated with these operations accessible via ``autoencoder.encoder`` and ``autoencoder.decoder``. This can be set using set_autoencoder
:ivar _autoencoder_kwargs: kwargs used to initialise the network. Saved in every checkpoint under the key 'kwargs'
:ivar torch.optim.optimiser optimiser: pytorch optimiser with access to self.autoencoder.parameters()
:ivar torch.Device device: The device used for all operations.
:ivar int epoch: the current epoch
:ivar TrainerProgress progress: Tracks best validation loss, checkpoint metadata, and convergence state.
:ivar float std: Standard deviation of the training dataset. Can be used to unscale structures produced by the network.
:ivar float mol: Biobox molecule containing a single example frame of the protein being trained on. This can be used to save examples during training. It is also used to save a temporary pdb that may be used to initialise thirdparty packages.
:ivar torch.Dataloader train_dataloader: Training data
:ivar torch.Dataloader valid_dataloader: Validation data
:ivar _data: (:func:`molearn.data <molearn.data.PDData>` Data object given to :func:`set_data <molearn.trainers.Trainer.set_data>`
"""
def __init__(self, device=None, json_log=False):
"""
:param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available()
:param str log_filename: (default: 'default_log_filename.json') file used to log outputs to
:param bool json_log: True to use json.dump to save the log file
"""
if not device:
self.device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
else:
self.device = device
print(f"device: {self.device}")
self.progress = TrainerProgress()
self.epoch = 0
self.verbose = True
self.json_log = json_log
self.log_folder = "fn_checkpoints"
self.log_filename = "log.dat"
self.checkpoint_folder = "fn_checkpoints"
self._last_log_file = None
self._last_checkpoint = None
self._repeat = 0
def _reset_progress(self) -> None:
"""Reset convergence tracking without touching the global epoch."""
self.progress.best_loss = None
self.progress.best_checkpoint = None
self.progress.epochs_since_improve = 0
self.progress.converged = False
def _ensure_setup(self) -> None:
"""Validate that the trainer has all prerequisites before fitting."""
missing = []
if not hasattr(self, "autoencoder"):
missing.append("autoencoder")
if not hasattr(self, "optimiser"):
missing.append("optimiser")
if not hasattr(self, "train_dataloader"):
missing.append("train_dataloader")
if not hasattr(self, "valid_dataloader"):
missing.append("valid_dataloader")
if missing:
raise RuntimeError(
"Trainer is missing required attributes: " + ", ".join(missing)
)
def _run_phase(
self,
dataloader,
step_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
prefix: str,
*,
backward: bool,
dry_run: bool = False,
) -> Tuple[Dict[str, float], Dict[str, float]]:
"""Shared mini-batch loop for training/validation phases."""
totals: Dict[str, float] = {}
count = 0
for batch in dataloader:
batch = batch[0].to(self.device)
batch_size = batch.shape[0]
if backward and not dry_run:
self.optimiser.zero_grad()
outputs = step_fn(batch)
outputs["loss"].backward()
self.optimiser.step()
else:
with torch.no_grad():
outputs = step_fn(batch)
for key, value in outputs.items():
totals[key] = totals.get(key, 0.0) + value.item() * batch_size
count += batch_size
if count == 0:
return {}, {}
averaged = {key: totals[key] / count for key in totals}
prefixed = {f"{prefix}_{key}": averaged[key] for key in averaged}
return prefixed, averaged
def _update_progress(self, loss: float, patience: Optional[int]) -> bool:
"""Update best checkpoint tracking and convergence status."""
patience = patience if patience is not None else math.inf
if self.progress.best_loss is None or loss < self.progress.best_loss:
self.progress.best_loss = float(loss)
self.progress.epochs_since_improve = 0
is_best = True
else:
self.progress.epochs_since_improve += 1
if self.progress.epochs_since_improve > patience:
self.progress.converged = True
is_best = False
return is_best
def _train_loop(
self,
*,
max_epochs: Optional[int],
patience: Optional[int],
log_filename: str,
log_folder: str,
checkpoint_folder: str,
verbose: Optional[bool],
) -> FitResult:
"""Core optimisation loop used by all public run methods."""
if max_epochs is None and patience is None:
raise ValueError("Provide at least one stopping criterion (epochs or patience)")
self._ensure_setup()
self._reset_progress()
self.get_repeat(checkpoint_folder)
log_file = self.prepare_logs(log_filename, log_folder)
if verbose is not None:
self.verbose = verbose
epochs_run = 0
max_epochs = max_epochs if max_epochs is not None else math.inf
logs: Dict[str, float] = {}
while epochs_run < max_epochs and not self.progress.converged:
time1 = time.time()
train_logs, _ = self._run_phase(
self.train_dataloader,
self.train_step,
"train",
backward=True,
dry_run=False,
)
time2 = time.time()
valid_logs, raw_valid = self._run_phase(
self.valid_dataloader,
self.valid_step,
"valid",
backward=False,
dry_run=False,
)
time3 = time.time()
logs = {**train_logs, **valid_logs}
logs.update(
epoch=self.epoch,
train_seconds=time2 - time1,
valid_seconds=time3 - time2,
)
valid_loss = logs.get("valid_loss")
if valid_loss is None:
raise ValueError("valid_step must return a 'loss' entry")
is_best = self._update_progress(valid_loss, patience)
self.checkpoint(
self.epoch,
logs,
checkpoint_folder,
is_best,
has_converged=self.progress.converged,
)
time4 = time.time()
logs.update(
checkpoint_seconds=time4 - time3,
total_seconds=time4 - time1,
)
self.log(log_file, logs)
if np.isnan(logs["valid_loss"]) or np.isnan(logs["train_loss"]):
raise TrainingFailure("nan received, failing")
self.epoch += 1
epochs_run += 1
self.results_epoch = raw_valid
if max_epochs is math.inf and patience is None and epochs_run >= 1:
break # safety valve when neither stopping criterion is provided
return FitResult(
epochs_run=epochs_run,
best_loss=self.progress.best_loss,
best_checkpoint=self.progress.best_checkpoint,
last_checkpoint=self._last_checkpoint,
log_file=log_file,
final_metrics=logs if "logs" in locals() else {},
)
[docs]
def fit(
self,
data: PDBData,
autoencoder,
*,
autoencoder_kwargs: Optional[Dict[str, Any]] = None,
optimiser_kwargs: Optional[Dict[str, Any]] = None,
data_kwargs: Optional[Dict[str, Any]] = None,
epochs: Optional[int] = None,
patience: Optional[int] = None,
log_filename: str = "log.dat",
log_folder: str = "checkpoint_folder",
checkpoint_folder: str = "checkpoint_folder",
verbose: Optional[bool] = None,
) -> FitResult:
"""End-to-end helper that wires data, model, optimiser, and training.
:returns: :class:`FitResult` describing the executed training procedure.
"""
if autoencoder_kwargs is None:
autoencoder_kwargs = {}
if optimiser_kwargs is None:
optimiser_kwargs = {}
if data_kwargs is None:
data_kwargs = {}
self.set_autoencoder(autoencoder, **autoencoder_kwargs)
self.set_data(data, **data_kwargs)
self.prepare_optimiser(**optimiser_kwargs)
return self._train_loop(
max_epochs=epochs,
patience=patience,
log_filename=log_filename,
log_folder=log_folder,
checkpoint_folder=checkpoint_folder,
verbose=verbose,
)
[docs]
def get_network_summary(self):
"""
returns a dictionary containing information about the size of the autoencoder.
"""
def get_parameters(trainable_only, model):
return sum(
p.numel()
for p in model.parameters()
if (p.requires_grad and trainable_only)
)
return dict(
encoder_trainable=get_parameters(True, self.autoencoder.encoder),
encoder_total=get_parameters(False, self.autoencoder.encoder),
decoder_trainable=get_parameters(True, self.autoencoder.decoder),
decoder_total=get_parameters(False, self.autoencoder.decoder),
autoencoder_trainable=get_parameters(True, self.autoencoder),
autoencoder_total=get_parameters(False, self.autoencoder),
)
[docs]
def set_autoencoder(self, autoencoder, **kwargs):
"""
:param autoencoder: (:func:`autoencoder <molearn.models>`,) torch network class that implements ``autoencoder.encode``, and ``autoencoder.decode``. Please pass the class not the instance
:param \*\*kwargs: any other kwargs given to this method will be used to initialise the network ``self.autoencoder = autoencoder(**kwargs)``
"""
if isinstance(autoencoder, type):
self.autoencoder = autoencoder(**kwargs).to(self.device)
else:
self.autoencoder = autoencoder.to(self.device)
self._autoencoder_kwargs = kwargs
[docs]
def set_dataloader(self, train_dataloader=None, valid_dataloader=None):
"""
:param torch.DataLoader train_dataloader: Alternatively set using ``trainer.train_dataloader = dataloader``
:param torch.DataLoader valid_dataloader: Alternatively set using ``trainer.valid_dataloader = dataloader``
"""
if train_dataloader is not None:
self.train_dataloader = train_dataloader
if valid_dataloader is not None:
self.valid_dataloader = valid_dataloader
[docs]
def set_data(self, data, **kwargs):
"""
Sets up internal variables and gives trainer access to dataloaders.
``self.train_dataloader``, ``self.valid_dataloader``, ``self.std``, ``self.mean``, ``self.mol`` will all be obtained from this object.
:param :func:`PDBData <molearn.data.PDBData>` data: data object to be set.
:param \*\*kwargs: will be passed on to :func:`data.get_dataloader(**kwargs) <molearn.data.PDBData.get_dataloader>`
"""
self.set_dataloader(*data.get_dataloader(**kwargs))
self._data = data
self.mol = data.mol
self.standardize = data.standardize
self.std = data.std
self.mean = data.mean
self.n_idx = data.indices['N'].to(device=self.device)
self.ca_idx = data.indices['CA'].to(device=self.device)
self.c_idx = data.indices['C'].to(device=self.device)
self.o_idx = data.indices['O'].to(device=self.device)
self.cb_idx = data.indices['CB'].to(device=self.device)
self.cb_valid_idx = data.indices['CB'][data.indices['CB'] >= 0].to(device=self.device)
[docs]
def prepare_optimiser(self, lr=1e-3, weight_decay=0.0001, **optimiser_kwargs):
"""
The Default optimiser is ``AdamW`` and is saved in ``self.optimiser``.
With no optional arguments this function is the same as doing:
``trainer.optimiser = torch.optim.AdawW(self.autoencoder.parameters(), lr=1e-3, weight_decay = 0.0001)``
:param float lr: (default: 1e-3) optimiser learning rate.
:param float weight_decay: (default: 0.0001) optimiser weight_decay
:param \*\*optimiser_kwargs: other kwargs that are passed onto AdamW
"""
self.optimiser = torch.optim.AdamW(
self.autoencoder.parameters(),
lr=lr,
weight_decay=weight_decay,
**optimiser_kwargs,
)
[docs]
def log(self, log_file, log_dict, verbose=None):
"""
Then contents of log_dict are dumped using ``json.dumps(log_dict)`` and printed and/or appended to ``self.log_filename``
This function is called from :func:`self.run <molearn.trainers.Trainer.run>`
:param str log_file: file to which the log_dict will be saved. If the file does not exist it will be created.
:param dict log_dict: dictionary to be printed or saved
:param bool verbose: if True or self.verbose is true the output will be printed
"""
if verbose or self.verbose:
max_key_len = max([len(k) for k in log_dict.keys()])
if "epoch" in log_dict:
cur_epoch = log_dict["epoch"]
print(f"{'epoch': <{max_key_len+1}}: {cur_epoch}")
for k, v in log_dict.items():
if k != "epoch":
print(f"{k: <{max_key_len+1}}: {v:.6f}")
print()
if not self.json_log:
# create header if file doesn't exist => first epoch
if not os.path.isfile(log_file):
with open(log_file, "a") as f:
f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n")
with open(log_file, "a") as f:
# just try to format if it is not a Failure
if "Failure" not in log_dict.values():
f.write(f"{','.join([str(v) for v in log_dict.values()])}\n")
else:
dump = json.dumps(log_dict)
f.write(dump + "\n")
else:
dump = json.dumps(log_dict)
with open(log_file, "a") as f:
f.write(dump + "\n")
[docs]
def prepare_logs(self, log_filename, log_folder):
"""
:param str log_filename: (default: 'log.dat') The name of the log file.
:param str log_folder: (default: 'checkpoint_folder') The folder in which the log file will be saved.
:returns: The full path to the log file.
"""
if not os.path.exists(log_folder):
os.mkdir(log_folder)
if hasattr(self, "_repeat") and self._repeat > 0:
log_filename = f"{log_folder}/{self.log_filename}_{self._repeat}"
else:
log_filename = f"{log_folder}/{self.log_filename}"
self._last_log_file = log_filename
return log_filename
[docs]
def run_until_converge(
self,
patience=16,
log_filename="log.dat",
log_folder="checkpoint_folder",
checkpoint_folder="checkpoint_folder",
verbose=None,
):
"""
Train until convergence. This method will call :func:`Trainer.run <molearn.trainers.Trainer.run>` in a loop until the training has converged.
The training will stop when the validation loss has not improved for ``patience`` epochs.
:param int patience: how many epochs to wait before stopping training if the validation loss has not improved.
:param str log_filename: the name of the log file.
:param str log_folder: the folder in which the log file will be saved.
:param str checkpoint_folder: the folder in which the checkpoint will be saved.
:param bool verbose: if True, the epoch logs will be printed as well as written to log_filename
:returns: :class:`FitResult` with information about the training run.
"""
return self._train_loop(
max_epochs=None,
patience=patience,
log_filename=log_filename,
log_folder=log_folder,
checkpoint_folder=checkpoint_folder,
verbose=verbose,
)
[docs]
def run(
self,
epochs=100,
log_filename="log.dat",
log_folder="checkpoint_folder",
checkpoint_folder="checkpoint_folder",
verbose=None,
):
"""
Calls the following in a loop:
- :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`
- :func:`Trainer.valid_epoch <molearn.trainers.Trainer.valid_epoch>`
- :func:`Trainer.checkpoint <molearn.trainers.Trainer.checkpoint>`
- :func:`Trainer.checkpoint <molearn.trainers.Trainer.checkpoint>`
- :func:`Trainer.log <molearn.trainers.Trainer.log>`
:param int epochs: (default: 100). run this many epochs.
:param str log_filename: (default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created.
:param str log_folder: (default: None) If not None log_folder directory is created and the log file is saved within this folder
:param int checkpoint_frequency: (default: 1) The frequency at which last.ckpt is saved. A checkpoint is saved every epoch if ``'valid_loss'`` is lower else when ``self.epoch`` is divisible by checkpoint_frequency.
:param str checkpoint_folder: (default: 'checkpoint_folder') Where to save checkpoints.
:param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename
:returns: :class:`FitResult` with information about the training run.
"""
return self._train_loop(
max_epochs=epochs,
patience=None,
log_filename=log_filename,
log_folder=log_folder,
checkpoint_folder=checkpoint_folder,
verbose=verbose,
)
[docs]
def dry_run(self,
log_filename="dry_run.dat",
log_folder="tmp_folder",):
"""
Pass the training and validation set through the network once without updating model parameters.
This is useful for debugging or restoring statistics for models loaded from checkpoint.
:param str log_filename: the name of the log file.
:param str log_folder: the folder in which the log file will be saved.
"""
self.verbose = True
self.autoencoder.eval()
log_file = self.prepare_logs(log_filename, log_folder)
with torch.no_grad():
time1 = time.time()
logs = self.train_epoch(dry_run=True)
time2 = time.time()
logs.update(self.valid_epoch())
time3 = time.time()
logs.update(
epoch=self.epoch,
train_seconds=time2 - time1,
valid_seconds=time3 - time2,
total_seconds=time3 - time1,
)
self.log(log_file, logs)
if np.isnan(logs["valid_loss"]) or np.isnan(logs["train_loss"]):
raise TrainingFailure("nan received, failing")
[docs]
def train_epoch(self, dry_run=False):
"""
Train one epoch. Called once an epoch from :func:`trainer.run <molearn.trainers.Trainer.run>`
This method performs the following functions:
- Sets network to train mode via ``self.autoencoder.train()``
- for each batch in self.train_dataloader implements typical pytorch training protocol:
* zero gradients with call ``self.optimiser.zero_grad()``
* Use training implemented in trainer.train_step ``result = self.train_step(batch)``
* Determine gradients using keyword ``'loss'`` e.g. ``result['loss'].backward()``
* Update network gradients. ``self.optimiser.step``
- All results are aggregated via averaging and returned with ``'train_'`` prepended on the dictionary key
:param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
:returns: Return all results from train_step averaged. These results will be printed and/or logged in :func:`trainer.run() <molearn.trainers.Trainer.run>` via a call to :func:`self.log(results) <molearn.trainers.Trainer.log>`
:rtype: dict
"""
self.autoencoder.train()
prefixed, _ = self._run_phase(
self.train_dataloader,
self.train_step,
"train",
backward=True,
dry_run=dry_run,
)
return prefixed
[docs]
def train_step(self, batch):
"""
Called from :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`.
:param torch.Tensor batch: Tensor of shape [Batch size, Number of Atoms, 3]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch <molearn.trainers.Trainer.train_epoch>` will call ``result['loss'].backwards()`` to obtain gradients.
:rtype: dict
"""
results = self.common_step(batch)
results["loss"] = results["mse_loss"]
return results
[docs]
def common_step(self, batch):
"""
Called from both train_step and valid_step.
Calculates the mean squared error loss for self.autoencoder.
Encoded and decoded frames are saved in self._internal under keys ``encoded`` and ``decoded`` respectively should you wish to use them elsewhere.
:param torch.Tensor batch: Tensor of shape [Batch size, Number of Atoms, 3] A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return calculated mse_loss
:rtype: dict
"""
self._internal = {}
encoded = self.autoencoder.encode(batch)
self._internal["encoded"] = encoded
decoded = self.autoencoder.decode(encoded)[:, : batch.size(1), :]
self._internal["decoded"] = decoded
return dict(mse_loss=((batch - decoded) ** 2).mean())
[docs]
def valid_epoch(self):
"""
Called once an epoch from :func:`trainer.run <molearn.trainers.Trainer.run>` within a no_grad context.
This method performs the following functions:
- Sets network to eval mode via ``self.autoencoder.eval()``
- for each batch in ``self.valid_dataloader`` calls :func:`trainer.valid_step <molearn.trainers.Trainer.valid_step>` to retrieve validation loss
- All results are aggregated via averaging and returned with ``'valid_'`` prepended on the dictionary key
* The loss with key ``'loss'`` is returned as ``'valid_loss'`` this will be the loss value by which the best checkpoint is determined.
:param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch.
:returns: Return all results from valid_step averaged. These results will be printed and/or logged in :func:`Trainer.run() <molearn.trainers.Trainer.run>` via a call to :func:`self.log(results) <molearn.trainers.Trainer.log>`
:rtype: dict
"""
self.autoencoder.eval()
prefixed, averaged = self._run_phase(
self.valid_dataloader,
self.valid_step,
"valid",
backward=False,
)
self.results_epoch = averaged
return prefixed
[docs]
def valid_step(self, batch):
"""
Called from :func:`Trainer.valid_epoch<molearn.trainer.Trainer.valid_epoch>` on every mini-batch.
:param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that will be the score via which the best checkpoint is determined.
:rtype: dict
"""
results = self.common_step(batch)
results["loss"] = results["mse_loss"]
return results
def update_hyperparameters(self, **kwargs):
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError(f"Trainer has no attribute {key}")
self._reset_progress()
[docs]
def update_optimiser_hyperparameters(self, **kwargs):
"""
Update optimeser hyperparameter e.g. ``trainer.update_optimiser_hyperparameters(lr = 1e3)``
:param \*\*kwargs: each key value pair in \*\*kwargs is inserted into ``self.optimiser``
"""
for g in self.optimiser.param_groups:
for key, value in kwargs.items():
g[key] = value
self._reset_progress()
[docs]
def checkpoint(self, epoch, valid_logs, checkpoint_folder, is_best, has_converged=False):
"""
Checkpoint the current network. The checkpoint will be saved as ``'last.ckpt'``.
If valid_logs[loss_key] is better than self.progress.best_loss then this checkpoint will replace self.progress.best_checkpoint and ``'last.ckpt'`` will be renamed to ``f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'`` and the former best (filename saved as ``self.progress.best_checkpoint``) will be deleted
:param int epoch: current epoch, will be saved within the ckpt. Current epoch can usually be obtained with ``self.epoch``
:param dict valid_logs: results dictionary containing loss_key.
:param str checkpoint_folder: The folder in which to save the checkpoint.
:param bool is_best: if the current checkpoint is the best so far, this will be used to determine if the checkpoint should be renamed to ``f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'``.
:param bool has_converged: if True, the checkpoint will be saved as ``f'{checkpoint_folder}/checkpoint_converged.ckpt'``. This is used to indicate that training has converged.
"""
valid_loss = valid_logs["valid_loss"]
last_checkpoint_path = f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt'
torch.save(
{
"epoch": epoch,
"model_state_dict": self.autoencoder.state_dict(),
"optimizer_state_dict": self.optimiser.state_dict(),
"loss": valid_loss,
"network_kwargs": self._autoencoder_kwargs,
"atoms": self._data.atoms,
"std": self.std,
"mean": self.mean,
},
last_checkpoint_path,
)
self._last_checkpoint = last_checkpoint_path
if is_best:
filename = f'{checkpoint_folder}/checkpoint{f"_{self._repeat}" if self._repeat>0 else ""}_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(
last_checkpoint_path,
filename,
)
if self.progress.best_checkpoint is not None and os.path.exists(self.progress.best_checkpoint):
os.remove(self.progress.best_checkpoint)
self.progress.best_checkpoint = filename
if has_converged:
filename = f'{checkpoint_folder}/checkpoint_converged.ckpt'
if self.progress.best_checkpoint is None:
raise RuntimeError("No best checkpoint available to mark as converged")
shutil.copyfile(self.progress.best_checkpoint, filename)
[docs]
def load_checkpoint(
self, checkpoint_name="best", checkpoint_folder="", load_optimiser=True
):
"""
Load checkpoint.
:param str checkpoint_name: (default: ``'best'``) if ``'best'`` then checkpoint_folder is searched for all files beginning with ``'checkpoint_'`` and loss values are extracted from the filename by assuming all characters after ``'loss'`` and before ``'.ckpt'`` are a float. The checkpoint with the lowest loss is loaded. checkpoint_name is not ``'best'`` we search for a checkpoint file at ``f'{checkpoint_folder}/{checkpoint_name}'``.
:param str checkpoint_folder: Folder whithin which to search for checkpoints.
:param bool load_optimiser: (default: True) Should optimiser state dictionary be loaded.
"""
if checkpoint_name == "best":
if self.progress.best_checkpoint is not None:
_name = self.progress.best_checkpoint
else:
ckpts = glob.glob(checkpoint_folder + "/checkpoint_*")
indexs = [x.rfind("loss") for x in ckpts]
losses = [float(x[y + 4 : -5]) for x, y in zip(ckpts, indexs)]
_name = ckpts[np.argmin(losses)]
elif checkpoint_name == "last":
_name = f"{checkpoint_folder}/last.ckpt"
else:
_name = f"{checkpoint_folder}/{checkpoint_name}"
checkpoint = torch.load(_name, map_location=self.device, weights_only=False)
if not hasattr(self, "autoencoder"):
raise NotImplementedError(
"self.autoencoder does not exist, I have no way of knowing what network you want to load checkoint weights into yet, please set the network first"
)
self.autoencoder.load_state_dict(checkpoint["model_state_dict"])
if load_optimiser:
if not hasattr(self, "optimiser"):
raise NotImplementedError(
"self.optimiser does not exist, I have no way of knowing what optimiser you previously used, please set it first."
)
self.optimiser.load_state_dict(checkpoint["optimizer_state_dict"])
self.epoch = checkpoint["epoch"] + 1
self.std = checkpoint["std"]
self.mean = checkpoint["mean"]
def get_repeat(self, checkpoint_folder):
if not os.path.exists(checkpoint_folder):
os.makedirs(checkpoint_folder)
if not hasattr(self, "_repeat"):
self._repeat = 0
for i in range(1000):
if not os.path.exists(
checkpoint_folder
+ f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt'
):
break # os.mkdir(checkpoint_folder)
else:
self._repeat += 1
else:
raise Exception(
"Something went wrong, you surely havnt done 1000 repeats?"
)
self.progress.repeat_index = self._repeat
[docs]
def get_scale(
self, ref_loss: float, tar_loss: float, scale_scale: float = 1.0
):
"""
Get a scaling factor to scale a loss to be in the same order of magnitude like `cur_mse_loss`
:param float target_loss: the reference loss
:param float tar_loss: the loss that should be scaled to be in the same order of magnitude as the `target_loss`
:param float scale_scale: scale to in-/ decrease the scale further
:return float scaling_factor: the calculated scaling factor for the `loss_to_scale`
"""
with torch.no_grad():
mag_ref = math.floor(math.log10(abs(ref_loss)))
mag_new = math.floor(math.log10(abs(tar_loss)))
return 10 ** (mag_ref - mag_new) * scale_scale
if __name__ == "__main__":
pass