From 51428d7756e62b9b8ee5379f38e9fd576eeb36e5 Mon Sep 17 00:00:00 2001 From: Andrey Kulagin Date: Mon, 14 Dec 2020 21:29:35 +0300 Subject: [PATCH] Apex with DataParallel fixes (#1032) * Apex with DataParallel fixes * Codestyle * Add noqa to use protected members * Changelog upd, codestyle --- CHANGELOG.md | 3 + catalyst/utils/components.py | 109 +++++++++++++++++++++++++++++------ 2 files changed, 94 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fff0d1d83..6f043c1cef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - prevented modifying config during the experiment and runner initialization ([#1004](https://github.com/catalyst-team/catalyst/pull/1004)) - a few test for RecSys MAP computation ([#1018](https://github.com/catalyst-team/catalyst/pull/1014)) - leave batch size the same for default distributed training ([#1023](https://github.com/catalyst-team/catalyst/issues/1023)) +- ([#1032](https://github.com/catalyst-team/catalyst/pull/1032)) + - Apex: now you can use apex for multiple models training + - Apex: DataParallel is allowed for opt_level other than "O1" diff --git a/catalyst/utils/components.py b/catalyst/utils/components.py index 46db777742..b90197c6fc 100644 --- a/catalyst/utils/components.py +++ b/catalyst/utils/components.py @@ -6,7 +6,14 @@ import torch.distributed from catalyst.settings import IS_XLA_AVAILABLE -from catalyst.typing import Criterion, Device, Model, Optimizer, Scheduler +from catalyst.typing import ( + Criterion, + Device, + Model, + Optimizer, + RunnerModel, + Scheduler, +) from catalyst.utils.distributed import ( check_amp_available, check_apex_available, @@ -19,14 +26,84 @@ from catalyst.utils.torch import get_device +def _patch_forward(model): + import apex + + input_caster_lambda = ( + lambda tensor: tensor.to( + apex.amp._amp_state.opt_properties.options[ # noqa: WPS437 + "cast_model_type" + ] + ) + if tensor.is_floating_point() + else tensor + ) + output_caster_lambda = ( + lambda tensor: tensor.to( + apex.amp._amp_state.opt_properties.options.get( # noqa: WPS437 + "cast_model_outputs", torch.float32 + ) + ) + if tensor.is_floating_point() + else tensor + ) + + def new_fwd( + *args, + old_fwd=model.forward, + input_caster=input_caster_lambda, + output_caster=output_caster_lambda, + **kwargs, + ): + return apex.amp._initialize.applier( # noqa: WPS437 + old_fwd( + *apex.amp._initialize.applier( # noqa: WPS437 + args, input_caster + ), + **apex.amp._initialize.applier( # noqa: WPS437 + kwargs, input_caster + ), + ), + output_caster, + ) + + model.forward = new_fwd + return model + + +# apex issue https://github.com/deepset-ai/FARM/issues/210 +# solution: https://github.com/NVIDIA/apex/issues/503#issuecomment-566181771 +def _wrap_into_data_parallel_with_apex( + model: RunnerModel, optimizer: Optimizer, distributed_params: Dict +): + if isinstance(model, nn.Module): + model = nn.Sequential(model) + model, optimizer = initialize_apex( + model, optimizer, **distributed_params + ) + model = torch.nn.DataParallel(model[0]) + model = _patch_forward(model) + elif isinstance(model, dict): + model = {k: nn.Sequential(v) for k, v in model.items()} + model, optimizer = initialize_apex( + model, optimizer, **distributed_params + ) + model = {k: nn.DataParallel(v[0]) for k, v in model.items()} + model = {k: _patch_forward(v) for k, v in model.items()} + else: + raise NotImplementedError() + + return model, optimizer + + def process_components( - model: Model, + model: RunnerModel, criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, distributed_params: Dict = None, device: Device = None, -) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]: +) -> Tuple[RunnerModel, Criterion, Optimizer, Scheduler, Device]: """ Returns the processed model, criterion, optimizer, scheduler and device. @@ -114,27 +191,18 @@ def process_components( ) # data parallel run (dp) (with apex support) else: - # apex issue https://github.com/deepset-ai/FARM/issues/210 - use_apex = (is_apex_enabled and torch.cuda.device_count() == 1) or ( - is_apex_enabled - and torch.cuda.device_count() > 1 - and distributed_params.get("opt_level", "O0") == "O1" + is_data_parallel = ( + torch.cuda.device_count() > 1 + and device.type != "cpu" + and device.index is None ) - if use_apex: - assert isinstance( - model, nn.Module - ), "Apex training is not available for KV model" - + if is_apex_enabled and not is_data_parallel: model, optimizer = initialize_apex( model, optimizer, **distributed_params ) - if ( - torch.cuda.device_count() > 1 - and device.type != "cpu" - and device.index is None - ): + elif not is_apex_enabled and is_data_parallel: if isinstance(model, nn.Module): model = nn.DataParallel(model) elif isinstance(model, dict): @@ -142,6 +210,11 @@ def process_components( else: raise NotImplementedError() + elif is_apex_enabled and is_data_parallel: + model, optimizer = _wrap_into_data_parallel_with_apex( + model, optimizer, distributed_params + ) + model: Model = maybe_recursive_call(model, "to", device=device) return model, criterion, optimizer, scheduler, device