-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix loss log when using TP * Make evaluation work with DP / TP * Final changes
- Loading branch information
1 parent
571effd
commit 4b37209
Showing
6 changed files
with
156 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,16 +14,28 @@ | |
# limitations under the License. | ||
"""Custom operations related to accelerate for Neuron.""" | ||
|
||
|
||
import torch | ||
from accelerate.utils.operations import recursively_apply | ||
|
||
from ...utils import is_neuronx_distributed_available | ||
from ...utils.require_utils import requires_torch_xla | ||
|
||
|
||
@requires_torch_xla | ||
def _xla_gather(tensor, out_of_graph: bool = False): | ||
import torch_xla.core.xla_model as xm | ||
|
||
groups = None | ||
if is_neuronx_distributed_available(): | ||
from neuronx_distributed.parallel_layers.parallel_state import ( | ||
get_data_parallel_group, | ||
model_parallel_is_initialized, | ||
) | ||
|
||
if model_parallel_is_initialized(): | ||
groups = get_data_parallel_group(as_list=True) | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
michaelbenayoun
Author
Member
|
||
|
||
def _xla_gather_one(tensor): | ||
if tensor.ndim == 0: | ||
tensor = tensor.clone()[None] | ||
|
@@ -32,9 +44,20 @@ def _xla_gather_one(tensor): | |
tensor = tensor.contiguous() | ||
|
||
if out_of_graph: | ||
gathered = xm.mesh_reduce("nested_xla_gather", tensor, torch.cat) | ||
gathered_tensors = xm.mesh_reduce("nested_xla_gather", tensor, lambda x: x) | ||
if groups is not None: | ||
new_gathered_tensors = [] | ||
# Since groups is containing list of group of replicas, we consider that visiting the first group of | ||
# replicas is enough since the value should be the same accross other axes. | ||
replicas_to_consider = set(groups[0]) | ||
for idx, tensor in enumerate(gathered_tensors): | ||
if idx not in replicas_to_consider: | ||
continue | ||
new_gathered_tensors.append(tensor) | ||
gathered_tensors = new_gathered_tensors | ||
gathered = torch.cat(gathered_tensors) | ||
else: | ||
gathered = xm.all_gather(tensor) | ||
gathered = xm.all_gather(tensor, groups=groups, pin_layout=False) | ||
return gathered | ||
|
||
res = recursively_apply(_xla_gather_one, tensor, error_on_other_type=True) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Is it always certain that we want to gather over data parallel groups?