Skip to content

Commit

Permalink
Rolled back changes in sampler as it leads to diverging and noisy los…
Browse files Browse the repository at this point in the history
…s with TP>1.
  • Loading branch information
Christian Bock committed Nov 6, 2023
1 parent 1d2a95e commit 87fb264
Showing 1 changed file with 10 additions and 31 deletions.
41 changes: 10 additions & 31 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
xm = None

if is_neuronx_distributed_available():
from neuronx_distributed import parallel_layers
from neuronx_distributed.utils.model_utils import move_model_to_device


Expand Down Expand Up @@ -142,26 +143,15 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, z
if num_steps != 1:
self.gradient_accumulation_steps = num_steps

def _prepare_data_loader_for_distributed(
self, data_loader: DataLoader, num_replicas: int, rank: int
) -> DataLoader:
def _prepare_data_loader_for_tp(self, data_loader: DataLoader) -> DataLoader:
# TODO: make it more robust, similar to the prepare_data_loader function in `accelerate`.
if isinstance(data_loader.sampler, DistributedSampler):
return data_loader

orig_sampler = data_loader.sampler
if isinstance(orig_sampler, torch.utils.data.SequentialSampler):
shuffle = False
else:
shuffle = True
if not isinstance(orig_sampler, torch.utils.data.RandomSampler):
logger.warning(
f"The sampler {orig_sampler} is going to be replaced by a torch.utils.data.DistributedSampler. This "
"new sampler will shuffle the dataset, it might not be the expected behaviour."
)

sampler = DistributedSampler(data_loader.dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)

sampler = DistributedSampler(
data_loader.dataset,
num_replicas=parallel_layers.parallel_state.get_data_parallel_size(),
rank=parallel_layers.parallel_state.get_data_parallel_rank(),
)
data_loader_for_tp = DataLoader(
data_loader.dataset,
batch_size=data_loader.batch_size,
Expand All @@ -176,15 +166,8 @@ def _prepare_data_loader_for_distributed(

def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optional[bool] = None):
if self.state.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
from neuronx_distributed import parallel_layers

num_replicas = parallel_layers.parallel_state.get_data_parallel_size()
rank = parallel_layers.parallel_state.get_data_parallel_rank()
else:
num_replicas = xm.xrt_world_size()
rank = xm.get_ordinal()
data_loader = self._prepare_data_loader_for_tp(data_loader)
if self.state.num_processes > 1:
data_loader = self._prepare_data_loader_for_distributed(data_loader, num_replicas=num_replicas, rank=rank)
data_loader = MpDeviceLoader(data_loader, self.device)
return data_loader
# TODO: fix that.
Expand Down Expand Up @@ -221,7 +204,7 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device
model_parallel_is_initialized,
)

if not model_parallel_is_initialized():
if not is_neuronx_distributed_available() or not model_parallel_is_initialized():
sharding_groups = None
grad_norm_groups = None
else:
Expand Down Expand Up @@ -343,14 +326,10 @@ def prepare_model_for_xla_fsdp(
def _prepare_model_for_tp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
if model in self._models or Parallelizer.was_parallelized(model):
return model

cpu_ids = [id(v) for v in model.parameters()]
# TODO: enable self.device (if needed).
model = self.state.tp_plugin.parallelize_model(model, device=None)

if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
if os.environ.get("XLA_USE_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
model.to(torch.float32)
Expand Down

0 comments on commit 87fb264

Please sign in to comment.