diff --git a/src/setfit/data.py b/src/setfit/data.py index ff5a0c33..8e6149d6 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -170,7 +170,7 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i df = df.groupby(label_column) # sample num_samples, or at least as much as possible - df = df.apply(lambda x: x.sample(min(num_samples, len(x)))) + df = df.apply(lambda x: x.sample(min(num_samples, len(x)), random_state=seed)) df = df.reset_index(drop=True) all_samples = Dataset.from_pandas(df, features=dataset.features)