Skip to content

Commit

Permalink
Apex with DataParallel fixes (#1032)
Browse files Browse the repository at this point in the history
* Apex with DataParallel fixes

* Codestyle

* Add noqa to use protected members

* Changelog upd, codestyle
  • Loading branch information
and-kul authored Dec 14, 2020
1 parent dfb9d5a commit 51428d7
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 18 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- prevented modifying config during the experiment and runner initialization ([#1004](https://github.com/catalyst-team/catalyst/pull/1004))
- a few test for RecSys MAP computation ([#1018](https://github.com/catalyst-team/catalyst/pull/1014))
- leave batch size the same for default distributed training ([#1023](https://github.com/catalyst-team/catalyst/issues/1023))
- ([#1032](https://github.com/catalyst-team/catalyst/pull/1032))
- Apex: now you can use apex for multiple models training
- Apex: DataParallel is allowed for opt_level other than "O1"



Expand Down
109 changes: 91 additions & 18 deletions catalyst/utils/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import torch.distributed

from catalyst.settings import IS_XLA_AVAILABLE
from catalyst.typing import Criterion, Device, Model, Optimizer, Scheduler
from catalyst.typing import (
Criterion,
Device,
Model,
Optimizer,
RunnerModel,
Scheduler,
)
from catalyst.utils.distributed import (
check_amp_available,
check_apex_available,
Expand All @@ -19,14 +26,84 @@
from catalyst.utils.torch import get_device


def _patch_forward(model):
import apex

input_caster_lambda = (
lambda tensor: tensor.to(
apex.amp._amp_state.opt_properties.options[ # noqa: WPS437
"cast_model_type"
]
)
if tensor.is_floating_point()
else tensor
)
output_caster_lambda = (
lambda tensor: tensor.to(
apex.amp._amp_state.opt_properties.options.get( # noqa: WPS437
"cast_model_outputs", torch.float32
)
)
if tensor.is_floating_point()
else tensor
)

def new_fwd(
*args,
old_fwd=model.forward,
input_caster=input_caster_lambda,
output_caster=output_caster_lambda,
**kwargs,
):
return apex.amp._initialize.applier( # noqa: WPS437
old_fwd(
*apex.amp._initialize.applier( # noqa: WPS437
args, input_caster
),
**apex.amp._initialize.applier( # noqa: WPS437
kwargs, input_caster
),
),
output_caster,
)

model.forward = new_fwd
return model


# apex issue https://github.com/deepset-ai/FARM/issues/210
# solution: https://github.com/NVIDIA/apex/issues/503#issuecomment-566181771
def _wrap_into_data_parallel_with_apex(
model: RunnerModel, optimizer: Optimizer, distributed_params: Dict
):
if isinstance(model, nn.Module):
model = nn.Sequential(model)
model, optimizer = initialize_apex(
model, optimizer, **distributed_params
)
model = torch.nn.DataParallel(model[0])
model = _patch_forward(model)
elif isinstance(model, dict):
model = {k: nn.Sequential(v) for k, v in model.items()}
model, optimizer = initialize_apex(
model, optimizer, **distributed_params
)
model = {k: nn.DataParallel(v[0]) for k, v in model.items()}
model = {k: _patch_forward(v) for k, v in model.items()}
else:
raise NotImplementedError()

return model, optimizer


def process_components(
model: Model,
model: RunnerModel,
criterion: Criterion = None,
optimizer: Optimizer = None,
scheduler: Scheduler = None,
distributed_params: Dict = None,
device: Device = None,
) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
) -> Tuple[RunnerModel, Criterion, Optimizer, Scheduler, Device]:
"""
Returns the processed model, criterion, optimizer, scheduler and device.
Expand Down Expand Up @@ -114,34 +191,30 @@ def process_components(
)
# data parallel run (dp) (with apex support)
else:
# apex issue https://github.com/deepset-ai/FARM/issues/210
use_apex = (is_apex_enabled and torch.cuda.device_count() == 1) or (
is_apex_enabled
and torch.cuda.device_count() > 1
and distributed_params.get("opt_level", "O0") == "O1"
is_data_parallel = (
torch.cuda.device_count() > 1
and device.type != "cpu"
and device.index is None
)

if use_apex:
assert isinstance(
model, nn.Module
), "Apex training is not available for KV model"

if is_apex_enabled and not is_data_parallel:
model, optimizer = initialize_apex(
model, optimizer, **distributed_params
)

if (
torch.cuda.device_count() > 1
and device.type != "cpu"
and device.index is None
):
elif not is_apex_enabled and is_data_parallel:
if isinstance(model, nn.Module):
model = nn.DataParallel(model)
elif isinstance(model, dict):
model = {k: nn.DataParallel(v) for k, v in model.items()}
else:
raise NotImplementedError()

elif is_apex_enabled and is_data_parallel:
model, optimizer = _wrap_into_data_parallel_with_apex(
model, optimizer, distributed_params
)

model: Model = maybe_recursive_call(model, "to", device=device)

return model, criterion, optimizer, scheduler, device
Expand Down

0 comments on commit 51428d7

Please sign in to comment.