%matplotlib inline
%reload_ext autoreload
%autoreload 2
class
DataSet
[source]
DataSet
(x
:ndarray
,y
:Optional
[ndarray
]=None
,w
:Optional
[ndarray
]=None
)
Class holding indexable input, target and weight data
class
WeightedDataLoader
[source]
WeightedDataLoader
(*args
, **kwds
) ::DataLoader
PyTorch DataLoader with support for optional weights and targets
class
DataPair
[source]
DataPair
(trn_dl
:WeightedDataLoader
,val_dl
:WeightedDataLoader
)
Single class of training and validation data to simplify passing data for model training
n = 105
trn = PseudoData(paper_sig, 1).sample(n)
val = PseudoData(paper_sig, 1).sample(n)
trn
(array([[ 0.2709233 , 1.7907513 , 1.674339 ], [ 0.01646089, -0.7460501 , 0.6675601 ], [-0.23959269, 0.22303306, 0.8689174 ], [ 1.2815254 , 0.7106047 , 0.08947958], [-0.8260263 , 0.22419322, 0.16161819], [ 0.73751366, -0.5809242 , 0.7030808 ], [-0.6777991 , 0.03431615, 1.2666872 ], [-1.0142534 , -1.6662276 , 0.0513282 ], [ 0.51317775, 0.14159556, 0.12035868], [-1.1800369 , 0.8506802 , 0.97652584], [ 1.1993155 , -1.072847 , 0.61093956], [ 0.25319907, 0.77248424, 0.45021617], [-0.46734175, -1.9731585 , 0.01964188], [ 1.8196396 , 1.4788404 , 0.33221117], [-0.20271742, -0.11154626, 1.9630086 ], [-1.2942913 , 0.4507109 , 0.33215162], [-1.0124866 , 0.11795827, 0.18895386], [ 0.67416126, 0.72892964, 0.14242482], [ 2.2467654 , 2.0556061 , 0.42249516], [ 0.00452289, 1.5189322 , 1.3089397 ], [ 0.08481363, 0.18277758, 0.0652803 ], [-0.90703595, -0.27753952, 0.569084 ], [-1.4561021 , 0.14298095, 1.3649249 ], [-0.8529808 , -1.0866692 , 0.43750086], [ 0.5432855 , 0.13593058, 0.68547803], [-0.45232975, 1.1170769 , 0.61820304], [-1.1627467 , 0.86478496, 0.23350638], [-0.06172046, -0.23253925, 0.00733259], [-0.19780564, -0.24313924, 1.1055272 ], [ 0.9458955 , 0.9133286 , 0.08356244], [ 1.352391 , 1.0247046 , 0.7359769 ], [-2.0561 , 0.59478337, 0.7743417 ], [ 0.24981733, 1.7199372 , 0.01900457], [ 0.25799215, 0.833846 , 0.07385374], [ 1.856978 , 0.02050481, 1.0800731 ], [ 0.22542433, -0.09563376, 0.18086645], [-0.38875923, -0.4359761 , 0.399867 ], [-0.57199 , 0.5305414 , 0.73443264], [-0.6141388 , 0.19322483, 0.48166612], [-0.15746458, -0.90498894, 0.39864063], [-0.9140586 , 0.3680044 , 0.426017 ], [ 0.18768083, 0.6604895 , 0.75721216], [ 1.8002547 , -0.01180843, 0.18264978], [-1.289833 , -0.6021565 , 0.43629807], [ 0.2147916 , -0.7965742 , 0.82796824], [ 1.8953354 , 0.0129039 , 0.3998597 ], [ 0.29255927, -0.6748278 , 0.93584406], [ 0.58833283, -0.28720692, 0.5214911 ], [-2.0833862 , -1.7748507 , 0.16057676], [-1.0621885 , -1.2542429 , 0.21099381], [ 1.3584772 , -0.03226585, 1.4647148 ], [ 0.84536374, 0.41062376, 0.07028887], [ 1.4430865 , -0.71170956, 0.2108295 ], [-0.28445148, 1.3497715 , 0.21197194], [-1.1325945 , -0.40913674, 0.02816439], [ 0.2276339 , -1.8691366 , 0.27375263], [-0.2991192 , 0.07881462, 0.5764913 ], [ 0.15207613, 1.3913307 , 1.040601 ], [-0.47383276, 0.14700344, 0.86801696], [-1.1767349 , -0.00886578, 0.34807572], [-0.19519526, -0.27044535, 0.0688965 ], [-0.93481356, -0.68222743, 0.15970075], [ 1.0284197 , -0.2849253 , 0.2366046 ], [ 0.8099784 , -0.49223515, 0.18879302], [-0.3318982 , -0.83862686, 1.1068472 ], [-1.5343724 , -0.51016635, 1.0323029 ], [ 0.7069417 , 0.6758482 , 0.5583674 ], [-0.40763295, 1.3891104 , 0.01530016], [ 0.3446544 , 0.73966247, 0.09296039], [ 0.1901904 , -0.65265316, 0.00857579], [ 1.0367985 , -0.47355187, 0.10772868], [ 0.47425756, -0.30729997, 0.14100465], [ 0.03410796, 1.5316759 , 0.32480058], [ 0.31121576, -1.479379 , 0.39946228], [ 0.21554323, -0.1999662 , 0.82143867], [-0.25494063, -0.17463271, 0.8407937 ], [ 0.2660136 , 2.4576206 , 0.2343428 ], [ 1.2129698 , 2.0038714 , 0.28311628], [ 0.83543515, 0.5551884 , 0.33021584], [-0.74448174, 0.01862068, 1.0495887 ], [ 0.02735333, 0.16676682, 0.07513639], [ 0.9743578 , -1.0512656 , 0.69908047], [ 0.68901646, -1.3249681 , 0.00999293], [ 0.76707995, -1.7836072 , 1.157289 ], [ 2.511846 , -0.23994523, 0.16924003], [ 0.47591376, 0.03348687, 0.12068149], [-1.3176295 , 0.3267708 , 0.29100758], [-0.59481734, -0.23143479, 0.13755725], [-1.3499573 , 0.8282714 , 0.4230522 ], [ 0.0086707 , 0.8173143 , 1.5349022 ], [-0.33481938, -1.0679154 , 0.34164077], [-0.7154699 , -0.06304839, 0.10494501], [-1.6924286 , 0.9839281 , 0.55981714], [-0.43082276, -1.4312286 , 0.37992394], [ 1.7234279 , 0.00332284, 0.00386309], [-0.34244117, 0.05155335, 0.5306555 ], [-0.01136534, 1.7406671 , 0.89349294], [ 0.51440746, 0.5952819 , 0.01716166], [ 1.570294 , 0.37689707, 0.10193883], [-0.40601325, 1.071294 , 0.04981167], [ 0.24173588, 0.23040244, 0.8558464 ], [-0.03069817, -0.4609025 , 0.00453525], [-0.50449324, -0.00548542, 0.6813995 ], [-0.86046463, 0.768944 , 0.6232264 ], [ 0.33241466, 0.8674967 , 0.13564499]], dtype=float32), arraydtype=float32), None)
trn_ds,val_ds = DataSet(*trn),DataSet(*val)
assert len(trn_ds) == n
trn_ds[1]
(tensor([ 0.0165, -0.7461, 0.6676]), tensor([1.]), None)
trn_dl = WeightedDataLoader(trn_ds, batch_size=10, shuffle=True, drop_last=True)
val_dl = WeightedDataLoader(val_ds, batch_size=10, shuffle=False)
next(iter(trn_dl))
(tensor([[-0.8530, -1.0867, 0.4375], [-0.4673, -1.9732, 0.0196], [-0.2549, -0.1746, 0.8408], [ 1.0284, -0.2849, 0.2366], [ 0.2660, 2.4576, 0.2343], [-0.4060, 1.0713, 0.0498], [ 0.2417, 0.2304, 0.8558], [ 0.1521, 1.3913, 1.0406], [-0.2991, 0.0788, 0.5765], [ 1.8196, 1.4788, 0.3322]]), tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]), None)
for i, (xb,yb,wb) in enumerate(trn_dl): print(i, xb, yb)
0 tensor([[ 0.1521, 1.3913, 1.0406], [ 1.5703, 0.3769, 0.1019], [ 1.3524, 1.0247, 0.7360], [ 0.1877, 0.6605, 0.7572], [-1.1800, 0.8507, 0.9765], [ 1.1993, -1.0728, 0.6109], [-0.6141, 0.1932, 0.4817], [-2.0834, -1.7749, 0.1606], [-0.3424, 0.0516, 0.5307], [-0.4308, -1.4312, 0.3799]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 1 tensor([[-0.1978, -0.2431, 1.1055], [ 1.7234, 0.0033, 0.0039], [-1.1767, -0.0089, 0.3481], [ 0.3324, 0.8675, 0.1356], [-0.0617, -0.2325, 0.0073], [-0.0114, 1.7407, 0.8935], [-1.0622, -1.2542, 0.2110], [-1.1627, 0.8648, 0.2335], [ 0.2580, 0.8338, 0.0739], [-1.1326, -0.4091, 0.0282]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 2 tensor([[-1.5344, -0.5102, 1.0323], [-0.0307, -0.4609, 0.0045], [ 0.2148, -0.7966, 0.8280], [-0.5720, 0.5305, 0.7344], [ 0.2709, 1.7908, 1.6743], [ 1.8196, 1.4788, 0.3322], [ 0.2417, 0.2304, 0.8558], [-0.2991, 0.0788, 0.5765], [ 0.7671, -1.7836, 1.1573], [ 1.8570, 0.0205, 1.0801]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 3 tensor([[ 1.2130, 2.0039, 0.2831], [ 0.8354, 0.5552, 0.3302], [ 1.8953, 0.0129, 0.3999], [-0.8530, -1.0867, 0.4375], [ 0.2926, -0.6748, 0.9358], [ 0.2254, -0.0956, 0.1809], [ 1.4431, -0.7117, 0.2108], [ 0.0045, 1.5189, 1.3089], [-0.8260, 0.2242, 0.1616], [-1.2898, -0.6022, 0.4363]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 4 tensor([[ 0.5433, 0.1359, 0.6855], [-1.3176, 0.3268, 0.2910], [ 1.8003, -0.0118, 0.1826], [ 0.2532, 0.7725, 0.4502], [ 0.0274, 0.1668, 0.0751], [-0.3888, -0.4360, 0.3999], [ 0.9459, 0.9133, 0.0836], [ 0.6742, 0.7289, 0.1424], [-0.7445, 0.0186, 1.0496], [-0.2396, 0.2230, 0.8689]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 5 tensor([[-0.9348, -0.6822, 0.1597], [ 0.0848, 0.1828, 0.0653], [ 0.1902, -0.6527, 0.0086], [ 1.2815, 0.7106, 0.0895], [ 2.5118, -0.2399, 0.1692], [-0.5948, -0.2314, 0.1376], [ 0.4759, 0.0335, 0.1207], [ 0.8100, -0.4922, 0.1888], [-0.6778, 0.0343, 1.2667], [-0.4076, 1.3891, 0.0153]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 6 tensor([[-1.0143, -1.6662, 0.0513], [ 0.5144, 0.5953, 0.0172], [-0.1952, -0.2704, 0.0689], [-0.5045, -0.0055, 0.6814], [ 0.9744, -1.0513, 0.6991], [-0.2027, -0.1115, 1.9630], [-1.3500, 0.8283, 0.4231], [-2.0561, 0.5948, 0.7743], [ 1.3585, -0.0323, 1.4647], [ 0.3112, -1.4794, 0.3995]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 7 tensor([[ 0.7375, -0.5809, 0.7031], [ 0.3447, 0.7397, 0.0930], [-1.4561, 0.1430, 1.3649], [ 1.0368, -0.4736, 0.1077], [ 2.2468, 2.0556, 0.4225], [-0.2549, -0.1746, 0.8408], [ 0.0165, -0.7461, 0.6676], [-0.2845, 1.3498, 0.2120], [ 0.2155, -0.2000, 0.8214], [-1.2943, 0.4507, 0.3322]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 8 tensor([[ 1.0284, -0.2849, 0.2366], [-0.8605, 0.7689, 0.6232], [ 0.0341, 1.5317, 0.3248], [ 0.5883, -0.2872, 0.5215], [-0.4523, 1.1171, 0.6182], [ 0.5132, 0.1416, 0.1204], [ 0.6890, -1.3250, 0.0100], [-0.7155, -0.0630, 0.1049], [-0.1575, -0.9050, 0.3986], [-0.3319, -0.8386, 1.1068]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]) 9 tensor([[-0.9141, 0.3680, 0.4260], [ 0.8454, 0.4106, 0.0703], [-0.4673, -1.9732, 0.0196], [ 0.7069, 0.6758, 0.5584], [-1.6924, 0.9839, 0.5598], [ 0.2660, 2.4576, 0.2343], [ 0.4743, -0.3073, 0.1410], [-0.3348, -1.0679, 0.3416], [-0.4738, 0.1470, 0.8680], [ 0.2498, 1.7199, 0.0190]]) tensor([[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]])
data = DataPair(trn_dl, val_dl)
data.trn_ds
<__main__.DataSet at 0x7ff67d7418d0>
get_paper_data
[source]
get_paper_data
(n
:int
,bs
=2000
,n_test
:int
=0
)
Function returning training, validation and testing data according to pseudodata used in INFERNO paper
n = 10
data = get_paper_data(n)
assert len(data.trn_ds) == len(data.val_ds) == n
data, test = get_paper_data(n,n_test=2*n)
assert len(data.trn_ds) == len(data.val_ds) == 0.5*len(test.dataset) == 10