diff --git a/src/sparcscore/ml/utils.py b/src/sparcscore/ml/utils.py index 0add2a5..ba65384 100644 --- a/src/sparcscore/ml/utils.py +++ b/src/sparcscore/ml/utils.py @@ -103,3 +103,45 @@ def combine_datasets_balanced(list_of_datasets, class_labels, train_per_class, v val_dataset = torch.utils.data.ConcatDataset(val_dataset) return train_dataset, val_dataset, test_dataset + + +def split_dataset_regression(dataset, train_size, test_size, val_size, seed=None): + """ + Split a dataset into train, test, and validation set. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + Dataset to be split. + train_size : int + Number of samples in the train set. + test_size : int + Number of samples in the test set. + val_size : int + Number of samples in the validation set. + + Returns + ------- + train : torch.utils.data.Dataset + Train dataset. + val : torch.utils.data.Dataset + Validation dataset. + test : torch.utils.data.Dataset + Test dataset. + """ + residual_size = len(dataset) - train_size - test_size - val_size + + if residual_size < 0: + raise ValueError( + f"Dataset with length {len(dataset)} is too small to be split into test set of size {test_size}, " + f"train set of size {train_size}, and validation set of size {val_size}. " + ) + + if seed is not None: + gen = torch.Generator() + gen.manual_seed(seed) + train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size], generator=gen) + else: + train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size]) + + return train, val, test