From bad54406562b8ed71e32526f4054155ac714c156 Mon Sep 17 00:00:00 2001 From: itsluketwist Date: Sat, 10 Aug 2024 11:23:04 +0100 Subject: [PATCH] Add transformer model and more options to data loader. --- src/loader.py | 18 ++++++++++++++++++ src/models/__init__.py | 2 ++ src/models/rnn.py | 8 ++++---- src/models/tcn.py | 2 +- src/models/te.py | 40 ++++++++++++++++++++++++++++++++++++++++ src/train.py | 22 ++++++++++++++++++++++ src/utils.py | 24 ++++++++++++++++++++++++ 7 files changed, 111 insertions(+), 5 deletions(-) create mode 100644 src/models/te.py diff --git a/src/loader.py b/src/loader.py index d55a12f..0a07279 100644 --- a/src/loader.py +++ b/src/loader.py @@ -29,12 +29,15 @@ def __init__( verbose: bool = False, sequence_len: int = 10, normalize: bool = False, + as_sequence: bool = True, **kwargs, ): self.data = pd.read_parquet(data_file) self.verbose = verbose self.sequence_len = sequence_len self.normalize = normalize + self.as_sequence = as_sequence + self.vector_len = 116 def __len__(self): return len(self.data) @@ -46,6 +49,9 @@ def __getitem__(self, idx: int) -> tuple[Tensor, str | int]: if self.normalize: sequence = normalize(sequence) + if not self.as_sequence: + sequence = sequence.reshape([self.sequence_len * self.vector_len]) + info = item["info"] result: str | int if self.verbose: @@ -65,6 +71,7 @@ def get_train_dataloader( parquet_file: str = TRAINING_DATA, dataset_class: Dataset = GameSequenceDataset, sequence_len: int = 10, + as_sequence: bool = True, ) -> tuple[DataLoader, DataLoader]: """ Create dataloaders for training a model with NBA game data. @@ -75,6 +82,8 @@ def get_train_dataloader( batch_size: int = 64 parquet_file: str = TRAINING_DATA dataset_class: Dataset = GameSequenceDataset + sequence_len: int = 10 + as_sequence: bool = True Returns ------- @@ -83,6 +92,7 @@ def get_train_dataloader( raw_data = dataset_class( data_file=parquet_file, sequence_len=sequence_len, + as_sequence=as_sequence, ) num_train = int(len(raw_data) * train_split) @@ -111,6 +121,7 @@ def get_eval_dataloader( parquet_file: str = EVALUATION_DATA, dataset_class: Dataset = GameSequenceDataset, sequence_len: int = 10, + as_sequence: bool = True, ) -> DataLoader: """ Create a dataloader for evaluating a model with NBA game data. @@ -119,6 +130,8 @@ def get_eval_dataloader( ---------- parquet_file: str = EVALUATION_DATA dataset_class: Dataset = GameSequenceDataset + sequence_len: int = 10 + as_sequence: bool = True Returns ------- @@ -127,6 +140,7 @@ def get_eval_dataloader( raw_data = dataset_class( data_file=parquet_file, sequence_len=sequence_len, + as_sequence=as_sequence, ) return DataLoader( dataset=raw_data, @@ -139,6 +153,7 @@ def get_sample_dataloader( parquet_file: str = FINAL_WEEK_DATA, dataset_class: Dataset = GameSequenceDataset, sequence_len: int = 10, + as_sequence: bool = True, ) -> DataLoader: """ Create a dataloader for providing sample NBA game data. @@ -148,6 +163,8 @@ def get_sample_dataloader( count: int = 10 parquet_file: str = FINAL_WEEK_DATA dataset_class: Dataset = GameSequenceDataset + sequence_len: int = 10 + as_sequence: bool = True Returns ------- @@ -157,6 +174,7 @@ def get_sample_dataloader( data_file=parquet_file, verbose=True, sequence_len=sequence_len, + as_sequence=as_sequence, ) idxs = np.random.choice(range(0, len(raw_data)), size=(count,)) _subset = Subset(raw_data, idxs) diff --git a/src/models/__init__.py b/src/models/__init__.py index add0058..50f141f 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,5 +1,6 @@ from src.models.rnn import GRU, LSTM, RNN from src.models.tcn import TCN +from src.models.te import TE __all__ = [ @@ -7,4 +8,5 @@ "LSTM", "RNN", "TCN", + "TE", ] diff --git a/src/models/rnn.py b/src/models/rnn.py index 121795c..37d3eab 100644 --- a/src/models/rnn.py +++ b/src/models/rnn.py @@ -7,7 +7,7 @@ class _BaseRNN(nn.Module, ABC): - """Base class for RNN models.""" + """Base class for recurrent-cell models.""" def __init__(self, hidden_size: int, input_size: int = 116, dropout: float = 0.0): super(_BaseRNN, self).__init__() @@ -36,7 +36,7 @@ def __repr__(self) -> str: class RNN(_BaseRNN): - """Vanilla RNN class.""" + """Vanilla RNN (recurrent neural network) class to predict sequential NBA data.""" def get_rnn_layer(self) -> nn.Module: return nn.RNN( @@ -57,7 +57,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class LSTM(_BaseRNN): - """LSTM class.""" + """LSTM (long short-term memory) class to predict sequential NBA data.""" def get_rnn_layer(self) -> nn.Module: return nn.LSTM( @@ -79,7 +79,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class GRU(_BaseRNN): - """GRU class.""" + """GRU (gated recurrent unit) class to predict sequential NBA data.""" def get_rnn_layer(self) -> nn.Module: return nn.GRU( diff --git a/src/models/tcn.py b/src/models/tcn.py index 397c645..d9ab46e 100644 --- a/src/models/tcn.py +++ b/src/models/tcn.py @@ -5,7 +5,7 @@ class TCN(nn.Module): - """Class for TCN model structured to predict the NBA data.""" + """Class for TCN (temporal convolutional network) model to predict sequential NBA data.""" def __init__( self, diff --git a/src/models/te.py b/src/models/te.py new file mode 100644 index 0000000..7772087 --- /dev/null +++ b/src/models/te.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +class TE(nn.Module): + """Class for TE (transformer encoder) model to predict sequential NBA data.""" + + def __init__( + self, + input_size: int = 116, + sequence_len: int = 8, + hidden_size: int = 512, + **te_kwargs, + ): + super(TE, self).__init__() + + self._input_size = input_size * sequence_len + self._sequence_len = sequence_len + self._output_size = 1 + self._num_layers = 1 + + self.te = nn.TransformerEncoderLayer( + d_model=self._input_size, + nhead=self._sequence_len, + dim_feedforward=hidden_size, + batch_first=True, + **te_kwargs, + ) + + self.linear = nn.Linear(self._input_size, self._output_size) + self.sigmoid = nn.Sigmoid() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = self.te(input) + output = self.linear(output) + output = self.sigmoid(output) + return output + + def __repr__(self) -> str: + return type(self).__name__ diff --git a/src/train.py b/src/train.py index c48f83a..56ac89d 100644 --- a/src/train.py +++ b/src/train.py @@ -6,12 +6,14 @@ from torch.optim import Adam from src.loader import ( + ALL_GAME_DATA, EVALUATION_DATA_FULL_22_23, EVALUATION_DATA_HALF_21_22, STREAK_DATA_EVALUATION_LONG, STREAK_DATA_EVALUATION_SHORT, STREAK_DATA_TRAINING_LONG, STREAK_DATA_TRAINING_SHORT, + TRAINING_DATA, get_eval_dataloader, get_train_dataloader, ) @@ -40,6 +42,8 @@ def run_train( output_path: str = "output", save_return: bool = False, weight_decay: float = 0.0, + use_all_data: bool = False, + data_as_sequence: bool = False, ): """ Train the chosen model, given the hyperparameters. @@ -64,6 +68,10 @@ def run_train( Whether or not to save the returned model and history data to file. weight_decay: float = 0.0 The rate of decay for regularization. + use_all_data: bool = False + Whether to use all data when training, ir just the training data. + data_as_sequence: bool = True + Whether to use the dataset as a sequence of vectors, or single vector. Returns ------- @@ -75,6 +83,8 @@ def run_train( train_split=train_split, batch_size=batch_size, sequence_len=sequence_len, + parquet_file=ALL_GAME_DATA if use_all_data else TRAINING_DATA, + as_sequence=data_as_sequence, ) loss_func = nn.BCELoss() # init loss function @@ -134,6 +144,8 @@ def run_train( model=model, loader=get_eval_dataloader( parquet_file=EVALUATION_DATA_HALF_21_22, + sequence_len=sequence_len, + as_sequence=data_as_sequence, ), ) logger.info(f"Accuracy from season remainder: {hist.eval_accuracy_half:.4f}") @@ -142,6 +154,8 @@ def run_train( model=model, loader=get_eval_dataloader( parquet_file=EVALUATION_DATA_FULL_22_23, + sequence_len=sequence_len, + as_sequence=data_as_sequence, ), ) logger.info(f"Accuracy from next season: {hist.eval_accuracy_next:.4f}") @@ -151,6 +165,8 @@ def run_train( model=model, loader=get_eval_dataloader( parquet_file=STREAK_DATA_TRAINING_SHORT, + sequence_len=sequence_len, + as_sequence=data_as_sequence, ), print_report=False, ) @@ -162,6 +178,8 @@ def run_train( model=model, loader=get_eval_dataloader( parquet_file=STREAK_DATA_TRAINING_LONG, + sequence_len=sequence_len, + as_sequence=data_as_sequence, ), print_report=False, ) @@ -173,6 +191,8 @@ def run_train( model=model, loader=get_eval_dataloader( parquet_file=STREAK_DATA_EVALUATION_SHORT, + sequence_len=sequence_len, + as_sequence=data_as_sequence, ), print_report=False, ) @@ -184,6 +204,8 @@ def run_train( model=model, loader=get_eval_dataloader( parquet_file=STREAK_DATA_EVALUATION_LONG, + sequence_len=sequence_len, + as_sequence=data_as_sequence, ), print_report=False, ) diff --git a/src/utils.py b/src/utils.py index 01ab954..e69ce63 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass, field from datetime import datetime from typing import List @@ -78,3 +79,26 @@ def update( self.train_accuracy.append(float(train_accuracy)) self.test_loss.append(float(test_loss)) self.test_accuracy.append(float(test_accuracy)) + + +def check_dir(path: str, output: bool = True) -> str: + """ + Utility to check the directory exists, and to create it if not. + + Parameters + ---------- + path: str + output: bool = True + + Returns + ------- + str + The given path, once it's sure to exist. + """ + if output and not path.startswith("output/"): + path = "output/" + path.lstrip("/") + + if not os.path.exists(path=path): + os.mkdir(path=path) + + return path