Skip to content

Commit

Permalink
add method for data splitting regression
Browse files Browse the repository at this point in the history
  • Loading branch information
namsaraeva committed May 13, 2024
1 parent 32b805f commit 72ee7d6
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/sparcscore/ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,45 @@ def combine_datasets_balanced(list_of_datasets, class_labels, train_per_class, v
val_dataset = torch.utils.data.ConcatDataset(val_dataset)

return train_dataset, val_dataset, test_dataset


def split_dataset_regression(dataset, train_size, test_size, val_size, seed=None):
"""
Split a dataset into train, test, and validation set.
Parameters
----------
dataset : torch.utils.data.Dataset
Dataset to be split.
train_size : int
Number of samples in the train set.
test_size : int
Number of samples in the test set.
val_size : int
Number of samples in the validation set.
Returns
-------
train : torch.utils.data.Dataset
Train dataset.
val : torch.utils.data.Dataset
Validation dataset.
test : torch.utils.data.Dataset
Test dataset.
"""
residual_size = len(dataset) - train_size - test_size - val_size

if residual_size < 0:
raise ValueError(
f"Dataset with length {len(dataset)} is too small to be split into test set of size {test_size}, "
f"train set of size {train_size}, and validation set of size {val_size}. "
)

if seed is not None:
gen = torch.Generator()
gen.manual_seed(seed)
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size], generator=gen)
else:
train, test, val, _ = torch.utils.data.random_split(dataset, [train_size, test_size, val_size, residual_size])

return train, val, test

0 comments on commit 72ee7d6

Please sign in to comment.