%matplotlib inline
%reload_ext autoreload
%autoreload 2
class DataSet[source]
DataSet(x:ndarray,y:Optional[ndarray]=None,w:Optional[ndarray]=None)
Class holding indexable input, target and weight data
class WeightedDataLoader[source]
WeightedDataLoader(*args, **kwds) ::DataLoader
PyTorch DataLoader with support for optional weights and targets
class DataPair[source]
DataPair(trn_dl:WeightedDataLoader,val_dl:WeightedDataLoader)
Single class of training and validation data to simplify passing data for model training
n = 105
trn = PseudoData(paper_sig, 1).sample(n)
val = PseudoData(paper_sig, 1).sample(n)
trn
(array([[ 0.2709233 , 1.7907513 , 1.674339 ],
[ 0.01646089, -0.7460501 , 0.6675601 ],
[-0.23959269, 0.22303306, 0.8689174 ],
[ 1.2815254 , 0.7106047 , 0.08947958],
[-0.8260263 , 0.22419322, 0.16161819],
[ 0.73751366, -0.5809242 , 0.7030808 ],
[-0.6777991 , 0.03431615, 1.2666872 ],
[-1.0142534 , -1.6662276 , 0.0513282 ],
[ 0.51317775, 0.14159556, 0.12035868],
[-1.1800369 , 0.8506802 , 0.97652584],
[ 1.1993155 , -1.072847 , 0.61093956],
[ 0.25319907, 0.77248424, 0.45021617],
[-0.46734175, -1.9731585 , 0.01964188],
[ 1.8196396 , 1.4788404 , 0.33221117],
[-0.20271742, -0.11154626, 1.9630086 ],
[-1.2942913 , 0.4507109 , 0.33215162],
[-1.0124866 , 0.11795827, 0.18895386],
[ 0.67416126, 0.72892964, 0.14242482],
[ 2.2467654 , 2.0556061 , 0.42249516],
[ 0.00452289, 1.5189322 , 1.3089397 ],
[ 0.08481363, 0.18277758, 0.0652803 ],
[-0.90703595, -0.27753952, 0.569084 ],
[-1.4561021 , 0.14298095, 1.3649249 ],
[-0.8529808 , -1.0866692 , 0.43750086],
[ 0.5432855 , 0.13593058, 0.68547803],
[-0.45232975, 1.1170769 , 0.61820304],
[-1.1627467 , 0.86478496, 0.23350638],
[-0.06172046, -0.23253925, 0.00733259],
[-0.19780564, -0.24313924, 1.1055272 ],
[ 0.9458955 , 0.9133286 , 0.08356244],
[ 1.352391 , 1.0247046 , 0.7359769 ],
[-2.0561 , 0.59478337, 0.7743417 ],
[ 0.24981733, 1.7199372 , 0.01900457],
[ 0.25799215, 0.833846 , 0.07385374],
[ 1.856978 , 0.02050481, 1.0800731 ],
[ 0.22542433, -0.09563376, 0.18086645],
[-0.38875923, -0.4359761 , 0.399867 ],
[-0.57199 , 0.5305414 , 0.73443264],
[-0.6141388 , 0.19322483, 0.48166612],
[-0.15746458, -0.90498894, 0.39864063],
[-0.9140586 , 0.3680044 , 0.426017 ],
[ 0.18768083, 0.6604895 , 0.75721216],
[ 1.8002547 , -0.01180843, 0.18264978],
[-1.289833 , -0.6021565 , 0.43629807],
[ 0.2147916 , -0.7965742 , 0.82796824],
[ 1.8953354 , 0.0129039 , 0.3998597 ],
[ 0.29255927, -0.6748278 , 0.93584406],
[ 0.58833283, -0.28720692, 0.5214911 ],
[-2.0833862 , -1.7748507 , 0.16057676],
[-1.0621885 , -1.2542429 , 0.21099381],
[ 1.3584772 , -0.03226585, 1.4647148 ],
[ 0.84536374, 0.41062376, 0.07028887],
[ 1.4430865 , -0.71170956, 0.2108295 ],
[-0.28445148, 1.3497715 , 0.21197194],
[-1.1325945 , -0.40913674, 0.02816439],
[ 0.2276339 , -1.8691366 , 0.27375263],
[-0.2991192 , 0.07881462, 0.5764913 ],
[ 0.15207613, 1.3913307 , 1.040601 ],
[-0.47383276, 0.14700344, 0.86801696],
[-1.1767349 , -0.00886578, 0.34807572],
[-0.19519526, -0.27044535, 0.0688965 ],
[-0.93481356, -0.68222743, 0.15970075],
[ 1.0284197 , -0.2849253 , 0.2366046 ],
[ 0.8099784 , -0.49223515, 0.18879302],
[-0.3318982 , -0.83862686, 1.1068472 ],
[-1.5343724 , -0.51016635, 1.0323029 ],
[ 0.7069417 , 0.6758482 , 0.5583674 ],
[-0.40763295, 1.3891104 , 0.01530016],
[ 0.3446544 , 0.73966247, 0.09296039],
[ 0.1901904 , -0.65265316, 0.00857579],
[ 1.0367985 , -0.47355187, 0.10772868],
[ 0.47425756, -0.30729997, 0.14100465],
[ 0.03410796, 1.5316759 , 0.32480058],
[ 0.31121576, -1.479379 , 0.39946228],
[ 0.21554323, -0.1999662 , 0.82143867],
[-0.25494063, -0.17463271, 0.8407937 ],
[ 0.2660136 , 2.4576206 , 0.2343428 ],
[ 1.2129698 , 2.0038714 , 0.28311628],
[ 0.83543515, 0.5551884 , 0.33021584],
[-0.74448174, 0.01862068, 1.0495887 ],
[ 0.02735333, 0.16676682, 0.07513639],
[ 0.9743578 , -1.0512656 , 0.69908047],
[ 0.68901646, -1.3249681 , 0.00999293],
[ 0.76707995, -1.7836072 , 1.157289 ],
[ 2.511846 , -0.23994523, 0.16924003],
[ 0.47591376, 0.03348687, 0.12068149],
[-1.3176295 , 0.3267708 , 0.29100758],
[-0.59481734, -0.23143479, 0.13755725],
[-1.3499573 , 0.8282714 , 0.4230522 ],
[ 0.0086707 , 0.8173143 , 1.5349022 ],
[-0.33481938, -1.0679154 , 0.34164077],
[-0.7154699 , -0.06304839, 0.10494501],
[-1.6924286 , 0.9839281 , 0.55981714],
[-0.43082276, -1.4312286 , 0.37992394],
[ 1.7234279 , 0.00332284, 0.00386309],
[-0.34244117, 0.05155335, 0.5306555 ],
[-0.01136534, 1.7406671 , 0.89349294],
[ 0.51440746, 0.5952819 , 0.01716166],
[ 1.570294 , 0.37689707, 0.10193883],
[-0.40601325, 1.071294 , 0.04981167],
[ 0.24173588, 0.23040244, 0.8558464 ],
[-0.03069817, -0.4609025 , 0.00453525],
[-0.50449324, -0.00548542, 0.6813995 ],
[-0.86046463, 0.768944 , 0.6232264 ],
[ 0.33241466, 0.8674967 , 0.13564499]], dtype=float32), array([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]], dtype=float32), None)
trn_ds,val_ds = DataSet(*trn),DataSet(*val)
assert len(trn_ds) == n
trn_ds[1]
(tensor([ 0.0165, -0.7461, 0.6676]), tensor([1.]), None)
trn_dl = WeightedDataLoader(trn_ds, batch_size=10, shuffle=True, drop_last=True)
val_dl = WeightedDataLoader(val_ds, batch_size=10, shuffle=False)
next(iter(trn_dl))
(tensor([[-0.8530, -1.0867, 0.4375],
[-0.4673, -1.9732, 0.0196],
[-0.2549, -0.1746, 0.8408],
[ 1.0284, -0.2849, 0.2366],
[ 0.2660, 2.4576, 0.2343],
[-0.4060, 1.0713, 0.0498],
[ 0.2417, 0.2304, 0.8558],
[ 0.1521, 1.3913, 1.0406],
[-0.2991, 0.0788, 0.5765],
[ 1.8196, 1.4788, 0.3322]]), tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]]), None)
for i, (xb,yb,wb) in enumerate(trn_dl): print(i, xb, yb)
0 tensor([[ 0.1521, 1.3913, 1.0406],
[ 1.5703, 0.3769, 0.1019],
[ 1.3524, 1.0247, 0.7360],
[ 0.1877, 0.6605, 0.7572],
[-1.1800, 0.8507, 0.9765],
[ 1.1993, -1.0728, 0.6109],
[-0.6141, 0.1932, 0.4817],
[-2.0834, -1.7749, 0.1606],
[-0.3424, 0.0516, 0.5307],
[-0.4308, -1.4312, 0.3799]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
1 tensor([[-0.1978, -0.2431, 1.1055],
[ 1.7234, 0.0033, 0.0039],
[-1.1767, -0.0089, 0.3481],
[ 0.3324, 0.8675, 0.1356],
[-0.0617, -0.2325, 0.0073],
[-0.0114, 1.7407, 0.8935],
[-1.0622, -1.2542, 0.2110],
[-1.1627, 0.8648, 0.2335],
[ 0.2580, 0.8338, 0.0739],
[-1.1326, -0.4091, 0.0282]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
2 tensor([[-1.5344, -0.5102, 1.0323],
[-0.0307, -0.4609, 0.0045],
[ 0.2148, -0.7966, 0.8280],
[-0.5720, 0.5305, 0.7344],
[ 0.2709, 1.7908, 1.6743],
[ 1.8196, 1.4788, 0.3322],
[ 0.2417, 0.2304, 0.8558],
[-0.2991, 0.0788, 0.5765],
[ 0.7671, -1.7836, 1.1573],
[ 1.8570, 0.0205, 1.0801]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
3 tensor([[ 1.2130, 2.0039, 0.2831],
[ 0.8354, 0.5552, 0.3302],
[ 1.8953, 0.0129, 0.3999],
[-0.8530, -1.0867, 0.4375],
[ 0.2926, -0.6748, 0.9358],
[ 0.2254, -0.0956, 0.1809],
[ 1.4431, -0.7117, 0.2108],
[ 0.0045, 1.5189, 1.3089],
[-0.8260, 0.2242, 0.1616],
[-1.2898, -0.6022, 0.4363]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
4 tensor([[ 0.5433, 0.1359, 0.6855],
[-1.3176, 0.3268, 0.2910],
[ 1.8003, -0.0118, 0.1826],
[ 0.2532, 0.7725, 0.4502],
[ 0.0274, 0.1668, 0.0751],
[-0.3888, -0.4360, 0.3999],
[ 0.9459, 0.9133, 0.0836],
[ 0.6742, 0.7289, 0.1424],
[-0.7445, 0.0186, 1.0496],
[-0.2396, 0.2230, 0.8689]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
5 tensor([[-0.9348, -0.6822, 0.1597],
[ 0.0848, 0.1828, 0.0653],
[ 0.1902, -0.6527, 0.0086],
[ 1.2815, 0.7106, 0.0895],
[ 2.5118, -0.2399, 0.1692],
[-0.5948, -0.2314, 0.1376],
[ 0.4759, 0.0335, 0.1207],
[ 0.8100, -0.4922, 0.1888],
[-0.6778, 0.0343, 1.2667],
[-0.4076, 1.3891, 0.0153]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
6 tensor([[-1.0143, -1.6662, 0.0513],
[ 0.5144, 0.5953, 0.0172],
[-0.1952, -0.2704, 0.0689],
[-0.5045, -0.0055, 0.6814],
[ 0.9744, -1.0513, 0.6991],
[-0.2027, -0.1115, 1.9630],
[-1.3500, 0.8283, 0.4231],
[-2.0561, 0.5948, 0.7743],
[ 1.3585, -0.0323, 1.4647],
[ 0.3112, -1.4794, 0.3995]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
7 tensor([[ 0.7375, -0.5809, 0.7031],
[ 0.3447, 0.7397, 0.0930],
[-1.4561, 0.1430, 1.3649],
[ 1.0368, -0.4736, 0.1077],
[ 2.2468, 2.0556, 0.4225],
[-0.2549, -0.1746, 0.8408],
[ 0.0165, -0.7461, 0.6676],
[-0.2845, 1.3498, 0.2120],
[ 0.2155, -0.2000, 0.8214],
[-1.2943, 0.4507, 0.3322]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
8 tensor([[ 1.0284, -0.2849, 0.2366],
[-0.8605, 0.7689, 0.6232],
[ 0.0341, 1.5317, 0.3248],
[ 0.5883, -0.2872, 0.5215],
[-0.4523, 1.1171, 0.6182],
[ 0.5132, 0.1416, 0.1204],
[ 0.6890, -1.3250, 0.0100],
[-0.7155, -0.0630, 0.1049],
[-0.1575, -0.9050, 0.3986],
[-0.3319, -0.8386, 1.1068]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
9 tensor([[-0.9141, 0.3680, 0.4260],
[ 0.8454, 0.4106, 0.0703],
[-0.4673, -1.9732, 0.0196],
[ 0.7069, 0.6758, 0.5584],
[-1.6924, 0.9839, 0.5598],
[ 0.2660, 2.4576, 0.2343],
[ 0.4743, -0.3073, 0.1410],
[-0.3348, -1.0679, 0.3416],
[-0.4738, 0.1470, 0.8680],
[ 0.2498, 1.7199, 0.0190]]) tensor([[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.]])
data = DataPair(trn_dl, val_dl)
data.trn_ds
<__main__.DataSet at 0x7ff67d7418d0>
get_paper_data[source]
get_paper_data(n:int,bs=2000,n_test:int=0)
Function returning training, validation and testing data according to pseudodata used in INFERNO paper
n = 10
data = get_paper_data(n)
assert len(data.trn_ds) == len(data.val_ds) == n
data, test = get_paper_data(n,n_test=2*n)
assert len(data.trn_ds) == len(data.val_ds) == 0.5*len(test.dataset) == 10