diff --git a/output/pretrained/GRU_state_data.pth b/output/pretrained/GRU_state_data.pth new file mode 100644 index 0000000..7cf2198 Binary files /dev/null and b/output/pretrained/GRU_state_data.pth differ diff --git a/output/pretrained/LSTM_state_data.pth b/output/pretrained/LSTM_state_data.pth new file mode 100644 index 0000000..7cce356 Binary files /dev/null and b/output/pretrained/LSTM_state_data.pth differ diff --git a/output/pretrained/RNN_state_data.pth b/output/pretrained/RNN_state_data.pth new file mode 100644 index 0000000..8621394 Binary files /dev/null and b/output/pretrained/RNN_state_data.pth differ diff --git a/output/pretrained/TCN_state_data.pth b/output/pretrained/TCN_state_data.pth new file mode 100644 index 0000000..df9886b Binary files /dev/null and b/output/pretrained/TCN_state_data.pth differ diff --git a/output/pretrained/TE_state_data.pth b/output/pretrained/TE_state_data.pth new file mode 100644 index 0000000..a5b6329 Binary files /dev/null and b/output/pretrained/TE_state_data.pth differ diff --git a/predict.ipynb b/predict.ipynb index d2294d8..6b26893 100644 --- a/predict.ipynb +++ b/predict.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -156,14 +156,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([1, 8, 116])\n" + "Loaded data of size: torch.Size([1, 8, 116])\n" ] } ], @@ -180,307 +180,24 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'class': 'RNN', 'args': {'hidden_size': 512, 'input_size': 116, 'dropout': 0.0}, 'state_dict': OrderedDict([('rnn.weight_ih_l0', tensor([[-0.0067, 0.0513, 0.0022, ..., 0.0505, 0.0357, -0.0072],\n", - " [-0.0040, 0.0515, 0.0603, ..., 0.0076, 0.0066, 0.0903],\n", - " [-0.0774, 0.0698, -0.0212, ..., 0.0375, 0.0144, 0.1040],\n", - " ...,\n", - " [-0.0374, 0.0139, 0.0079, ..., 0.0131, 0.0327, 0.0791],\n", - " [ 0.0031, -0.0372, -0.0067, ..., -0.0415, -0.0025, -0.0009],\n", - " [ 0.0070, -0.0163, 0.0259, ..., -0.0180, 0.0130, -0.0185]],\n", - " device='cuda:0')), ('rnn.weight_hh_l0', tensor([[-0.0209, 0.0219, 0.0452, ..., 0.0315, -0.0160, -0.0151],\n", - " [ 0.0197, -0.0233, 0.0225, ..., -0.0175, -0.0293, -0.0221],\n", - " [ 0.0273, 0.0355, -0.0377, ..., 0.0149, 0.0264, -0.0339],\n", - " ...,\n", - " [ 0.0442, 0.0426, -0.0126, ..., 0.0084, 0.0014, -0.0122],\n", - " [-0.0286, 0.0230, -0.0102, ..., 0.0086, -0.0029, -0.0285],\n", - " [ 0.0363, -0.0273, 0.0279, ..., -0.0157, -0.0305, 0.0143]],\n", - " device='cuda:0')), ('rnn.bias_ih_l0', tensor([ 1.8111e-02, -2.6781e-02, -4.0634e-03, 1.6743e-02, -1.0715e-02,\n", - " 1.6265e-02, -3.2948e-02, -6.8702e-03, 1.4201e-02, 2.3823e-02,\n", - " 3.2213e-02, 1.2273e-02, -2.6945e-02, -2.6195e-02, 3.4248e-02,\n", - " 2.2180e-02, 9.6105e-03, -3.5814e-02, -1.8378e-02, -7.7094e-04,\n", - " -9.4703e-04, 3.3616e-02, 1.9444e-02, -8.7114e-03, 3.2739e-02,\n", - " 9.9509e-03, -6.2903e-03, 1.3472e-02, -2.0024e-02, -3.2941e-02,\n", - " 1.0864e-02, -3.4586e-02, -1.3443e-03, 1.2753e-02, -7.4560e-03,\n", - " -3.0947e-03, 5.6593e-03, -2.2091e-02, -3.0970e-02, -2.5012e-02,\n", - " -3.6564e-02, 3.6434e-02, 2.9558e-02, 3.0064e-02, -1.9407e-02,\n", - " -1.9144e-02, 4.5596e-03, 2.0145e-02, 1.6468e-02, -2.4431e-02,\n", - " 2.3784e-03, 2.9147e-02, 3.7242e-02, 8.1427e-03, 1.0948e-03,\n", - " -3.4101e-02, 1.8542e-02, 3.5021e-02, -2.5783e-02, 9.3484e-03,\n", - " 2.4564e-02, -3.0976e-02, -1.7374e-03, -1.5906e-02, -2.5417e-02,\n", - " 7.2955e-03, 3.3041e-02, -2.0101e-02, -3.7639e-02, -1.4605e-02,\n", - " -2.6385e-02, -1.9132e-02, -2.5175e-02, 1.8294e-02, 3.6208e-02,\n", - " 3.3664e-02, -5.0429e-03, 3.9298e-03, 2.4817e-02, -1.1454e-02,\n", - " -2.6925e-02, -8.7122e-04, 2.5935e-02, -9.8170e-03, 7.2944e-05,\n", - " -2.7697e-02, -1.2776e-02, -1.1467e-02, -2.4725e-02, 2.5394e-02,\n", - " -3.4422e-02, 8.9304e-03, -4.5742e-03, -1.7344e-02, -9.3744e-04,\n", - " 2.9017e-02, -3.1157e-02, -1.6356e-02, -1.7084e-02, -2.1292e-02,\n", - " 2.2797e-02, -2.6110e-02, 2.2817e-02, 1.0459e-02, -1.5681e-02,\n", - " -1.9369e-02, 1.4659e-02, -2.0604e-02, -2.3151e-02, 3.3874e-03,\n", - " -1.5110e-02, -1.6369e-02, -7.0151e-03, -1.8548e-02, -2.6604e-02,\n", - " -6.5618e-03, -3.6241e-02, -3.1688e-02, 2.5403e-05, -3.2969e-02,\n", - " 1.2356e-02, 2.0256e-02, 9.6857e-04, 3.8288e-02, -2.5391e-02,\n", - " -3.3861e-02, 1.7714e-02, 3.1606e-02, 2.4646e-02, -2.1113e-02,\n", - " 2.2660e-02, -2.8754e-03, -5.5579e-03, -2.7844e-03, -2.8713e-02,\n", - " 1.1034e-02, -5.9582e-03, -1.5935e-02, 3.9038e-02, -2.8265e-02,\n", - " -2.4773e-02, -2.4171e-02, -8.6292e-03, -3.7858e-02, 5.8170e-03,\n", - " 7.3982e-03, 3.0470e-02, 1.6998e-02, -1.3035e-02, -9.9555e-04,\n", - " 6.4096e-04, 2.7425e-02, 3.3348e-03, 1.8546e-02, -2.2431e-02,\n", - " -1.3993e-02, 1.9206e-02, -3.2559e-02, -1.8681e-03, 6.6771e-03,\n", - " 1.4759e-02, -1.6621e-02, 1.7838e-02, 1.5146e-03, 1.4538e-02,\n", - " 2.1896e-02, -8.3440e-03, 2.8657e-02, 2.3448e-02, -1.9553e-02,\n", - " 2.3599e-02, -2.0774e-02, -3.5980e-03, 1.3032e-03, 4.0859e-02,\n", - " -3.0607e-02, 2.6080e-02, -7.6080e-03, 2.1783e-03, 1.9371e-02,\n", - " -2.2019e-02, 1.4630e-02, -1.1102e-02, 3.7803e-02, 2.8926e-02,\n", - " 3.0758e-02, -9.1583e-03, 2.7705e-02, -2.1965e-03, 3.1656e-02,\n", - " 1.8851e-02, 2.7763e-02, 9.9924e-03, -6.7093e-03, -3.6289e-02,\n", - " -1.5082e-04, -1.0631e-02, -5.8395e-03, -3.2592e-02, 2.3484e-03,\n", - " 3.7604e-02, 1.6621e-03, 1.0569e-02, -3.9286e-02, 1.5574e-02,\n", - " -1.1596e-02, 2.1687e-02, -4.4336e-02, 2.5966e-02, -2.0243e-02,\n", - " 2.1614e-02, -8.6071e-03, 4.1544e-03, 1.0928e-02, 4.1057e-03,\n", - " 1.5348e-02, -3.1164e-03, -2.9517e-02, 4.2468e-02, -1.5965e-02,\n", - " -3.0531e-02, 2.7999e-02, 1.5462e-02, -2.1790e-02, -2.2933e-02,\n", - " 3.6634e-02, 3.0408e-02, -3.3588e-02, -2.9649e-03, -3.0793e-02,\n", - " 3.4174e-02, -1.2204e-02, -4.1533e-03, 2.0980e-02, -1.9647e-02,\n", - " -1.3871e-02, -3.8874e-04, 2.0256e-02, 1.9322e-02, -3.7720e-02,\n", - " -1.1272e-02, -1.4170e-02, -2.0485e-02, 2.3206e-02, -7.5841e-03,\n", - " -6.0333e-03, -3.2532e-02, -1.0999e-02, -5.7598e-03, -2.3579e-02,\n", - " -1.7936e-02, 3.6662e-02, -3.7169e-03, 1.5431e-02, 5.8733e-03,\n", - " 3.3884e-02, -5.9398e-04, -2.4953e-02, -2.5950e-02, 1.9242e-03,\n", - " -2.5873e-02, 4.0144e-02, -1.2363e-02, -2.8360e-02, 9.1527e-03,\n", - " -1.3067e-02, 2.8207e-02, 9.3364e-03, 1.2728e-02, 4.4359e-02,\n", - " -3.2946e-02, 1.7478e-02, -3.1623e-02, -2.3094e-02, 3.1974e-02,\n", - " -9.5789e-03, 2.8828e-02, 1.8058e-02, 1.6583e-02, -1.1617e-02,\n", - " -1.9145e-02, -3.0421e-02, -2.8697e-02, -3.5240e-03, -2.7886e-02,\n", - " -1.1412e-02, 2.6778e-02, -3.1331e-02, -1.9696e-02, 2.4519e-02,\n", - " 9.5947e-03, 4.3017e-02, 2.1170e-03, 2.1445e-02, 3.2667e-02,\n", - " -3.3909e-02, -1.7720e-02, -1.3568e-02, 4.1741e-02, -3.7873e-02,\n", - " 2.0438e-02, 2.9061e-02, 1.3704e-02, 1.7363e-02, -7.7676e-03,\n", - " -1.0324e-02, 2.0690e-03, -1.4933e-02, 8.7419e-03, -2.6359e-02,\n", - " -6.7530e-03, 4.8337e-03, -1.4728e-02, -2.5673e-02, 1.1972e-02,\n", - " -4.5589e-03, -1.6134e-02, -1.5268e-02, -2.3441e-02, -1.5927e-02,\n", - " 3.5573e-02, 1.4648e-03, 1.4684e-02, -2.4955e-02, 1.7429e-02,\n", - " -2.3515e-02, -3.4481e-02, -1.9777e-03, 5.8714e-03, 4.3375e-02,\n", - " 8.5882e-03, 3.1015e-02, 2.5195e-02, 3.2845e-02, 4.9105e-03,\n", - " 3.2035e-03, -3.3776e-02, -3.3765e-02, 4.3348e-02, 3.4604e-02,\n", - " 1.1769e-02, -3.2888e-02, 1.3825e-03, -4.0359e-03, -4.2319e-02,\n", - " -2.9038e-02, -2.3940e-02, -3.0994e-02, -6.1948e-03, -3.6867e-02,\n", - " 1.8567e-02, -7.8213e-03, -5.4482e-03, 2.3144e-02, 2.7622e-03,\n", - " 3.1960e-02, -2.8063e-02, 1.0038e-03, 3.3530e-02, -3.6069e-02,\n", - " -1.3959e-02, 1.3825e-02, 1.1711e-02, 6.4103e-03, 2.2153e-03,\n", - " 1.8895e-02, 2.2527e-02, 1.1061e-02, -2.3915e-02, -5.6058e-03,\n", - " 3.0207e-02, 1.3919e-02, -1.9913e-05, 1.0912e-02, -3.9912e-02,\n", - " -2.4519e-02, 2.4014e-02, -3.4215e-02, 2.0771e-02, -1.4133e-02,\n", - " -2.0537e-02, -3.5628e-02, -2.6749e-02, -2.0842e-02, 5.8522e-03,\n", - " 8.3789e-03, 1.5420e-02, -2.8678e-02, -3.7518e-02, 4.8427e-03,\n", - " -2.1849e-02, -2.1646e-02, 9.3861e-03, -7.7176e-03, 1.6586e-02,\n", - " -6.7733e-04, -1.8504e-04, -1.7768e-02, 2.0643e-02, -1.0033e-02,\n", - " -3.4722e-03, -3.3094e-02, 1.8807e-02, 2.9360e-02, -2.1536e-02,\n", - " 2.1318e-02, 1.7793e-02, 2.1258e-02, -6.5998e-03, 1.9706e-02,\n", - " -3.3606e-02, -2.8671e-02, 2.7924e-02, 3.3058e-04, -1.4342e-03,\n", - " -2.2750e-02, 2.2617e-02, 3.8280e-02, 3.8726e-02, -7.8967e-03,\n", - " -9.9757e-03, -1.9764e-02, 1.0883e-02, -1.5125e-03, -4.1490e-03,\n", - " -1.8074e-02, 1.3813e-02, 3.9276e-02, -2.1581e-02, -3.8515e-02,\n", - " -2.7119e-02, -2.3735e-02, 5.3212e-03, 4.4095e-04, -3.9134e-02,\n", - " 3.8355e-02, -1.0873e-02, 2.9719e-02, -2.9585e-03, 9.3589e-03,\n", - " -5.2759e-03, 2.0814e-02, -3.4896e-02, 1.0858e-02, 3.7425e-02,\n", - " -2.8287e-02, -1.4037e-02, -4.1064e-03, 2.6441e-02, -6.4868e-03,\n", - " -2.5794e-02, -4.3410e-03, -3.5065e-02, -1.0213e-02, -2.0883e-02,\n", - " 2.8050e-02, 3.7213e-02, -1.7304e-02, -1.2649e-02, -2.0664e-03,\n", - " -2.0959e-02, -2.9738e-02, -2.9440e-03, 2.8278e-02, -6.0002e-03,\n", - " 8.7061e-03, 2.0209e-02, -2.5403e-02, -2.3993e-02, -2.9707e-02,\n", - " -3.6108e-02, -3.3118e-02, -1.8849e-02, -3.6540e-02, 6.7718e-03,\n", - " 8.8591e-04, 4.3042e-02, -2.7515e-03, 3.1393e-02, -1.2268e-02,\n", - " -1.1763e-02, 2.6841e-02, 1.0779e-03, -3.3840e-02, 6.0528e-03,\n", - " 3.2110e-02, 3.3504e-02, -2.2105e-02, -4.0328e-04, 7.5812e-03,\n", - " -2.0404e-02, -7.9293e-03, 4.1881e-02, 3.5226e-02, -1.4845e-02,\n", - " -3.6957e-02, 9.6239e-03, 4.6920e-03, 1.1712e-02, 3.7794e-02,\n", - " 3.6824e-02, -7.3588e-03, -2.2133e-02, -2.1761e-02, 5.3850e-03,\n", - " 1.4853e-02, -1.2601e-02, 7.1829e-03, -1.6853e-02, -4.1580e-02,\n", - " -3.7828e-02, -3.3310e-02], device='cuda:0')), ('rnn.bias_hh_l0', tensor([-0.0176, -0.0317, 0.0152, -0.0377, -0.0333, 0.0278, 0.0297, 0.0351,\n", - " 0.0381, -0.0003, 0.0294, -0.0190, 0.0195, 0.0132, 0.0208, -0.0257,\n", - " 0.0195, 0.0126, -0.0045, -0.0221, -0.0300, 0.0114, -0.0163, -0.0008,\n", - " 0.0237, 0.0310, -0.0072, 0.0245, -0.0039, -0.0181, 0.0438, 0.0281,\n", - " -0.0427, -0.0133, 0.0319, -0.0065, 0.0303, 0.0245, -0.0175, 0.0374,\n", - " 0.0244, -0.0405, 0.0097, -0.0194, -0.0042, 0.0263, -0.0378, 0.0130,\n", - " -0.0234, -0.0166, 0.0241, -0.0217, 0.0143, 0.0173, 0.0180, 0.0326,\n", - " 0.0092, -0.0147, -0.0226, -0.0286, -0.0322, -0.0092, 0.0353, -0.0164,\n", - " -0.0315, -0.0254, -0.0020, 0.0119, -0.0012, 0.0442, 0.0402, -0.0008,\n", - " 0.0433, -0.0091, 0.0166, 0.0034, 0.0262, 0.0136, -0.0017, 0.0311,\n", - " -0.0171, 0.0257, 0.0329, -0.0260, 0.0236, -0.0273, -0.0074, -0.0250,\n", - " 0.0003, 0.0107, 0.0294, -0.0162, 0.0175, 0.0156, 0.0381, 0.0103,\n", - " 0.0240, -0.0195, -0.0362, 0.0167, 0.0025, 0.0237, -0.0267, 0.0225,\n", - " 0.0300, 0.0183, -0.0282, -0.0027, 0.0003, -0.0377, 0.0229, -0.0398,\n", - " 0.0295, 0.0309, -0.0159, -0.0329, 0.0056, -0.0320, 0.0094, 0.0255,\n", - " 0.0203, -0.0397, -0.0035, -0.0382, -0.0025, -0.0171, 0.0344, -0.0216,\n", - " 0.0309, 0.0176, -0.0360, 0.0288, -0.0303, -0.0314, 0.0212, 0.0338,\n", - " -0.0062, 0.0047, -0.0395, 0.0442, -0.0274, 0.0106, -0.0110, -0.0402,\n", - " -0.0046, 0.0155, 0.0026, 0.0180, -0.0257, -0.0024, -0.0146, -0.0140,\n", - " -0.0072, 0.0031, -0.0078, 0.0296, -0.0155, 0.0194, -0.0212, -0.0295,\n", - " 0.0244, 0.0327, -0.0265, -0.0345, 0.0077, 0.0015, 0.0237, -0.0251,\n", - " -0.0289, 0.0157, 0.0149, -0.0166, -0.0207, 0.0037, 0.0017, 0.0059,\n", - " -0.0047, -0.0158, 0.0236, 0.0035, -0.0181, -0.0279, -0.0088, -0.0084,\n", - " -0.0277, -0.0356, 0.0324, 0.0081, -0.0273, -0.0011, 0.0052, 0.0176,\n", - " 0.0106, -0.0329, -0.0004, -0.0357, 0.0265, -0.0152, 0.0043, -0.0003,\n", - " -0.0188, 0.0207, -0.0134, 0.0222, 0.0004, -0.0240, -0.0240, 0.0255,\n", - " 0.0164, 0.0254, 0.0120, -0.0208, 0.0090, 0.0311, 0.0153, -0.0235,\n", - " 0.0330, 0.0208, -0.0073, -0.0332, -0.0193, -0.0014, 0.0211, 0.0022,\n", - " -0.0268, -0.0219, -0.0094, -0.0182, 0.0199, -0.0374, 0.0348, -0.0184,\n", - " 0.0188, -0.0328, 0.0308, -0.0107, 0.0043, -0.0444, -0.0118, 0.0150,\n", - " 0.0055, -0.0325, -0.0277, -0.0045, -0.0081, 0.0367, 0.0225, -0.0317,\n", - " 0.0355, -0.0138, 0.0006, -0.0264, 0.0091, 0.0057, -0.0073, 0.0076,\n", - " -0.0098, 0.0186, 0.0072, -0.0265, 0.0030, -0.0239, -0.0243, 0.0366,\n", - " -0.0212, 0.0252, -0.0160, -0.0317, -0.0279, -0.0190, 0.0024, 0.0225,\n", - " -0.0113, -0.0094, 0.0356, 0.0007, -0.0389, 0.0397, -0.0249, 0.0325,\n", - " 0.0187, 0.0060, -0.0251, 0.0117, -0.0160, 0.0040, -0.0244, -0.0084,\n", - " -0.0273, 0.0174, 0.0289, -0.0352, 0.0174, 0.0242, -0.0345, -0.0065,\n", - " -0.0024, -0.0131, -0.0112, 0.0040, -0.0047, -0.0230, 0.0353, 0.0218,\n", - " 0.0094, -0.0336, 0.0085, 0.0114, 0.0072, 0.0265, 0.0088, -0.0417,\n", - " -0.0056, 0.0130, -0.0225, 0.0252, -0.0158, -0.0283, -0.0365, -0.0385,\n", - " 0.0346, -0.0274, -0.0225, -0.0085, -0.0087, 0.0217, 0.0403, 0.0262,\n", - " 0.0150, 0.0203, -0.0279, -0.0308, 0.0306, 0.0325, 0.0170, 0.0196,\n", - " 0.0358, 0.0219, -0.0400, -0.0344, -0.0391, 0.0369, -0.0327, -0.0326,\n", - " 0.0382, 0.0209, -0.0336, -0.0180, -0.0270, 0.0266, -0.0240, -0.0274,\n", - " 0.0421, -0.0317, -0.0158, -0.0274, -0.0378, 0.0313, -0.0349, -0.0144,\n", - " -0.0143, 0.0094, -0.0067, 0.0050, 0.0080, 0.0198, -0.0269, 0.0014,\n", - " -0.0177, -0.0059, -0.0225, -0.0071, -0.0334, -0.0043, 0.0261, 0.0051,\n", - " 0.0290, 0.0276, -0.0375, 0.0385, 0.0231, 0.0419, 0.0393, -0.0190,\n", - " 0.0128, -0.0225, 0.0292, -0.0173, 0.0119, 0.0294, -0.0175, 0.0419,\n", - " -0.0377, 0.0152, -0.0027, 0.0025, -0.0190, 0.0294, 0.0107, 0.0080,\n", - " -0.0327, 0.0271, 0.0208, 0.0087, 0.0068, -0.0042, -0.0295, 0.0082,\n", - " 0.0378, 0.0266, -0.0104, 0.0289, 0.0036, 0.0146, 0.0143, 0.0260,\n", - " 0.0127, 0.0099, -0.0329, -0.0232, -0.0216, 0.0134, -0.0124, -0.0256,\n", - " -0.0236, -0.0099, -0.0253, -0.0007, 0.0071, 0.0185, -0.0110, 0.0432,\n", - " -0.0086, -0.0222, 0.0099, -0.0433, -0.0059, -0.0301, -0.0405, 0.0258,\n", - " -0.0208, 0.0038, 0.0264, 0.0312, 0.0165, -0.0218, 0.0059, 0.0186,\n", - " -0.0191, -0.0402, -0.0113, 0.0353, 0.0080, -0.0191, 0.0003, 0.0301,\n", - " 0.0246, -0.0021, 0.0341, 0.0326, -0.0086, 0.0029, 0.0189, 0.0211,\n", - " 0.0033, 0.0405, 0.0371, 0.0071, 0.0190, 0.0238, 0.0212, 0.0324,\n", - " -0.0279, -0.0007, 0.0115, -0.0079, -0.0093, 0.0196, 0.0124, -0.0009,\n", - " -0.0044, -0.0348, 0.0012, 0.0109, 0.0279, -0.0114, -0.0116, 0.0169,\n", - " 0.0062, -0.0183, 0.0398, 0.0074, -0.0445, -0.0308, -0.0066, 0.0267,\n", - " -0.0064, -0.0042, -0.0330, 0.0313, 0.0374, 0.0262, -0.0180, -0.0016,\n", - " -0.0275, -0.0062, -0.0219, -0.0216, -0.0240, 0.0352, 0.0468, 0.0332],\n", - " device='cuda:0')), ('linear.weight', tensor([[-8.8547e-02, -5.9987e-02, -4.7513e-02, 2.4096e-02, -1.0198e-01,\n", - " 2.5908e-02, 8.3311e-02, 7.6904e-02, -8.2395e-02, -9.8762e-02,\n", - " -5.7265e-02, -3.2875e-02, 1.2428e-01, 4.0255e-02, 9.0196e-02,\n", - " 9.2717e-02, -1.1692e-01, -6.9366e-02, -9.2311e-02, -2.7361e-02,\n", - " 2.3696e-02, 7.2859e-02, 1.2173e-01, -5.6736e-02, -4.5124e-02,\n", - " 2.0660e-04, 2.3291e-02, 7.1301e-02, -1.1238e-02, -1.7722e-02,\n", - " -1.2992e-01, 6.7982e-02, 4.1920e-02, -9.6538e-02, -1.0276e-01,\n", - " 8.4963e-02, 4.8128e-02, -2.5694e-02, 7.7580e-02, 5.2343e-02,\n", - " -4.0465e-02, 1.0426e-02, -8.2975e-02, -1.2294e-02, 1.0158e-01,\n", - " -6.1852e-02, 1.0340e-01, 8.4143e-03, 5.4585e-02, -1.2020e-02,\n", - " 5.1561e-02, 9.0128e-02, 4.5203e-02, 2.4200e-02, -8.1520e-02,\n", - " 7.0582e-02, 8.1724e-02, -3.4081e-02, 1.0137e-01, -7.7322e-02,\n", - " -6.0439e-02, -5.1627e-02, -3.0337e-02, -8.9423e-02, 8.6312e-02,\n", - " -5.1982e-02, 5.7188e-02, 7.0377e-02, -5.1014e-02, -6.3353e-02,\n", - " 4.6164e-02, 5.9232e-02, -1.1451e-01, 3.1713e-02, 1.8907e-03,\n", - " -9.4574e-02, -4.1265e-02, 9.1570e-02, -2.6643e-03, -4.2570e-02,\n", - " -8.6614e-02, 8.9580e-02, 6.1936e-02, 1.1297e-01, 1.0044e-01,\n", - " 6.1658e-04, 9.5859e-02, -1.8678e-02, -3.1659e-02, 2.1996e-02,\n", - " 1.2178e-02, 4.6491e-02, 8.1361e-02, 5.3897e-02, 1.8912e-03,\n", - " 6.5380e-02, 7.7466e-02, 6.9567e-03, 4.4857e-02, -3.4188e-02,\n", - " -1.1944e-02, -4.2887e-02, -2.2774e-02, 7.9200e-02, -4.8751e-02,\n", - " -1.4714e-02, 6.3899e-02, 9.2667e-02, 1.5898e-02, 4.0466e-03,\n", - " -5.9873e-02, 9.4463e-02, -8.5658e-02, -2.0654e-02, -5.2903e-02,\n", - " -4.2401e-03, 9.2405e-02, 6.5300e-02, 5.8294e-02, -4.1732e-02,\n", - " -7.4746e-02, 6.4393e-02, -7.3649e-02, 6.5130e-02, -9.0391e-02,\n", - " -5.2148e-02, -6.5804e-02, 1.1155e-01, 7.5387e-02, -8.6350e-02,\n", - " -3.0850e-02, 7.9514e-02, 9.9159e-02, 3.2748e-02, 3.3221e-02,\n", - " 4.5976e-02, 8.8669e-02, 4.3326e-02, -5.4980e-02, -7.7956e-02,\n", - " -4.8407e-02, -5.9281e-02, 8.5664e-02, -6.4136e-02, -1.0308e-01,\n", - " -7.8024e-02, 7.8702e-02, 1.2721e-01, -7.7224e-02, -6.8875e-03,\n", - " 1.1804e-02, -1.0920e-01, 7.8305e-02, -7.9952e-02, 5.7828e-02,\n", - " -7.2264e-02, 8.7700e-02, -7.0189e-02, 4.3205e-02, -1.0677e-01,\n", - " -8.8330e-02, 7.0265e-02, 3.8109e-02, 2.8679e-02, -4.5497e-02,\n", - " -1.4503e-02, 9.2823e-02, 5.9961e-02, -6.4451e-03, 6.5338e-02,\n", - " -7.9279e-02, 5.0901e-02, -8.8057e-02, 8.4255e-02, 7.6836e-02,\n", - " -8.5072e-02, -8.0984e-02, 5.7213e-02, -1.0367e-01, 8.1646e-02,\n", - " 7.0346e-02, -4.9342e-02, -5.9608e-02, -8.9457e-02, 1.5482e-02,\n", - " -8.6755e-02, 1.0971e-01, -1.0970e-01, 7.1382e-02, 6.6606e-02,\n", - " 3.9987e-02, -7.9434e-02, 4.4270e-02, 5.2716e-02, -7.7227e-02,\n", - " -1.1333e-01, -2.6065e-02, -9.9209e-02, -8.1538e-02, 3.2267e-02,\n", - " 7.2184e-02, -7.9998e-02, -1.0031e-01, 9.7022e-02, -1.1757e-01,\n", - " -6.9832e-02, 1.0073e-01, 2.1854e-02, -1.0080e-02, 2.2683e-03,\n", - " -1.4065e-02, 6.0701e-02, -7.0993e-02, 8.5523e-02, 7.8003e-02,\n", - " -1.1570e-01, 4.2775e-02, -9.4603e-02, 6.4485e-02, 9.0059e-02,\n", - " -7.9190e-02, -7.1474e-02, 6.5204e-02, 1.2171e-01, -8.8447e-02,\n", - " 6.1292e-02, -1.1522e-01, -7.4163e-02, 9.8226e-02, -3.4326e-03,\n", - " 3.5089e-02, -7.7413e-02, -5.8655e-02, -4.4211e-02, -9.1385e-03,\n", - " -7.9697e-02, -1.8058e-02, -3.0261e-02, 4.2157e-02, 5.6010e-06,\n", - " 3.7692e-02, -9.8621e-04, 1.0169e-01, -9.9847e-02, 3.1335e-02,\n", - " 7.2065e-02, -1.1815e-02, 8.0157e-02, 4.5651e-02, -2.6291e-02,\n", - " -3.9500e-03, 8.8305e-02, 4.4117e-02, -6.2695e-02, 1.9198e-02,\n", - " -9.7623e-02, 1.1319e-01, 8.5181e-03, -7.8295e-02, -5.6044e-02,\n", - " -8.4191e-02, -9.2241e-02, 7.8314e-02, 4.6674e-02, -1.1581e-01,\n", - " 9.0101e-02, 8.5546e-02, -7.1386e-02, 9.6367e-02, 5.5332e-02,\n", - " -5.7802e-02, -3.6574e-02, -7.7728e-02, 1.2894e-01, -7.5711e-02,\n", - " 2.9625e-02, -4.1130e-02, -3.2909e-02, -5.0653e-02, -6.0805e-02,\n", - " 2.4993e-02, -7.7794e-02, -7.4821e-02, -1.1293e-01, 7.7086e-02,\n", - " 9.1735e-02, 8.2461e-02, 4.8319e-02, -1.0312e-01, -7.6090e-02,\n", - " -8.7980e-02, -9.1191e-02, 8.2028e-02, 4.1118e-03, -5.5895e-02,\n", - " 9.5860e-02, -9.6931e-02, -4.4530e-02, -5.5381e-02, -3.4383e-02,\n", - " -1.0726e-01, 7.8146e-02, 1.0706e-01, 5.9457e-02, 7.6715e-02,\n", - " -9.5877e-03, 7.5848e-02, -9.6188e-02, -5.5072e-02, -6.8968e-02,\n", - " -1.1800e-01, 7.5767e-02, -3.3765e-03, 2.3612e-02, 1.0352e-01,\n", - " 7.0269e-02, -4.7941e-03, -3.7258e-02, 9.8977e-02, -5.2837e-02,\n", - " -1.3844e-02, -9.2345e-02, 5.5284e-02, 2.0484e-03, 5.8924e-02,\n", - " 8.0741e-02, -6.0709e-02, -9.2244e-02, 4.2808e-02, -7.8506e-02,\n", - " -2.5839e-02, -3.5147e-02, -8.7390e-02, -1.0298e-01, -1.1235e-01,\n", - " 9.3587e-02, 3.1214e-02, 3.0557e-02, 9.5419e-02, 8.9070e-02,\n", - " -1.7131e-02, -7.0812e-02, -5.9715e-02, 3.2449e-02, 6.0174e-02,\n", - " 7.8788e-02, 7.9160e-02, 1.0599e-01, -6.1668e-02, 5.5055e-02,\n", - " -4.9683e-02, -6.1380e-02, 5.7822e-02, 5.5926e-02, 5.6107e-02,\n", - " 8.9982e-02, -1.7426e-02, 9.6545e-02, -4.2272e-02, 6.7213e-02,\n", - " 8.2237e-02, -1.0420e-01, -6.5223e-02, 9.7535e-02, -8.2105e-02,\n", - " 1.7870e-02, -7.3713e-02, -2.3103e-02, -5.2375e-02, -7.8572e-02,\n", - " -1.0300e-01, 2.7166e-02, 8.0863e-03, 6.1729e-02, -5.1810e-02,\n", - " 1.5730e-02, 8.5515e-02, 3.4076e-03, -7.4814e-02, -6.7597e-03,\n", - " 3.5729e-02, -1.8539e-02, 5.3722e-02, -3.1593e-03, -9.8434e-02,\n", - " 2.9996e-02, 2.0447e-02, -7.2871e-02, 7.0552e-02, -7.8399e-02,\n", - " 6.0796e-02, 9.5971e-02, -4.1587e-02, -2.0611e-02, 7.5822e-02,\n", - " 9.1397e-02, 1.0871e-01, -8.1552e-02, 8.3245e-02, 4.1961e-02,\n", - " 1.0330e-01, -5.2573e-02, 6.6681e-02, -9.3816e-02, 5.5165e-02,\n", - " -8.1054e-02, -1.1313e-01, 9.4115e-02, 1.0250e-01, -9.8301e-02,\n", - " -2.4801e-02, -6.7894e-02, -4.7212e-02, 7.1672e-02, 1.2984e-01,\n", - " -3.7132e-02, -3.8868e-02, 8.7328e-02, -3.5943e-02, -7.3443e-02,\n", - " -1.6196e-02, 7.8980e-02, -2.2536e-02, 2.5644e-02, -9.1489e-02,\n", - " 6.7848e-02, -7.2270e-02, -1.0718e-01, -7.7530e-02, -1.1823e-01,\n", - " 8.6737e-02, 1.0586e-01, -6.0456e-02, -1.0192e-01, 6.5803e-02,\n", - " -1.1555e-02, -7.7942e-04, 9.1762e-02, -6.6495e-02, -8.3189e-02,\n", - " -2.4862e-02, -1.1741e-02, 8.8719e-02, -7.6499e-02, 4.5282e-02,\n", - " -8.9169e-02, 7.1551e-02, -5.9672e-02, 7.5346e-02, 1.1082e-02,\n", - " 6.8779e-02, 6.4291e-02, -2.9776e-02, -2.9347e-02, -7.6854e-02,\n", - " 5.4371e-02, 8.7839e-03, -8.5555e-02, -9.6236e-02, -6.0593e-02,\n", - " -4.2329e-02, -6.4755e-02, -2.0376e-02, 3.2041e-02, 4.3369e-02,\n", - " 4.6920e-02, 6.0962e-02, 8.8417e-02, -9.9226e-02, -9.1065e-02,\n", - " -1.9233e-02, 6.4499e-02, 1.2042e-02, -1.8561e-02, -2.2591e-02,\n", - " -1.2839e-03, 4.2153e-02, -5.2112e-02, -3.9568e-02, 3.2787e-02,\n", - " -9.6947e-02, 2.3297e-02, -9.0590e-02, -5.0173e-02, -6.2493e-02,\n", - " -8.2204e-02, -8.3147e-02, -1.1300e-01, 4.3085e-02, -4.4801e-02,\n", - " -1.2093e-01, -7.5007e-02, -6.6533e-02, -6.9864e-02, 8.7778e-02,\n", - " 2.4411e-02, 7.6407e-02, -5.6045e-02, -6.7802e-02, -1.1631e-02,\n", - " 8.3432e-02, -4.4422e-02, 1.5315e-02, 8.7830e-02, -7.9375e-02,\n", - " -8.2172e-02, -6.2139e-02, 2.9427e-03, 8.0568e-02, -3.9482e-02,\n", - " 5.7513e-02, 2.9916e-02]], device='cuda:0')), ('linear.bias', tensor([-0.0032], device='cuda:0'))])}\n", - "Loaded model: RNN\n" + "Loaded model: TCN\n" ] } ], "source": [ - "# 3. load the model from .pth file\n", + "# 3. load the model from .pth file containing the state data\n", "\n", "from src.persist import load_model\n", "\n", - "model_type = \"lstm\" # choose between rnn, lstm, gru, tcn, te\n", - "model_path = f\"output/pretrained/{model_type}.pth\"\n", + "model_type = \"TCN\" # choose between RNN, LSTM, GRU, TCN, TE\n", + "model_path = f\"output/pretrained/{model_type}_state_data.pth\"\n", "model = load_model(file_path=model_path)\n", "\n", "print(f\"Loaded model: {model}\")" @@ -488,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -496,7 +213,7 @@ "output_type": "stream", "text": [ "The HOME team will win.\n", - "Home win prediction: 0.6114086508750916\n" + "Home win prediction: 0.9627264142036438\n" ] } ], diff --git a/train.ipynb b/train.ipynb index bdc9e98..74dc259 100644 --- a/train.ipynb +++ b/train.ipynb @@ -24,7 +24,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 1, @@ -55,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -75,7 +75,7 @@ "=================================================================" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -101,10 +101,11 @@ " data_as_sequence=True,\n", " output_path=check_dir(\"rnn\"),\n", " # adjust\n", - " epochs=30,\n", - " init_learning_rate=1e-3,\n", + " epochs=50,\n", + " init_learning_rate=1e-5,\n", " weight_decay=1e-5,\n", - " # data_use_all=True,\n", + " # enable for final training\n", + " data_use_all=True,\n", ")" ] }, @@ -127,12 +128,12 @@ "Layer (type:depth-idx) Param #\n", "=================================================================\n", "LSTM --\n", - "├─LSTM: 1-1 4,677,632\n", - "├─Linear: 1-2 1,025\n", + "├─LSTM: 1-1 17,743,872\n", + "├─Linear: 1-2 2,049\n", "├─Sigmoid: 1-3 --\n", "=================================================================\n", - "Total params: 4,678,657\n", - "Trainable params: 4,678,657\n", + "Total params: 17,745,921\n", + "Trainable params: 17,745,921\n", "Non-trainable params: 0\n", "=================================================================" ] @@ -144,7 +145,7 @@ ], "source": [ "lstm_model = models.LSTM(\n", - " hidden_size=1024,\n", + " hidden_size=2048,\n", ")\n", "\n", "summary(lstm_model)" @@ -152,187 +153,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 17:03:21] INFO : src.train - Loading data...\n", - "[2024-08-11 17:03:26] INFO : src.train - Beginning to train the network...\n", - "[2024-08-11 17:03:29] INFO : src.train - EPOCH: 1/30\n", - "[2024-08-11 17:03:29] INFO : src.train - Train loss: 0.6831, Train accuracy: 0.5747\n", - "[2024-08-11 17:03:29] INFO : src.train - Val loss: 0.6745, Val accuracy: 0.5801\n", - "\n", - "[2024-08-11 17:03:33] INFO : src.train - EPOCH: 2/30\n", - "[2024-08-11 17:03:33] INFO : src.train - Train loss: 0.6768, Train accuracy: 0.5771\n", - "[2024-08-11 17:03:33] INFO : src.train - Val loss: 0.6672, Val accuracy: 0.6003\n", - "\n", - "[2024-08-11 17:03:36] INFO : src.train - EPOCH: 3/30\n", - "[2024-08-11 17:03:36] INFO : src.train - Train loss: 0.6723, Train accuracy: 0.5866\n", - "[2024-08-11 17:03:36] INFO : src.train - Val loss: 0.6661, Val accuracy: 0.5952\n", - "\n", - "[2024-08-11 17:03:39] INFO : src.train - EPOCH: 4/30\n", - "[2024-08-11 17:03:39] INFO : src.train - Train loss: 0.6770, Train accuracy: 0.5783\n", - "[2024-08-11 17:03:39] INFO : src.train - Val loss: 0.6672, Val accuracy: 0.5952\n", - "\n", - "[2024-08-11 17:03:42] INFO : src.train - EPOCH: 5/30\n", - "[2024-08-11 17:03:42] INFO : src.train - Train loss: 0.6686, Train accuracy: 0.5918\n", - "[2024-08-11 17:03:42] INFO : src.train - Val loss: 0.6783, Val accuracy: 0.5662\n", - "\n", - "[2024-08-11 17:03:45] INFO : src.train - EPOCH: 6/30\n", - "[2024-08-11 17:03:45] INFO : src.train - Train loss: 0.6737, Train accuracy: 0.5806\n", - "[2024-08-11 17:03:45] INFO : src.train - Val loss: 0.6640, Val accuracy: 0.5986\n", - "\n", - "[2024-08-11 17:03:48] INFO : src.train - EPOCH: 7/30\n", - "[2024-08-11 17:03:48] INFO : src.train - Train loss: 0.6722, Train accuracy: 0.5812\n", - "[2024-08-11 17:03:48] INFO : src.train - Val loss: 0.6650, Val accuracy: 0.5943\n", - "\n", - "[2024-08-11 17:03:51] INFO : src.train - EPOCH: 8/30\n", - "[2024-08-11 17:03:51] INFO : src.train - Train loss: 0.6743, Train accuracy: 0.5792\n", - "[2024-08-11 17:03:51] INFO : src.train - Val loss: 0.6567, Val accuracy: 0.6042\n", - "\n", - "[2024-08-11 17:03:54] INFO : src.train - EPOCH: 9/30\n", - "[2024-08-11 17:03:54] INFO : src.train - Train loss: 0.6734, Train accuracy: 0.5792\n", - "[2024-08-11 17:03:54] INFO : src.train - Val loss: 0.6676, Val accuracy: 0.5965\n", - "\n", - "[2024-08-11 17:03:58] INFO : src.train - EPOCH: 10/30\n", - "[2024-08-11 17:03:58] INFO : src.train - Train loss: 0.6750, Train accuracy: 0.5806\n", - "[2024-08-11 17:03:58] INFO : src.train - Val loss: 0.6703, Val accuracy: 0.5982\n", - "\n", - "[2024-08-11 17:04:01] INFO : src.train - EPOCH: 11/30\n", - "[2024-08-11 17:04:01] INFO : src.train - Train loss: 0.6725, Train accuracy: 0.5842\n", - "[2024-08-11 17:04:01] INFO : src.train - Val loss: 0.6606, Val accuracy: 0.6029\n", - "\n", - "[2024-08-11 17:04:04] INFO : src.train - EPOCH: 12/30\n", - "[2024-08-11 17:04:04] INFO : src.train - Train loss: 0.6693, Train accuracy: 0.5909\n", - "[2024-08-11 17:04:04] INFO : src.train - Val loss: 0.6661, Val accuracy: 0.5999\n", - "\n", - "[2024-08-11 17:04:07] INFO : src.train - EPOCH: 13/30\n", - "[2024-08-11 17:04:07] INFO : src.train - Train loss: 0.6664, Train accuracy: 0.5953\n", - "[2024-08-11 17:04:07] INFO : src.train - Val loss: 0.6627, Val accuracy: 0.6060\n", - "\n", - "[2024-08-11 17:04:10] INFO : src.train - EPOCH: 14/30\n", - "[2024-08-11 17:04:10] INFO : src.train - Train loss: 0.6657, Train accuracy: 0.6010\n", - "[2024-08-11 17:04:10] INFO : src.train - Val loss: 0.6597, Val accuracy: 0.6198\n", - "\n", - "[2024-08-11 17:04:13] INFO : src.train - EPOCH: 15/30\n", - "[2024-08-11 17:04:13] INFO : src.train - Train loss: 0.6642, Train accuracy: 0.6018\n", - "[2024-08-11 17:04:13] INFO : src.train - Val loss: 0.6687, Val accuracy: 0.5969\n", - "\n", - "[2024-08-11 17:04:16] INFO : src.train - EPOCH: 16/30\n", - "[2024-08-11 17:04:16] INFO : src.train - Train loss: 0.6645, Train accuracy: 0.6028\n", - "[2024-08-11 17:04:16] INFO : src.train - Val loss: 0.6513, Val accuracy: 0.6167\n", - "\n", - "[2024-08-11 17:04:19] INFO : src.train - EPOCH: 17/30\n", - "[2024-08-11 17:04:19] INFO : src.train - Train loss: 0.6636, Train accuracy: 0.6092\n", - "[2024-08-11 17:04:19] INFO : src.train - Val loss: 0.6556, Val accuracy: 0.5999\n", - "\n", - "[2024-08-11 17:04:22] INFO : src.train - EPOCH: 18/30\n", - "[2024-08-11 17:04:22] INFO : src.train - Train loss: 0.6620, Train accuracy: 0.6034\n", - "[2024-08-11 17:04:22] INFO : src.train - Val loss: 0.6520, Val accuracy: 0.6124\n", - "\n", - "[2024-08-11 17:04:26] INFO : src.train - EPOCH: 19/30\n", - "[2024-08-11 17:04:26] INFO : src.train - Train loss: 0.6682, Train accuracy: 0.5967\n", - "[2024-08-11 17:04:26] INFO : src.train - Val loss: 0.6515, Val accuracy: 0.6284\n", - "\n", - "[2024-08-11 17:04:29] INFO : src.train - EPOCH: 20/30\n", - "[2024-08-11 17:04:29] INFO : src.train - Train loss: 0.6619, Train accuracy: 0.6100\n", - "[2024-08-11 17:04:29] INFO : src.train - Val loss: 0.6494, Val accuracy: 0.6293\n", - "\n", - "[2024-08-11 17:04:32] INFO : src.train - EPOCH: 21/30\n", - "[2024-08-11 17:04:32] INFO : src.train - Train loss: 0.6647, Train accuracy: 0.6030\n", - "[2024-08-11 17:04:32] INFO : src.train - Val loss: 0.6632, Val accuracy: 0.6146\n", - "\n", - "[2024-08-11 17:04:35] INFO : src.train - EPOCH: 22/30\n", - "[2024-08-11 17:04:35] INFO : src.train - Train loss: 0.6617, Train accuracy: 0.6068\n", - "[2024-08-11 17:04:35] INFO : src.train - Val loss: 0.6564, Val accuracy: 0.5991\n", - "\n", - "[2024-08-11 17:04:38] INFO : src.train - EPOCH: 23/30\n", - "[2024-08-11 17:04:38] INFO : src.train - Train loss: 0.6588, Train accuracy: 0.6150\n", - "[2024-08-11 17:04:38] INFO : src.train - Val loss: 0.6505, Val accuracy: 0.6280\n", - "\n", - "[2024-08-11 17:04:41] INFO : src.train - EPOCH: 24/30\n", - "[2024-08-11 17:04:41] INFO : src.train - Train loss: 0.6588, Train accuracy: 0.6139\n", - "[2024-08-11 17:04:41] INFO : src.train - Val loss: 0.6527, Val accuracy: 0.6306\n", - "\n", - "[2024-08-11 17:04:44] INFO : src.train - EPOCH: 25/30\n", - "[2024-08-11 17:04:44] INFO : src.train - Train loss: 0.6604, Train accuracy: 0.6102\n", - "[2024-08-11 17:04:44] INFO : src.train - Val loss: 0.6530, Val accuracy: 0.6280\n", - "\n", - "[2024-08-11 17:04:47] INFO : src.train - EPOCH: 26/30\n", - "[2024-08-11 17:04:47] INFO : src.train - Train loss: 0.6616, Train accuracy: 0.6102\n", - "[2024-08-11 17:04:47] INFO : src.train - Val loss: 0.6545, Val accuracy: 0.6219\n", - "\n", - "[2024-08-11 17:04:50] INFO : src.train - EPOCH: 27/30\n", - "[2024-08-11 17:04:50] INFO : src.train - Train loss: 0.6573, Train accuracy: 0.6163\n", - "[2024-08-11 17:04:50] INFO : src.train - Val loss: 0.6830, Val accuracy: 0.5848\n", - "\n", - "[2024-08-11 17:04:53] INFO : src.train - EPOCH: 28/30\n", - "[2024-08-11 17:04:53] INFO : src.train - Train loss: 0.6595, Train accuracy: 0.6157\n", - "[2024-08-11 17:04:53] INFO : src.train - Val loss: 0.6599, Val accuracy: 0.6206\n", - "\n", - "[2024-08-11 17:04:57] INFO : src.train - EPOCH: 29/30\n", - "[2024-08-11 17:04:57] INFO : src.train - Train loss: 0.6574, Train accuracy: 0.6108\n", - "[2024-08-11 17:04:57] INFO : src.train - Val loss: 0.6467, Val accuracy: 0.6331\n", - "\n", - "[2024-08-11 17:05:00] INFO : src.train - EPOCH: 30/30\n", - "[2024-08-11 17:05:00] INFO : src.train - Train loss: 0.6578, Train accuracy: 0.6078\n", - "[2024-08-11 17:05:00] INFO : src.train - Val loss: 0.6621, Val accuracy: 0.5956\n", - "\n", - "[2024-08-11 17:05:00] INFO : src.train - Evaluating network...\n", - "[2024-08-11 17:05:00] INFO : src.train - Accuracy from season remainder: 0.5706\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " AWAY_WIN 0.68 0.17 0.27 163\n", - " HOME_WIN 0.56 0.93 0.70 184\n", - "\n", - " accuracy 0.57 347\n", - " macro avg 0.62 0.55 0.48 347\n", - "weighted avg 0.61 0.57 0.49 347\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 17:05:01] INFO : src.train - Accuracy from next season: 0.5747\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " AWAY_WIN 0.49 0.10 0.16 445\n", - " HOME_WIN 0.58 0.92 0.71 606\n", - "\n", - " accuracy 0.57 1051\n", - " macro avg 0.54 0.51 0.44 1051\n", - "weighted avg 0.54 0.57 0.48 1051\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 17:05:02] INFO : src.train - Accuracy on short streaks (training): 0.5247\n", - "[2024-08-11 17:05:02] INFO : src.train - Accuracy on long streaks (training): 0.5233\n", - "[2024-08-11 17:05:02] INFO : src.train - Accuracy on short streaks (evaluation): 0.4728\n", - "[2024-08-11 17:05:02] INFO : src.train - Accuracy on long streaks (evaluation): 0.4394\n" - ] - } - ], + "outputs": [], "source": [ "_m, _h = run_train(\n", " # fixed\n", @@ -341,10 +164,11 @@ " data_as_sequence=True,\n", " output_path=check_dir(\"lstm\"),\n", " # adjust\n", - " epochs=30,\n", - " init_learning_rate=1e-3,\n", + " epochs=50,\n", + " init_learning_rate=1e-4,\n", " weight_decay=1e-5,\n", - " # data_use_all=True,\n", + " # enable for final training\n", + " data_use_all=True,\n", ")" ] }, @@ -357,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -367,24 +191,24 @@ "Layer (type:depth-idx) Param #\n", "=================================================================\n", "GRU --\n", - "├─GRU: 1-1 3,508,224\n", - "├─Linear: 1-2 1,025\n", + "├─GRU: 1-1 13,307,904\n", + "├─Linear: 1-2 2,049\n", "├─Sigmoid: 1-3 --\n", "=================================================================\n", - "Total params: 3,509,249\n", - "Trainable params: 3,509,249\n", + "Total params: 13,309,953\n", + "Trainable params: 13,309,953\n", "Non-trainable params: 0\n", "=================================================================" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gru_model = models.GRU(\n", - " hidden_size=1024,\n", + " hidden_size=2048,\n", ")\n", "\n", "summary(gru_model)" @@ -392,199 +216,22 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 16:41:32] INFO : src.train - Loading data...\n", - "[2024-08-11 16:41:36] INFO : src.train - Beginning to train the network...\n", - "[2024-08-11 16:41:41] INFO : src.train - EPOCH: 1/30\n", - "[2024-08-11 16:41:41] INFO : src.train - Train loss: 0.6771, Train accuracy: 0.5806\n", - "[2024-08-11 16:41:41] INFO : src.train - Val loss: 0.6321, Val accuracy: 0.6383\n", - "\n", - "[2024-08-11 16:41:45] INFO : src.train - EPOCH: 2/30\n", - "[2024-08-11 16:41:45] INFO : src.train - Train loss: 0.6487, Train accuracy: 0.6237\n", - "[2024-08-11 16:41:45] INFO : src.train - Val loss: 0.6272, Val accuracy: 0.6457\n", - "\n", - "[2024-08-11 16:41:49] INFO : src.train - EPOCH: 3/30\n", - "[2024-08-11 16:41:49] INFO : src.train - Train loss: 0.6396, Train accuracy: 0.6332\n", - "[2024-08-11 16:41:49] INFO : src.train - Val loss: 0.6168, Val accuracy: 0.6539\n", - "\n", - "[2024-08-11 16:41:52] INFO : src.train - EPOCH: 4/30\n", - "[2024-08-11 16:41:52] INFO : src.train - Train loss: 0.6376, Train accuracy: 0.6372\n", - "[2024-08-11 16:41:52] INFO : src.train - Val loss: 0.6408, Val accuracy: 0.6439\n", - "\n", - "[2024-08-11 16:41:56] INFO : src.train - EPOCH: 5/30\n", - "[2024-08-11 16:41:56] INFO : src.train - Train loss: 0.6388, Train accuracy: 0.6337\n", - "[2024-08-11 16:41:56] INFO : src.train - Val loss: 0.6178, Val accuracy: 0.6573\n", - "\n", - "[2024-08-11 16:42:00] INFO : src.train - EPOCH: 6/30\n", - "[2024-08-11 16:42:00] INFO : src.train - Train loss: 0.6359, Train accuracy: 0.6382\n", - "[2024-08-11 16:42:00] INFO : src.train - Val loss: 0.6160, Val accuracy: 0.6595\n", - "\n", - "[2024-08-11 16:42:04] INFO : src.train - EPOCH: 7/30\n", - "[2024-08-11 16:42:04] INFO : src.train - Train loss: 0.6358, Train accuracy: 0.6383\n", - "[2024-08-11 16:42:04] INFO : src.train - Val loss: 0.6153, Val accuracy: 0.6556\n", - "\n", - "[2024-08-11 16:42:08] INFO : src.train - EPOCH: 8/30\n", - "[2024-08-11 16:42:08] INFO : src.train - Train loss: 0.6349, Train accuracy: 0.6369\n", - "[2024-08-11 16:42:08] INFO : src.train - Val loss: 0.6332, Val accuracy: 0.6500\n", - "\n", - "[2024-08-11 16:42:11] INFO : src.train - EPOCH: 9/30\n", - "[2024-08-11 16:42:11] INFO : src.train - Train loss: 0.6334, Train accuracy: 0.6416\n", - "[2024-08-11 16:42:11] INFO : src.train - Val loss: 0.6210, Val accuracy: 0.6508\n", - "\n", - "[2024-08-11 16:42:15] INFO : src.train - EPOCH: 10/30\n", - "[2024-08-11 16:42:15] INFO : src.train - Train loss: 0.6332, Train accuracy: 0.6402\n", - "[2024-08-11 16:42:15] INFO : src.train - Val loss: 0.6140, Val accuracy: 0.6573\n", - "\n", - "[2024-08-11 16:42:19] INFO : src.train - EPOCH: 11/30\n", - "[2024-08-11 16:42:19] INFO : src.train - Train loss: 0.6346, Train accuracy: 0.6370\n", - "[2024-08-11 16:42:19] INFO : src.train - Val loss: 0.6294, Val accuracy: 0.6401\n", - "\n", - "[2024-08-11 16:42:23] INFO : src.train - EPOCH: 12/30\n", - "[2024-08-11 16:42:23] INFO : src.train - Train loss: 0.6335, Train accuracy: 0.6382\n", - "[2024-08-11 16:42:23] INFO : src.train - Val loss: 0.6186, Val accuracy: 0.6590\n", - "\n", - "[2024-08-11 16:42:27] INFO : src.train - EPOCH: 13/30\n", - "[2024-08-11 16:42:27] INFO : src.train - Train loss: 0.6310, Train accuracy: 0.6440\n", - "[2024-08-11 16:42:27] INFO : src.train - Val loss: 0.6201, Val accuracy: 0.6517\n", - "\n", - "[2024-08-11 16:42:31] INFO : src.train - EPOCH: 14/30\n", - "[2024-08-11 16:42:31] INFO : src.train - Train loss: 0.6314, Train accuracy: 0.6422\n", - "[2024-08-11 16:42:31] INFO : src.train - Val loss: 0.6249, Val accuracy: 0.6491\n", - "\n", - "[2024-08-11 16:42:35] INFO : src.train - EPOCH: 15/30\n", - "[2024-08-11 16:42:35] INFO : src.train - Train loss: 0.6311, Train accuracy: 0.6418\n", - "[2024-08-11 16:42:35] INFO : src.train - Val loss: 0.6190, Val accuracy: 0.6573\n", - "\n", - "[2024-08-11 16:42:39] INFO : src.train - EPOCH: 16/30\n", - "[2024-08-11 16:42:39] INFO : src.train - Train loss: 0.6326, Train accuracy: 0.6425\n", - "[2024-08-11 16:42:39] INFO : src.train - Val loss: 0.6144, Val accuracy: 0.6586\n", - "\n", - "[2024-08-11 16:42:43] INFO : src.train - EPOCH: 17/30\n", - "[2024-08-11 16:42:43] INFO : src.train - Train loss: 0.6321, Train accuracy: 0.6382\n", - "[2024-08-11 16:42:43] INFO : src.train - Val loss: 0.6150, Val accuracy: 0.6539\n", - "\n", - "[2024-08-11 16:42:47] INFO : src.train - EPOCH: 18/30\n", - "[2024-08-11 16:42:47] INFO : src.train - Train loss: 0.6332, Train accuracy: 0.6378\n", - "[2024-08-11 16:42:47] INFO : src.train - Val loss: 0.6154, Val accuracy: 0.6560\n", - "\n", - "[2024-08-11 16:42:51] INFO : src.train - EPOCH: 19/30\n", - "[2024-08-11 16:42:51] INFO : src.train - Train loss: 0.6280, Train accuracy: 0.6462\n", - "[2024-08-11 16:42:51] INFO : src.train - Val loss: 0.6144, Val accuracy: 0.6612\n", - "\n", - "[2024-08-11 16:42:54] INFO : src.train - EPOCH: 20/30\n", - "[2024-08-11 16:42:54] INFO : src.train - Train loss: 0.6287, Train accuracy: 0.6442\n", - "[2024-08-11 16:42:54] INFO : src.train - Val loss: 0.6150, Val accuracy: 0.6595\n", - "\n", - "[2024-08-11 16:42:58] INFO : src.train - EPOCH: 21/30\n", - "[2024-08-11 16:42:58] INFO : src.train - Train loss: 0.6294, Train accuracy: 0.6444\n", - "[2024-08-11 16:42:58] INFO : src.train - Val loss: 0.6205, Val accuracy: 0.6552\n", - "\n", - "[2024-08-11 16:43:02] INFO : src.train - EPOCH: 22/30\n", - "[2024-08-11 16:43:02] INFO : src.train - Train loss: 0.6292, Train accuracy: 0.6413\n", - "[2024-08-11 16:43:02] INFO : src.train - Val loss: 0.6171, Val accuracy: 0.6608\n", - "\n", - "[2024-08-11 16:43:06] INFO : src.train - EPOCH: 23/30\n", - "[2024-08-11 16:43:06] INFO : src.train - Train loss: 0.6291, Train accuracy: 0.6446\n", - "[2024-08-11 16:43:06] INFO : src.train - Val loss: 0.6189, Val accuracy: 0.6526\n", - "\n", - "[2024-08-11 16:43:09] INFO : src.train - EPOCH: 24/30\n", - "[2024-08-11 16:43:09] INFO : src.train - Train loss: 0.6288, Train accuracy: 0.6471\n", - "[2024-08-11 16:43:09] INFO : src.train - Val loss: 0.6223, Val accuracy: 0.6534\n", - "\n", - "[2024-08-11 16:43:13] INFO : src.train - EPOCH: 25/30\n", - "[2024-08-11 16:43:13] INFO : src.train - Train loss: 0.6290, Train accuracy: 0.6432\n", - "[2024-08-11 16:43:13] INFO : src.train - Val loss: 0.6384, Val accuracy: 0.6375\n", - "\n", - "[2024-08-11 16:43:17] INFO : src.train - EPOCH: 26/30\n", - "[2024-08-11 16:43:17] INFO : src.train - Train loss: 0.6257, Train accuracy: 0.6478\n", - "[2024-08-11 16:43:17] INFO : src.train - Val loss: 0.6188, Val accuracy: 0.6521\n", - "\n", - "[2024-08-11 16:43:21] INFO : src.train - EPOCH: 27/30\n", - "[2024-08-11 16:43:21] INFO : src.train - Train loss: 0.6251, Train accuracy: 0.6521\n", - "[2024-08-11 16:43:21] INFO : src.train - Val loss: 0.6153, Val accuracy: 0.6616\n", - "\n", - "[2024-08-11 16:43:24] INFO : src.train - EPOCH: 28/30\n", - "[2024-08-11 16:43:24] INFO : src.train - Train loss: 0.6219, Train accuracy: 0.6514\n", - "[2024-08-11 16:43:24] INFO : src.train - Val loss: 0.6347, Val accuracy: 0.6401\n", - "\n", - "[2024-08-11 16:43:28] INFO : src.train - EPOCH: 29/30\n", - "[2024-08-11 16:43:28] INFO : src.train - Train loss: 0.6225, Train accuracy: 0.6526\n", - "[2024-08-11 16:43:28] INFO : src.train - Val loss: 0.6242, Val accuracy: 0.6547\n", - "\n", - "[2024-08-11 16:43:32] INFO : src.train - EPOCH: 30/30\n", - "[2024-08-11 16:43:32] INFO : src.train - Train loss: 0.6226, Train accuracy: 0.6545\n", - "[2024-08-11 16:43:32] INFO : src.train - Val loss: 0.6178, Val accuracy: 0.6560\n", - "\n", - "[2024-08-11 16:43:32] INFO : src.train - Evaluating network...\n", - "[2024-08-11 16:43:32] INFO : src.train - Accuracy from season remainder: 0.6484\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " AWAY_WIN 0.71 0.43 0.53 163\n", - " HOME_WIN 0.62 0.84 0.72 184\n", - "\n", - " accuracy 0.65 347\n", - " macro avg 0.67 0.64 0.63 347\n", - "weighted avg 0.66 0.65 0.63 347\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 16:43:33] INFO : src.train - Accuracy from next season: 0.6051\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " AWAY_WIN 0.55 0.37 0.44 445\n", - " HOME_WIN 0.63 0.78 0.70 606\n", - "\n", - " accuracy 0.61 1051\n", - " macro avg 0.59 0.57 0.57 1051\n", - "weighted avg 0.59 0.61 0.59 1051\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 16:43:34] INFO : src.train - Accuracy on short streaks (training): 0.4959\n", - "[2024-08-11 16:43:34] INFO : src.train - Accuracy on long streaks (training): 0.3098\n", - "[2024-08-11 16:43:34] INFO : src.train - Accuracy on short streaks (evaluation): 0.4686\n", - "[2024-08-11 16:43:35] INFO : src.train - Accuracy on long streaks (evaluation): 0.1667\n" - ] - } - ], + "outputs": [], "source": [ "_m, _h = run_train(\n", " # fixed\n", - " model=rnn_model,\n", + " model=gru_model,\n", " sequence_len=8,\n", " data_as_sequence=True,\n", " output_path=check_dir(\"gru\"),\n", " # adjust\n", - " epochs=30,\n", - " init_learning_rate=1e-3,\n", + " epochs=50,\n", + " init_learning_rate=1e-4,\n", " weight_decay=1e-5,\n", - " # data_use_all=True,\n", + " # enable for final training\n", + " data_use_all=True,\n", ")" ] }, @@ -597,12 +244,38 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "================================================================================\n", + "Layer (type:depth-idx) Param #\n", + "================================================================================\n", + "TCN --\n", + "├─TCN: 1-1 --\n", + "│ └─ModuleList: 2-1 --\n", + "│ │ └─TemporalBlock: 3-1 411,904\n", + "│ │ └─TemporalBlock: 3-2 230,016\n", + "│ │ └─TemporalBlock: 3-3 57,664\n", + "│ └─Conv1d: 2-2 65\n", + "│ └─Sigmoid: 2-3 --\n", + "================================================================================\n", + "Total params: 699,649\n", + "Trainable params: 699,649\n", + "Non-trainable params: 0\n", + "================================================================================" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "tcn_model = models.TCN(\n", - " channels=[32, 16, 4],\n", + " channels=[256, 128, 64],\n", ")\n", "\n", "summary(tcn_model)" @@ -616,15 +289,15 @@ "source": [ "_m, _h = run_train(\n", " # fixed\n", - " model=rnn_model,\n", + " model=tcn_model,\n", " sequence_len=8,\n", " data_as_sequence=True,\n", " output_path=check_dir(\"tcn\"),\n", " # adjust\n", - " epochs=30,\n", - " init_learning_rate=1e-3,\n", - " weight_decay=1e-5,\n", - " # data_use_all=True,\n", + " epochs=50,\n", + " init_learning_rate=(1e-4) / 2,\n", + " # enable for final training\n", + " data_use_all=True,\n", ")" ] }, @@ -637,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -650,9 +323,9 @@ "├─TransformerEncoderLayer: 1-1 --\n", "│ └─MultiheadAttention: 2-1 2,586,336\n", "│ │ └─NonDynamicallyQuantizableLinear: 3-1 862,112\n", - "│ └─Linear: 2-2 1,902,592\n", + "│ └─Linear: 2-2 2,853,888\n", "│ └─Dropout: 2-3 --\n", - "│ └─Linear: 2-4 1,901,472\n", + "│ └─Linear: 2-4 2,851,744\n", "│ └─LayerNorm: 2-5 1,856\n", "│ └─LayerNorm: 2-6 1,856\n", "│ └─Dropout: 2-7 --\n", @@ -660,20 +333,21 @@ "├─Linear: 1-2 929\n", "├─Sigmoid: 1-3 --\n", "================================================================================\n", - "Total params: 7,257,153\n", - "Trainable params: 7,257,153\n", + "Total params: 9,158,721\n", + "Trainable params: 9,158,721\n", "Non-trainable params: 0\n", "================================================================================" ] }, - "execution_count": 4, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te_model = models.TE(\n", - " hidden_size=2048,\n", + " hidden_size=3072,\n", + " dropout=0.2,\n", ")\n", "\n", "summary(te_model)" @@ -681,199 +355,21 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 17:21:42] INFO : src.train - Loading data...\n", - "[2024-08-11 17:21:46] INFO : src.train - Beginning to train the network...\n", - "[2024-08-11 17:21:48] INFO : src.train - EPOCH: 1/30\n", - "[2024-08-11 17:21:48] INFO : src.train - Train loss: 0.7967, Train accuracy: 0.5419\n", - "[2024-08-11 17:21:48] INFO : src.train - Val loss: 0.6776, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:21:51] INFO : src.train - EPOCH: 2/30\n", - "[2024-08-11 17:21:51] INFO : src.train - Train loss: 0.6900, Train accuracy: 0.5581\n", - "[2024-08-11 17:21:51] INFO : src.train - Val loss: 0.6799, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:21:53] INFO : src.train - EPOCH: 3/30\n", - "[2024-08-11 17:21:53] INFO : src.train - Train loss: 0.6882, Train accuracy: 0.5594\n", - "[2024-08-11 17:21:53] INFO : src.train - Val loss: 0.6756, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:21:56] INFO : src.train - EPOCH: 4/30\n", - "[2024-08-11 17:21:56] INFO : src.train - Train loss: 0.6860, Train accuracy: 0.5636\n", - "[2024-08-11 17:21:56] INFO : src.train - Val loss: 0.6801, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:21:58] INFO : src.train - EPOCH: 5/30\n", - "[2024-08-11 17:21:58] INFO : src.train - Train loss: 0.6884, Train accuracy: 0.5625\n", - "[2024-08-11 17:21:58] INFO : src.train - Val loss: 0.6794, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:00] INFO : src.train - EPOCH: 6/30\n", - "[2024-08-11 17:22:00] INFO : src.train - Train loss: 0.6864, Train accuracy: 0.5685\n", - "[2024-08-11 17:22:00] INFO : src.train - Val loss: 0.6897, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:03] INFO : src.train - EPOCH: 7/30\n", - "[2024-08-11 17:22:03] INFO : src.train - Train loss: 0.6878, Train accuracy: 0.5581\n", - "[2024-08-11 17:22:03] INFO : src.train - Val loss: 0.6760, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:05] INFO : src.train - EPOCH: 8/30\n", - "[2024-08-11 17:22:05] INFO : src.train - Train loss: 0.6866, Train accuracy: 0.5716\n", - "[2024-08-11 17:22:05] INFO : src.train - Val loss: 0.6746, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:08] INFO : src.train - EPOCH: 9/30\n", - "[2024-08-11 17:22:08] INFO : src.train - Train loss: 0.6868, Train accuracy: 0.5694\n", - "[2024-08-11 17:22:08] INFO : src.train - Val loss: 0.6770, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:10] INFO : src.train - EPOCH: 10/30\n", - "[2024-08-11 17:22:10] INFO : src.train - Train loss: 0.6885, Train accuracy: 0.5642\n", - "[2024-08-11 17:22:10] INFO : src.train - Val loss: 0.6793, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:13] INFO : src.train - EPOCH: 11/30\n", - "[2024-08-11 17:22:13] INFO : src.train - Train loss: 0.6873, Train accuracy: 0.5642\n", - "[2024-08-11 17:22:13] INFO : src.train - Val loss: 0.6893, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:15] INFO : src.train - EPOCH: 12/30\n", - "[2024-08-11 17:22:15] INFO : src.train - Train loss: 0.6852, Train accuracy: 0.5705\n", - "[2024-08-11 17:22:15] INFO : src.train - Val loss: 0.6785, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:17] INFO : src.train - EPOCH: 13/30\n", - "[2024-08-11 17:22:17] INFO : src.train - Train loss: 0.6863, Train accuracy: 0.5697\n", - "[2024-08-11 17:22:17] INFO : src.train - Val loss: 0.6836, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:20] INFO : src.train - EPOCH: 14/30\n", - "[2024-08-11 17:22:20] INFO : src.train - Train loss: 0.6848, Train accuracy: 0.5718\n", - "[2024-08-11 17:22:20] INFO : src.train - Val loss: 0.6782, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:22] INFO : src.train - EPOCH: 15/30\n", - "[2024-08-11 17:22:22] INFO : src.train - Train loss: 0.6866, Train accuracy: 0.5705\n", - "[2024-08-11 17:22:22] INFO : src.train - Val loss: 0.6749, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:25] INFO : src.train - EPOCH: 16/30\n", - "[2024-08-11 17:22:25] INFO : src.train - Train loss: 0.6845, Train accuracy: 0.5742\n", - "[2024-08-11 17:22:25] INFO : src.train - Val loss: 0.6775, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:27] INFO : src.train - EPOCH: 17/30\n", - "[2024-08-11 17:22:27] INFO : src.train - Train loss: 0.6837, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:27] INFO : src.train - Val loss: 0.6795, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:30] INFO : src.train - EPOCH: 18/30\n", - "[2024-08-11 17:22:30] INFO : src.train - Train loss: 0.6832, Train accuracy: 0.5705\n", - "[2024-08-11 17:22:30] INFO : src.train - Val loss: 0.6752, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:32] INFO : src.train - EPOCH: 19/30\n", - "[2024-08-11 17:22:32] INFO : src.train - Train loss: 0.6833, Train accuracy: 0.5732\n", - "[2024-08-11 17:22:32] INFO : src.train - Val loss: 0.6788, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:35] INFO : src.train - EPOCH: 20/30\n", - "[2024-08-11 17:22:35] INFO : src.train - Train loss: 0.6834, Train accuracy: 0.5736\n", - "[2024-08-11 17:22:35] INFO : src.train - Val loss: 0.6831, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:37] INFO : src.train - EPOCH: 21/30\n", - "[2024-08-11 17:22:37] INFO : src.train - Train loss: 0.6831, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:37] INFO : src.train - Val loss: 0.6747, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:40] INFO : src.train - EPOCH: 22/30\n", - "[2024-08-11 17:22:40] INFO : src.train - Train loss: 0.6833, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:40] INFO : src.train - Val loss: 0.6769, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:42] INFO : src.train - EPOCH: 23/30\n", - "[2024-08-11 17:22:42] INFO : src.train - Train loss: 0.6833, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:42] INFO : src.train - Val loss: 0.6751, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:45] INFO : src.train - EPOCH: 24/30\n", - "[2024-08-11 17:22:45] INFO : src.train - Train loss: 0.6833, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:45] INFO : src.train - Val loss: 0.6747, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:48] INFO : src.train - EPOCH: 25/30\n", - "[2024-08-11 17:22:48] INFO : src.train - Train loss: 0.6837, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:48] INFO : src.train - Val loss: 0.6779, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:50] INFO : src.train - EPOCH: 26/30\n", - "[2024-08-11 17:22:50] INFO : src.train - Train loss: 0.6824, Train accuracy: 0.5734\n", - "[2024-08-11 17:22:50] INFO : src.train - Val loss: 0.6869, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:53] INFO : src.train - EPOCH: 27/30\n", - "[2024-08-11 17:22:53] INFO : src.train - Train loss: 0.6834, Train accuracy: 0.5728\n", - "[2024-08-11 17:22:53] INFO : src.train - Val loss: 0.6747, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:55] INFO : src.train - EPOCH: 28/30\n", - "[2024-08-11 17:22:55] INFO : src.train - Train loss: 0.6827, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:55] INFO : src.train - Val loss: 0.6807, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:22:58] INFO : src.train - EPOCH: 29/30\n", - "[2024-08-11 17:22:58] INFO : src.train - Train loss: 0.6831, Train accuracy: 0.5735\n", - "[2024-08-11 17:22:58] INFO : src.train - Val loss: 0.6790, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:23:00] INFO : src.train - EPOCH: 30/30\n", - "[2024-08-11 17:23:00] INFO : src.train - Train loss: 0.6831, Train accuracy: 0.5735\n", - "[2024-08-11 17:23:00] INFO : src.train - Val loss: 0.6751, Val accuracy: 0.5939\n", - "\n", - "[2024-08-11 17:23:01] INFO : src.train - Evaluating network...\n", - "[2024-08-11 17:23:01] INFO : src.train - Accuracy from season remainder: 0.5303\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " AWAY_WIN 0.00 0.00 0.00 163\n", - " HOME_WIN 0.53 1.00 0.69 184\n", - "\n", - " accuracy 0.53 347\n", - " macro avg 0.27 0.50 0.35 347\n", - "weighted avg 0.28 0.53 0.37 347\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 17:23:01] INFO : src.train - Accuracy from next season: 0.5766\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " AWAY_WIN 0.00 0.00 0.00 445\n", - " HOME_WIN 0.58 1.00 0.73 606\n", - "\n", - " accuracy 0.58 1051\n", - " macro avg 0.29 0.50 0.37 1051\n", - "weighted avg 0.33 0.58 0.42 1051\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2024-08-11 17:23:02] INFO : src.train - Accuracy on short streaks (training): 0.6103\n", - "[2024-08-11 17:23:02] INFO : src.train - Accuracy on long streaks (training): 0.6212\n", - "[2024-08-11 17:23:03] INFO : src.train - Accuracy on short streaks (evaluation): 0.5649\n", - "[2024-08-11 17:23:03] INFO : src.train - Accuracy on long streaks (evaluation): 0.5758\n" - ] - } - ], + "outputs": [], "source": [ "_m, _h = run_train(\n", " # fixed\n", - " model=rnn_model,\n", + " model=te_model,\n", " sequence_len=8,\n", " data_as_sequence=False,\n", " output_path=check_dir(\"te\"),\n", " # adjust\n", - " epochs=30,\n", - " init_learning_rate=1e-4,\n", - " weight_decay=1e-5,\n", - " # data_use_all=True,\n", + " epochs=50,\n", + " init_learning_rate=1e-5,\n", + " # enable for final training\n", + " data_use_all=True,\n", ")" ] },