%matplotlib inline
%reload_ext autoreload
%autoreload 2

class ModelWrapper[source]

ModelWrapper(model:Module, device:device=device(type='cpu'))

Class to handle training and prediction of NN over data, with optional callbacks. Also supports loading and saving.

Testing

from pytorch_inferno.callback import LossTracker, EarlyStopping
from pytorch_inferno.data import get_paper_data

from fastcore.all import partialler
n = 1000
model = ModelWrapper(nn.Sequential(nn.Linear(3,50),nn.ReLU(),nn.Linear(50,1),nn.Sigmoid()))
data, test = get_paper_data(n, bs=64, n_test=n)
model.fit(10, data=data, opt=partialler(optim.SGD,lr=2e-3), loss=nn.BCELoss(),
          cbs=[LossTracker(),EarlyStopping(5)])
Train: 0.7777950843175252 Valid: 0.7710541779994965
Train: 0.739490024248759 Valid: 0.7274892117977142
Train: 0.6959971030553181 Valid: 0.6971303887367248
Train: 0.6743967692057292 Valid: 0.6751605319976807
Train: 0.6553388277689616 Valid: 0.6592334063053131
Train: 0.641307270526886 Valid: 0.6469841578006744
Train: 0.6313625574111938 Valid: 0.6365882155895233
Train: 0.625999140739441 Valid: 0.627690737247467
Train: 0.6166404604911804 Valid: 0.6199122014045715
Train: 0.6116449395815532 Valid: 0.6127951893806457
preds = model.predict(test)
100.00% [8/8 00:00<00:00]
assert len(preds) == n
preds.shape
(1000, 1)
preds = model.predict(test.dataset.x)
100.00% [1/1 00:00<00:00]
assert len(preds) == n
preds.shape
(1000, 1)