Skip to content

Commit

Permalink
[WIP] fix greedy
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 17, 2023
1 parent 28046d0 commit 6230837
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 16 deletions.
8 changes: 4 additions & 4 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,10 @@ def save_state(self, output_dir: Optional[str] = None, **save_model_func_kwargs)
return self.save_state_for_tp(output_dir=output_dir, **save_model_func_kwargs)
return super().save_state(output_dir=output_dir, **save_model_func_kwargs)

def gather(self, tensor, out_of_graph: bool = False):
return _xla_gather(tensor, out_of_graph=out_of_graph)
def gather(self, tensor, out_of_graph: bool = False, gather_axis: Optional[str] = None):
return _xla_gather(tensor, out_of_graph=out_of_graph, gather_axis=gather_axis)

def gather_for_metrics(self, input_data):
def gather_for_metrics(self, input_data, gather_axis: Optional[str] = None):
try:
recursively_apply(lambda x: x, input_data, error_on_other_type=True)
all_tensors = True
Expand All @@ -529,7 +529,7 @@ def gather_for_metrics(self, input_data):
data = gather_object(input_data)
else:
# It is needed to perform out-of-graph gather otherwise re-compilation happens at every evaluation step.
data = self.gather(input_data, out_of_graph=True)
data = self.gather(input_data, out_of_graph=True, gather_axis=gather_axis)

try:
if self.gradient_state.end_of_dataloader:
Expand Down
41 changes: 39 additions & 2 deletions optimum/neuron/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,38 @@
# limitations under the License.
"""Custom operations related to accelerate for Neuron."""

from typing import Optional

import torch
from accelerate.utils.operations import recursively_apply

from ...utils import is_neuronx_distributed_available
from ...utils.require_utils import requires_torch_xla


@requires_torch_xla
def _xla_gather(tensor, out_of_graph: bool = False):
def _xla_gather(tensor, out_of_graph: bool = False, gather_axis: Optional[str] = None):
import torch_xla.core.xla_model as xm

# if gather_axis is not None:
# if gather_axis not in ["dp", "tp", "pp"]:
# raise ValueError(f'Wrong value for gather_axis ({gather_axis}), expected: "dp", "tp" or "pp"')
# if not is_neuronx_distributed_available():
# raise RuntimeError("The `neuronx_distributed` package is required.")

# from neuronx_distributed.parallel_layers.parallel_state import (
# get_tensor_model_parallel_group,
# get_data_parallel_group,
# )
# axis2groups = {
# "dp": get_data_parallel_group(as_list=True),
# "tp": get_tensor_model_parallel_group(as_list=True),
# }
# groups = axis2groups[gather_axis]
# else:
# groups = None
groups = None

def _xla_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]
Expand All @@ -32,7 +54,22 @@ def _xla_gather_one(tensor):
tensor = tensor.contiguous()

if out_of_graph:
gathered = xm.mesh_reduce("nested_xla_gather", tensor, torch.cat)
gathered = xm.mesh_reduce("nested_xla_gather", tensor, lambda x: x)
if groups is not None:
new_gathered = []
visited_replicas = set()
for idx, tensor in enumerate(gathered):
for replica_idx, group in enumerate(groups):
if idx in group:
if replica_idx in visited_replicas:
continue
else:
new_gathered.append(tensor)
visited_replicas.add(replica_idx)
print("Gathered", gathered)
print("New gathered", new_gathered)
gathered = new_gathered
gathered = torch.cat(gathered)
else:
gathered = xm.all_gather(tensor)
return gathered
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,8 @@ def greedy_search(
else:
next_token_logits = outputs.logits[:, -1, :]

xm.mark_step()

# pre-process distribution
# Move to cpu to handle arbitrary logits_processor
next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu"))
Expand Down
237 changes: 227 additions & 10 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
from packaging import version
from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments
from transformers.dependency_versions_check import dep_version_check
from transformers.integrations import is_fairscale_available
from transformers.integrations import is_fairscale_available, deepspeed_init
from transformers.trainer import (
OPTIMIZER_NAME,
SCHEDULER_NAME,
TRAINER_STATE_NAME,
TRAINING_ARGS_NAME,
)
from transformers.trainer_pt_utils import reissue_pt_warnings
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalLoopOutput
from transformers.trainer_pt_utils import reissue_pt_warnings, find_batch_size, nested_concat, nested_numpify, IterableDatasetShard
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalLoopOutput, has_length, EvalPrediction, denumpify_detensorize
from transformers.utils import is_sagemaker_mp_enabled

from ..utils import check_if_transformers_greater, logging
Expand Down Expand Up @@ -358,6 +358,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)


def _save_checkpoint_with_accelerator(self, model, trial, metrics=None):
if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP and not self.is_fsdp_enabled:
# TODO: handle this case better?
Expand Down Expand Up @@ -506,13 +507,229 @@ def evaluation_loop(
# 2. The model needs to be parallelized.
self.accelerator.prepare_model(self.model)

return super().evaluation_loop(
dataloader,
description,
prediction_loss_only=prediction_loss_only,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
args = self.args

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.deepspeed is None:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)

model = self._wrap_model(self.model, training=False, dataloader=dataloader)

if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)

if self.is_fsdp_enabled:
self.model = model

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model

# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped

# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)

batch_size = self.args.eval_batch_size

logger.info(f"***** Running {description} *****")
if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")

model.eval()

self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)

if args.past_index >= 0:
self._past = None

# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
inputs_host = None

# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length.

observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

if is_torch_xla_available():
xm.mark_step()

# Update containers on host
if loss is not None:
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)), gather_axis="tp")
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode), gather_axis="tp")
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.accelerator.gather_for_metrics((logits))
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)

if labels is not None:
labels = self.accelerator.gather_for_metrics((labels))
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)

self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if (
args.eval_accumulation_steps is not None
and (step + 1) % args.eval_accumulation_steps == 0
and (self.accelerator.sync_gradients or version.parse(accelerate_version) > version.parse("0.20.3"))
):
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)

# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = None, None, None, None

if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")

# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples
if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples

# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
if args.include_inputs_for_metrics:
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
)
else:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else:
metrics = {}

# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)

if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time

# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
# def evaluation_loop(
# self,
# dataloader: torch.utils.data.DataLoader,
# description: str,
# prediction_loss_only: Optional[bool] = None,
# ignore_keys: Optional[List[str]] = None,
# metric_key_prefix: str = "eval",
# ) -> EvalLoopOutput:
# # This will prepare the model if it was not prepared before.
# # This is needed for example for TP when we performing only evaluation (no training):
# # 1. The model needs to be loaded if it was lazy loaded.
# # 2. The model needs to be parallelized.
# self.accelerator.prepare_model(self.model)

# return super().evaluation_loop(
# dataloader,
# description,
# prediction_loss_only=prediction_loss_only,
# ignore_keys=ignore_keys,
# metric_key_prefix=metric_key_prefix,
# )


class NeuronTrainer(AugmentTrainerForNeuronMixin, Trainer):
Expand Down

0 comments on commit 6230837

Please sign in to comment.