%matplotlib inline
%reload_ext autoreload
%autoreload 2

class DataSet[source]

DataSet(x:ndarray, y:Optional[ndarray]=None, w:Optional[ndarray]=None)

Class holding indexable input, target and weight data

class WeightedDataLoader[source]

WeightedDataLoader(*args, **kwds) :: DataLoader

PyTorch DataLoader with support for optional weights and targets

class DataPair[source]

DataPair(trn_dl:WeightedDataLoader, val_dl:WeightedDataLoader)

Single class of training and validation data to simplify passing data for model training

n = 105
trn = PseudoData(paper_sig, 1).sample(n)
val = PseudoData(paper_sig, 1).sample(n)
trn
(array([[ 0.2709233 ,  1.7907513 ,  1.674339  ],
        [ 0.01646089, -0.7460501 ,  0.6675601 ],
        [-0.23959269,  0.22303306,  0.8689174 ],
        [ 1.2815254 ,  0.7106047 ,  0.08947958],
        [-0.8260263 ,  0.22419322,  0.16161819],
        [ 0.73751366, -0.5809242 ,  0.7030808 ],
        [-0.6777991 ,  0.03431615,  1.2666872 ],
        [-1.0142534 , -1.6662276 ,  0.0513282 ],
        [ 0.51317775,  0.14159556,  0.12035868],
        [-1.1800369 ,  0.8506802 ,  0.97652584],
        [ 1.1993155 , -1.072847  ,  0.61093956],
        [ 0.25319907,  0.77248424,  0.45021617],
        [-0.46734175, -1.9731585 ,  0.01964188],
        [ 1.8196396 ,  1.4788404 ,  0.33221117],
        [-0.20271742, -0.11154626,  1.9630086 ],
        [-1.2942913 ,  0.4507109 ,  0.33215162],
        [-1.0124866 ,  0.11795827,  0.18895386],
        [ 0.67416126,  0.72892964,  0.14242482],
        [ 2.2467654 ,  2.0556061 ,  0.42249516],
        [ 0.00452289,  1.5189322 ,  1.3089397 ],
        [ 0.08481363,  0.18277758,  0.0652803 ],
        [-0.90703595, -0.27753952,  0.569084  ],
        [-1.4561021 ,  0.14298095,  1.3649249 ],
        [-0.8529808 , -1.0866692 ,  0.43750086],
        [ 0.5432855 ,  0.13593058,  0.68547803],
        [-0.45232975,  1.1170769 ,  0.61820304],
        [-1.1627467 ,  0.86478496,  0.23350638],
        [-0.06172046, -0.23253925,  0.00733259],
        [-0.19780564, -0.24313924,  1.1055272 ],
        [ 0.9458955 ,  0.9133286 ,  0.08356244],
        [ 1.352391  ,  1.0247046 ,  0.7359769 ],
        [-2.0561    ,  0.59478337,  0.7743417 ],
        [ 0.24981733,  1.7199372 ,  0.01900457],
        [ 0.25799215,  0.833846  ,  0.07385374],
        [ 1.856978  ,  0.02050481,  1.0800731 ],
        [ 0.22542433, -0.09563376,  0.18086645],
        [-0.38875923, -0.4359761 ,  0.399867  ],
        [-0.57199   ,  0.5305414 ,  0.73443264],
        [-0.6141388 ,  0.19322483,  0.48166612],
        [-0.15746458, -0.90498894,  0.39864063],
        [-0.9140586 ,  0.3680044 ,  0.426017  ],
        [ 0.18768083,  0.6604895 ,  0.75721216],
        [ 1.8002547 , -0.01180843,  0.18264978],
        [-1.289833  , -0.6021565 ,  0.43629807],
        [ 0.2147916 , -0.7965742 ,  0.82796824],
        [ 1.8953354 ,  0.0129039 ,  0.3998597 ],
        [ 0.29255927, -0.6748278 ,  0.93584406],
        [ 0.58833283, -0.28720692,  0.5214911 ],
        [-2.0833862 , -1.7748507 ,  0.16057676],
        [-1.0621885 , -1.2542429 ,  0.21099381],
        [ 1.3584772 , -0.03226585,  1.4647148 ],
        [ 0.84536374,  0.41062376,  0.07028887],
        [ 1.4430865 , -0.71170956,  0.2108295 ],
        [-0.28445148,  1.3497715 ,  0.21197194],
        [-1.1325945 , -0.40913674,  0.02816439],
        [ 0.2276339 , -1.8691366 ,  0.27375263],
        [-0.2991192 ,  0.07881462,  0.5764913 ],
        [ 0.15207613,  1.3913307 ,  1.040601  ],
        [-0.47383276,  0.14700344,  0.86801696],
        [-1.1767349 , -0.00886578,  0.34807572],
        [-0.19519526, -0.27044535,  0.0688965 ],
        [-0.93481356, -0.68222743,  0.15970075],
        [ 1.0284197 , -0.2849253 ,  0.2366046 ],
        [ 0.8099784 , -0.49223515,  0.18879302],
        [-0.3318982 , -0.83862686,  1.1068472 ],
        [-1.5343724 , -0.51016635,  1.0323029 ],
        [ 0.7069417 ,  0.6758482 ,  0.5583674 ],
        [-0.40763295,  1.3891104 ,  0.01530016],
        [ 0.3446544 ,  0.73966247,  0.09296039],
        [ 0.1901904 , -0.65265316,  0.00857579],
        [ 1.0367985 , -0.47355187,  0.10772868],
        [ 0.47425756, -0.30729997,  0.14100465],
        [ 0.03410796,  1.5316759 ,  0.32480058],
        [ 0.31121576, -1.479379  ,  0.39946228],
        [ 0.21554323, -0.1999662 ,  0.82143867],
        [-0.25494063, -0.17463271,  0.8407937 ],
        [ 0.2660136 ,  2.4576206 ,  0.2343428 ],
        [ 1.2129698 ,  2.0038714 ,  0.28311628],
        [ 0.83543515,  0.5551884 ,  0.33021584],
        [-0.74448174,  0.01862068,  1.0495887 ],
        [ 0.02735333,  0.16676682,  0.07513639],
        [ 0.9743578 , -1.0512656 ,  0.69908047],
        [ 0.68901646, -1.3249681 ,  0.00999293],
        [ 0.76707995, -1.7836072 ,  1.157289  ],
        [ 2.511846  , -0.23994523,  0.16924003],
        [ 0.47591376,  0.03348687,  0.12068149],
        [-1.3176295 ,  0.3267708 ,  0.29100758],
        [-0.59481734, -0.23143479,  0.13755725],
        [-1.3499573 ,  0.8282714 ,  0.4230522 ],
        [ 0.0086707 ,  0.8173143 ,  1.5349022 ],
        [-0.33481938, -1.0679154 ,  0.34164077],
        [-0.7154699 , -0.06304839,  0.10494501],
        [-1.6924286 ,  0.9839281 ,  0.55981714],
        [-0.43082276, -1.4312286 ,  0.37992394],
        [ 1.7234279 ,  0.00332284,  0.00386309],
        [-0.34244117,  0.05155335,  0.5306555 ],
        [-0.01136534,  1.7406671 ,  0.89349294],
        [ 0.51440746,  0.5952819 ,  0.01716166],
        [ 1.570294  ,  0.37689707,  0.10193883],
        [-0.40601325,  1.071294  ,  0.04981167],
        [ 0.24173588,  0.23040244,  0.8558464 ],
        [-0.03069817, -0.4609025 ,  0.00453525],
        [-0.50449324, -0.00548542,  0.6813995 ],
        [-0.86046463,  0.768944  ,  0.6232264 ],
        [ 0.33241466,  0.8674967 ,  0.13564499]], dtype=float32), array([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], dtype=float32), None)
trn_ds,val_ds = DataSet(*trn),DataSet(*val)
assert len(trn_ds) == n
trn_ds[1]
(tensor([ 0.0165, -0.7461,  0.6676]), tensor([1.]), None)
trn_dl = WeightedDataLoader(trn_ds, batch_size=10, shuffle=True, drop_last=True)
val_dl = WeightedDataLoader(val_ds, batch_size=10, shuffle=False)
next(iter(trn_dl))
(tensor([[-0.8530, -1.0867,  0.4375],
         [-0.4673, -1.9732,  0.0196],
         [-0.2549, -0.1746,  0.8408],
         [ 1.0284, -0.2849,  0.2366],
         [ 0.2660,  2.4576,  0.2343],
         [-0.4060,  1.0713,  0.0498],
         [ 0.2417,  0.2304,  0.8558],
         [ 0.1521,  1.3913,  1.0406],
         [-0.2991,  0.0788,  0.5765],
         [ 1.8196,  1.4788,  0.3322]]), tensor([[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]), None)
for i, (xb,yb,wb) in enumerate(trn_dl): print(i, xb, yb)
0 tensor([[ 0.1521,  1.3913,  1.0406],
        [ 1.5703,  0.3769,  0.1019],
        [ 1.3524,  1.0247,  0.7360],
        [ 0.1877,  0.6605,  0.7572],
        [-1.1800,  0.8507,  0.9765],
        [ 1.1993, -1.0728,  0.6109],
        [-0.6141,  0.1932,  0.4817],
        [-2.0834, -1.7749,  0.1606],
        [-0.3424,  0.0516,  0.5307],
        [-0.4308, -1.4312,  0.3799]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
1 tensor([[-0.1978, -0.2431,  1.1055],
        [ 1.7234,  0.0033,  0.0039],
        [-1.1767, -0.0089,  0.3481],
        [ 0.3324,  0.8675,  0.1356],
        [-0.0617, -0.2325,  0.0073],
        [-0.0114,  1.7407,  0.8935],
        [-1.0622, -1.2542,  0.2110],
        [-1.1627,  0.8648,  0.2335],
        [ 0.2580,  0.8338,  0.0739],
        [-1.1326, -0.4091,  0.0282]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
2 tensor([[-1.5344, -0.5102,  1.0323],
        [-0.0307, -0.4609,  0.0045],
        [ 0.2148, -0.7966,  0.8280],
        [-0.5720,  0.5305,  0.7344],
        [ 0.2709,  1.7908,  1.6743],
        [ 1.8196,  1.4788,  0.3322],
        [ 0.2417,  0.2304,  0.8558],
        [-0.2991,  0.0788,  0.5765],
        [ 0.7671, -1.7836,  1.1573],
        [ 1.8570,  0.0205,  1.0801]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
3 tensor([[ 1.2130,  2.0039,  0.2831],
        [ 0.8354,  0.5552,  0.3302],
        [ 1.8953,  0.0129,  0.3999],
        [-0.8530, -1.0867,  0.4375],
        [ 0.2926, -0.6748,  0.9358],
        [ 0.2254, -0.0956,  0.1809],
        [ 1.4431, -0.7117,  0.2108],
        [ 0.0045,  1.5189,  1.3089],
        [-0.8260,  0.2242,  0.1616],
        [-1.2898, -0.6022,  0.4363]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
4 tensor([[ 0.5433,  0.1359,  0.6855],
        [-1.3176,  0.3268,  0.2910],
        [ 1.8003, -0.0118,  0.1826],
        [ 0.2532,  0.7725,  0.4502],
        [ 0.0274,  0.1668,  0.0751],
        [-0.3888, -0.4360,  0.3999],
        [ 0.9459,  0.9133,  0.0836],
        [ 0.6742,  0.7289,  0.1424],
        [-0.7445,  0.0186,  1.0496],
        [-0.2396,  0.2230,  0.8689]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
5 tensor([[-0.9348, -0.6822,  0.1597],
        [ 0.0848,  0.1828,  0.0653],
        [ 0.1902, -0.6527,  0.0086],
        [ 1.2815,  0.7106,  0.0895],
        [ 2.5118, -0.2399,  0.1692],
        [-0.5948, -0.2314,  0.1376],
        [ 0.4759,  0.0335,  0.1207],
        [ 0.8100, -0.4922,  0.1888],
        [-0.6778,  0.0343,  1.2667],
        [-0.4076,  1.3891,  0.0153]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
6 tensor([[-1.0143, -1.6662,  0.0513],
        [ 0.5144,  0.5953,  0.0172],
        [-0.1952, -0.2704,  0.0689],
        [-0.5045, -0.0055,  0.6814],
        [ 0.9744, -1.0513,  0.6991],
        [-0.2027, -0.1115,  1.9630],
        [-1.3500,  0.8283,  0.4231],
        [-2.0561,  0.5948,  0.7743],
        [ 1.3585, -0.0323,  1.4647],
        [ 0.3112, -1.4794,  0.3995]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
7 tensor([[ 0.7375, -0.5809,  0.7031],
        [ 0.3447,  0.7397,  0.0930],
        [-1.4561,  0.1430,  1.3649],
        [ 1.0368, -0.4736,  0.1077],
        [ 2.2468,  2.0556,  0.4225],
        [-0.2549, -0.1746,  0.8408],
        [ 0.0165, -0.7461,  0.6676],
        [-0.2845,  1.3498,  0.2120],
        [ 0.2155, -0.2000,  0.8214],
        [-1.2943,  0.4507,  0.3322]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
8 tensor([[ 1.0284, -0.2849,  0.2366],
        [-0.8605,  0.7689,  0.6232],
        [ 0.0341,  1.5317,  0.3248],
        [ 0.5883, -0.2872,  0.5215],
        [-0.4523,  1.1171,  0.6182],
        [ 0.5132,  0.1416,  0.1204],
        [ 0.6890, -1.3250,  0.0100],
        [-0.7155, -0.0630,  0.1049],
        [-0.1575, -0.9050,  0.3986],
        [-0.3319, -0.8386,  1.1068]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
9 tensor([[-0.9141,  0.3680,  0.4260],
        [ 0.8454,  0.4106,  0.0703],
        [-0.4673, -1.9732,  0.0196],
        [ 0.7069,  0.6758,  0.5584],
        [-1.6924,  0.9839,  0.5598],
        [ 0.2660,  2.4576,  0.2343],
        [ 0.4743, -0.3073,  0.1410],
        [-0.3348, -1.0679,  0.3416],
        [-0.4738,  0.1470,  0.8680],
        [ 0.2498,  1.7199,  0.0190]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
data = DataPair(trn_dl, val_dl)
data.trn_ds
<__main__.DataSet at 0x7ff67d7418d0>

Paper data

get_paper_data[source]

get_paper_data(n:int, bs=2000, n_test:int=0)

Function returning training, validation and testing data according to pseudodata used in INFERNO paper

n = 10
data = get_paper_data(n)
assert len(data.trn_ds) == len(data.val_ds) == n
data, test = get_paper_data(n,n_test=2*n)
assert len(data.trn_ds) == len(data.val_ds) == 0.5*len(test.dataset) == 10