%matplotlib inline
%reload_ext autoreload
%autoreload 2

class AbsCallback[source]

AbsCallback()

Abstract callback passing though all action points and indicating where callbacks can affect the model. See ModelWrapper etc. to see where exactly these action points are called.

class LossTracker[source]

LossTracker(loss_is_meaned:bool=True) :: AbsCallback

Tracks training and validation losses during training. Losses are assumed to be averaged and will be re-averaged over the epoch unless loss_is_meaned is false.

class EarlyStopping[source]

EarlyStopping(patience:int, loss_is_meaned:bool=True) :: AbsCallback

Tracks validation loss during training and terminates training if loss doesn't decrease after patience number of epochs. Losses are assumed to be averaged and will be re-averaged over the epoch unless loss_is_meaned is false.

class SaveBest[source]

SaveBest(savename:Union[str, Path], auto_reload:bool=True, loss_is_meaned:bool=True) :: AbsCallback

Tracks validation loss during training and automatically saves a copy of the weights to indicated file whenever validation loss decreases. Losses are assumed to be averaged and will be re-averaged over the epoch unless loss_is_meaned is false.

class PredHandler[source]

PredHandler() :: AbsCallback

Default callback for predictions. Collects predictions over batches and returns them as stacked array

class PaperSystMod[source]

PaperSystMod(r:float=0, l:float=3) :: AbsCallback

Prediction callback for modifying input data from INFERNO paper according to specified nuisances.

class GradClip[source]

GradClip(clip:float, clip_norm:bool=True) :: AbsCallback

Training callback implementing gradient clipping