Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix training crash issue on multi-nodes when dataloader_num_workers>0 #1721

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import inspect
import json
import math
import multiprocessing
import os
import random
import shutil
Expand Down Expand Up @@ -87,6 +88,7 @@
find_executable_batch_size,
get_last_checkpoint,
has_length,
seed_worker,
)
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import (
Expand Down Expand Up @@ -333,6 +335,157 @@ def _move_model_to_device(self, model, device):
if self.args.use_habana and hasattr(model, "tie_weights"):
model.tie_weights()

def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].

Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.

Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

if self.args.dataloader_num_workers > 0:
multiprocessing_context = multiprocessing.get_context("forkserver")
else:
multiprocessing_context = multiprocessing.get_context("fork")

dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"multiprocessing_context": multiprocessing_context,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
dataloader_params["multiprocessing_context"] = multiprocessing_context

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].

Subclass and override this method if you want to inject some custom behavior.

Args:
eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")

# If we have persistent workers, don't do a fork bomb especially as eval datasets
# don't change during training
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
if (
hasattr(self, "_eval_dataloaders")
and dataloader_key in self._eval_dataloaders
and self.args.dataloader_persistent_workers
):
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])

eval_dataset = (
self.eval_dataset[eval_dataset]
if isinstance(eval_dataset, str)
else eval_dataset
if eval_dataset is not None
else self.eval_dataset
)
data_collator = self.data_collator

if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

if self.args.dataloader_num_workers > 0:
multiprocessing_context = multiprocessing.get_context("forkserver")
else:
multiprocessing_context = multiprocessing.get_context("fork")

dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"multiprocessing_context": multiprocessing_context,
}

if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
dataloader_params["multiprocessing_context"] = multiprocessing_context

# accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
if self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = eval_dataloader
else:
self._eval_dataloaders = {dataloader_key: eval_dataloader}

return self.accelerator.prepare(eval_dataloader)

def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Returns the test [`~torch.utils.data.DataLoader`].

Subclass and override this method if you want to inject some custom behavior.

Args:
test_dataset (`torch.utils.data.Dataset`, *optional*):
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator

if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")

if self.args.dataloader_num_workers > 0:
multiprocessing_context = multiprocessing.get_context("forkserver")
else:
multiprocessing_context = multiprocessing.get_context("fork")

dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"multiprocessing_context": multiprocessing_context,
}

if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
dataloader_params["multiprocessing_context"] = multiprocessing_context

# We use the same batch_size as for eval.
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
Expand Down