Skip to content

Commit

Permalink
Various fixes for TP (#260)
Browse files Browse the repository at this point in the history
* Fix loss log when using TP

* Make evaluation work with DP / TP

* Final changes
  • Loading branch information
michaelbenayoun authored Oct 23, 2023
1 parent 571effd commit 4b37209
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 22 deletions.
37 changes: 27 additions & 10 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
xm = None

if is_neuronx_distributed_available():
from neuronx_distributed import parallel_layers
from neuronx_distributed.utils.model_utils import move_model_to_device


Expand Down Expand Up @@ -143,15 +142,26 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, z
if num_steps != 1:
self.gradient_accumulation_steps = num_steps

def _prepare_data_loader_for_tp(self, data_loader: DataLoader) -> DataLoader:
def _prepare_data_loader_for_distributed(
self, data_loader: DataLoader, num_replicas: int, rank: int
) -> DataLoader:
# TODO: make it more robust, similar to the prepare_data_loader function in `accelerate`.
if isinstance(data_loader.sampler, DistributedSampler):
return data_loader
sampler = DistributedSampler(
data_loader.dataset,
num_replicas=parallel_layers.parallel_state.get_data_parallel_size(),
rank=parallel_layers.parallel_state.get_data_parallel_rank(),
)

orig_sampler = data_loader.sampler
if isinstance(orig_sampler, torch.utils.data.SequentialSampler):
shuffle = False
else:
shuffle = True
if not isinstance(orig_sampler, torch.utils.data.RandomSampler):
logger.warning(
f"The sampler {orig_sampler} is going to be replaced by a torch.utils.data.DistributedSampler. This "
"new sampler will shuffle the dataset, it might not be the expected behaviour."
)

sampler = DistributedSampler(data_loader.dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)

data_loader_for_tp = DataLoader(
data_loader.dataset,
batch_size=data_loader.batch_size,
Expand All @@ -166,8 +176,15 @@ def _prepare_data_loader_for_tp(self, data_loader: DataLoader) -> DataLoader:

def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optional[bool] = None):
if self.state.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
data_loader = self._prepare_data_loader_for_tp(data_loader)
from neuronx_distributed import parallel_layers

num_replicas = parallel_layers.parallel_state.get_data_parallel_size()
rank = parallel_layers.parallel_state.get_data_parallel_rank()
else:
num_replicas = xm.xrt_world_size()
rank = xm.get_ordinal()
if self.state.num_processes > 1:
data_loader = self._prepare_data_loader_for_distributed(data_loader, num_replicas=num_replicas, rank=rank)
data_loader = MpDeviceLoader(data_loader, self.device)
return data_loader
# TODO: fix that.
Expand Down Expand Up @@ -204,7 +221,7 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device
model_parallel_is_initialized,
)

if not is_neuronx_distributed_available() or not model_parallel_is_initialized():
if not model_parallel_is_initialized():
sharding_groups = None
grad_norm_groups = None
else:
Expand Down Expand Up @@ -329,7 +346,7 @@ def _prepare_model_for_tp(
cpu_ids = [id(v) for v in model.parameters()]
# TODO: enable self.device (if needed).
model = self.state.tp_plugin.parallelize_model(model, device=None)
if os.environ.get("XLA_USE_BF16", "0") == "1":
if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
model.to(torch.float32)
Expand Down
27 changes: 25 additions & 2 deletions optimum/neuron/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,28 @@
# limitations under the License.
"""Custom operations related to accelerate for Neuron."""


import torch
from accelerate.utils.operations import recursively_apply

from ...utils import is_neuronx_distributed_available
from ...utils.require_utils import requires_torch_xla


@requires_torch_xla
def _xla_gather(tensor, out_of_graph: bool = False):
import torch_xla.core.xla_model as xm

groups = None
if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
model_parallel_is_initialized,
)

if model_parallel_is_initialized():
groups = get_data_parallel_group(as_list=True)

This comment has been minimized.

Copy link
@bocchris-aws

bocchris-aws Nov 3, 2023

Contributor

Is it always certain that we want to gather over data parallel groups?

This comment has been minimized.

Copy link
@michaelbenayoun

michaelbenayoun Nov 3, 2023

Author Member

For now I would say yes. Tis function is only used for evaluation purposes if I'm not mistaken.


def _xla_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]
Expand All @@ -32,9 +44,20 @@ def _xla_gather_one(tensor):
tensor = tensor.contiguous()

if out_of_graph:
gathered = xm.mesh_reduce("nested_xla_gather", tensor, torch.cat)
gathered_tensors = xm.mesh_reduce("nested_xla_gather", tensor, lambda x: x)
if groups is not None:
new_gathered_tensors = []
# Since groups is containing list of group of replicas, we consider that visiting the first group of
# replicas is enough since the value should be the same accross other axes.
replicas_to_consider = set(groups[0])
for idx, tensor in enumerate(gathered_tensors):
if idx not in replicas_to_consider:
continue
new_gathered_tensors.append(tensor)
gathered_tensors = new_gathered_tensors
gathered = torch.cat(gathered_tensors)
else:
gathered = xm.all_gather(tensor)
gathered = xm.all_gather(tensor, groups=groups, pin_layout=False)
return gathered

res = recursively_apply(_xla_gather_one, tensor, error_on_other_type=True)
Expand Down
15 changes: 10 additions & 5 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,15 @@ def save_model_checkpoint_as_sharded(
optimizer: Optional["torch.optim.Optimizer"] = None,
):
cls._check_model_was_parallelized(model)

from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

data_parallel_rank = get_data_parallel_rank()
tensor_parallel_rank = get_tensor_model_parallel_rank()

if not isinstance(output_dir, Path):
output_dir = Path(output_dir)

Expand All @@ -474,12 +483,8 @@ def save_model_checkpoint_as_sharded(
state_dict["optimizer_state_dict"] = optimizer.state_dict()

output_path = output_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

if get_data_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0:
if data_parallel_rank == 0 and tensor_parallel_rank == 0:
if output_path.is_dir():
shutil.rmtree(output_path, ignore_errors=True)
output_path.mkdir()
Expand Down
3 changes: 2 additions & 1 deletion optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,8 @@ def greedy_search(
else:
next_token_logits = outputs.logits[:, -1, :]

xm.mark_step()

# pre-process distribution
# Move to cpu to handle arbitrary logits_processor
next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu"))
Expand Down Expand Up @@ -1302,7 +1304,6 @@ def greedy_search(
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step

batch_size, _ = input_ids.shape
update_indices = torch.stack(
[torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
Expand Down
81 changes: 79 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@
TRAINER_STATE_NAME,
TRAINING_ARGS_NAME,
)
from transformers.trainer_pt_utils import reissue_pt_warnings
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalLoopOutput
from transformers.trainer_pt_utils import (
reissue_pt_warnings,
)
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
)
from transformers.utils import is_sagemaker_mp_enabled

from ..utils import check_if_transformers_greater, logging
Expand All @@ -55,6 +60,7 @@
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
get_model_param_count,
is_precompilation,
is_topology_supported,
patch_generation_mixin_to_neuron_generation_mixin,
patched_finfo,
prepare_environment_for_neuron,
Expand Down Expand Up @@ -130,6 +136,12 @@ def __init__(self, *args, **kwargs):
if not isinstance(self, Trainer):
raise TypeError(f"{self.__class__.__name__} can only be mixed with Trainer subclasses.")

if not is_topology_supported():
num_devices = xm.xrt_world_size()
raise ValueError(
f"Topology not supported. Supported number of devices: 1, 2, 8 or a multiple of 32. Got: {num_devices}."
)

training_args = kwargs.get("args", None)
if training_args is None and len(args) >= 2:
training_args = args[1]
Expand Down Expand Up @@ -255,6 +267,9 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
return None
return super()._get_train_sampler()

def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[torch.utils.data.Sampler]:
return torch.utils.data.SequentialSampler(eval_dataset)

@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args)
Expand Down Expand Up @@ -295,6 +310,68 @@ def _inner_training_loop(
ignore_keys_for_eval=ignore_keys_for_eval,
)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
logs: Dict[str, float] = {}

xm.mark_step()

if self.args.tp_plugin.tensor_parallel_size > 1:
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
get_data_parallel_size,
)

dp_size = get_data_parallel_size()
tr_loss_div = tr_loss / dp_size
tr_loss_scalar = xm.all_reduce(
xm.REDUCE_SUM,
tr_loss_div,
groups=get_data_parallel_group(as_list=True),
)
tr_loss_scalar = tr_loss_scalar.detach().item()
else:
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs)

metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)

def _save_checkpoint_with_accelerator(self, model, trial, metrics=None):
if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP and not self.is_fsdp_enabled:
# TODO: handle this case better?
Expand Down
15 changes: 13 additions & 2 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@
from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import NeuronGenerationMixin
from . import is_torch_xla_available
from .require_utils import requires_torch_xla


if TYPE_CHECKING:
from transformers import PreTrainedModel

if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl

TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP = "4.30.0.dev0"
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE = "4.30.0.dev0"
Expand Down Expand Up @@ -145,6 +144,15 @@ def is_model_officially_supported(model: "PreTrainedModel") -> bool:
return class_name in _SUPPORTED_MODEL_NAMES


@requires_torch_xla
def is_topology_supported() -> bool:
import torch_xla.core.xla_model as xm

num_devices = xm.xrt_world_size()
allowed_number_of_devices = [1, 2, 8]
return num_devices in allowed_number_of_devices or num_devices % 32 == 0


class FirstAndLastDataset(Dataset):
def __init__(
self, dataloader: DataLoader, num_repeat: int = 10, gradient_accumulation_steps: int = 1, world_size: int = 1
Expand Down Expand Up @@ -270,11 +278,14 @@ def patch_transformers_for_neuron_sdk():
transformers.utils.logging.set_verbosity = set_verbosity


@requires_torch_xla
def skip_first_batches(dataloader, num_batches=0):
"""
Wrapper around `accelerate.data_loader.skip_first_batches` to handle `pl.ParallelLoader` when using
`torch_xla.distributed`, for XLA FSDP for instance.
"""
import torch_xla.distributed.parallel_loader as pl

if isinstance(dataloader, (pl.ParallelLoader, pl.PerDeviceLoader)):
dataloader._loader = skip_first_batches(dataloader._loader, num_batches=num_batches)
else:
Expand Down

0 comments on commit 4b37209

Please sign in to comment.