From 097ab2ab02cf6e86203130aaf5d2a95bcfc5848e Mon Sep 17 00:00:00 2001 From: Shang Wang Date: Fri, 1 Oct 2021 01:55:05 -0400 Subject: [PATCH 1/2] [Optim][Adadelta] Update optim. --- hfta/ops/utils.py | 5 +- hfta/optim/__init__.py | 6 +- hfta/optim/_functional.py | 49 +++++++++++ hfta/optim/adadelta.py | 153 +++++++++++++++++---------------- hfta/optim/adadelta_test.py | 33 ++++++-- hfta/optim/lr_scheduler.py | 2 +- hfta/optim/utils.py | 165 ++++++++++++++++++++++++------------ 7 files changed, 271 insertions(+), 142 deletions(-) create mode 100644 hfta/optim/_functional.py diff --git a/hfta/ops/utils.py b/hfta/ops/utils.py index c41d614..ad734e7 100644 --- a/hfta/ops/utils.py +++ b/hfta/ops/utils.py @@ -2,7 +2,7 @@ import numpy as np import re -RE_PARSE_RATIO = re.compile('Mismatched elements: \d+ \/ \d+ \((\d+)\.(\d+)%\)') +RE_PARSE_RATIO = re.compile('Mismatched elements: (\d+) \/ (\d+)') def testcase_automator(testcase, configs): @@ -48,6 +48,5 @@ def assert_allclose( if not m: raise e else: - if (float('{}.{}'.format(m.group(1), m.group(2))) / 100 >= - population_threshold): + if (int(m.group(1)) / int(m.group(2))) >= population_threshold: raise e diff --git a/hfta/optim/__init__.py b/hfta/optim/__init__.py index 73e0c3f..1ef418c 100644 --- a/hfta/optim/__init__.py +++ b/hfta/optim/__init__.py @@ -2,19 +2,19 @@ import torch.optim from .adadelta import Adadelta, PartiallyFusedAdadelta -from .adam import Adam, PartiallyFusedAdam +#from .adam import Adam, PartiallyFusedAdam from .lr_scheduler import StepLR, PartiallyFusedStepLR from .utils import (index_array_or_return_scalar, consolidate_hyperparams_and_determine_B) _OPTIMIZERS_MAP = { torch.optim.Adadelta: Adadelta, - torch.optim.Adam: Adam, + #torch.optim.Adam: Adam, } _PARTIALLY_FUSED_OPTIMIZERS_MAP = { torch.optim.Adadelta: PartiallyFusedAdadelta, - torch.optim.Adam: PartiallyFusedAdam, + #torch.optim.Adam: PartiallyFusedAdam, } _LR_SCHEDULER_MAP = { diff --git a/hfta/optim/_functional.py b/hfta/optim/_functional.py new file mode 100644 index 0000000..ee5ff88 --- /dev/null +++ b/hfta/optim/_functional.py @@ -0,0 +1,49 @@ +import math +import torch +from torch import Tensor +from typing import List, Optional, Union + +from .utils import Coefficient, is_coefficient + + +def adadelta( + params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + *, + lr: Union[float, Coefficient], + rho: Union[float, Coefficient], + eps: Union[float, Coefficient], + weight_decay: [float, Coefficient], +): + r"""Functional API that performs Adadelta algorithm computation. + See :class:`~torch.optim.Adadelta` for details. + """ + + for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, + acc_deltas): + if is_coefficient(weight_decay) or weight_decay != 0: + if is_coefficient(weight_decay): + grad = grad + weight_decay[param] * param + else: + grad = grad.add(param, alpha=weight_decay) + + if is_coefficient(rho): + square_avg.mul_(rho[param]).add_((1 - rho[param]) * grad * grad) + else: + square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) + if is_coefficient(eps): + std = square_avg.add(eps[param]).sqrt_() + delta = acc_delta.add(eps[param]).sqrt_().div_(std).mul_(grad) + else: + std = square_avg.add(eps).sqrt_() + delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) + if is_coefficient(lr): + param.add_(-lr[param] * delta) + else: + param.add_(delta, alpha=-lr) + if is_coefficient(rho): + acc_delta.mul_(rho[param]).add_((1 - rho[param]) * delta * delta) + else: + acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) diff --git a/hfta/optim/adadelta.py b/hfta/optim/adadelta.py index 9de2ec7..26e46bb 100644 --- a/hfta/optim/adadelta.py +++ b/hfta/optim/adadelta.py @@ -1,119 +1,124 @@ import numpy as np import torch +from . import _functional as F from torch.optim import Optimizer -from .utils import (_validate_range, _broadcastablize, - _move_coeff_to_same_device, _reduce_array_if_possible_for, - _zero_grad_if_cuda, index_array_or_return_scalar) +from .utils import (make_coefficient, reduce_array_if_possible_for, + index_array_or_return_scalar) from .partial import PartiallyFusedOptimizer class Adadelta(Optimizer): - """Implements Adadelta algorithm. - - It has been proposed in `ADADELTA: An Adaptive Learning Rate Method`__. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - rho (float or a list/tuple/np.array/torch.Tensor of floats, optional): - coefficient used for computing a running average of squared - gradients (default: 0.9) - eps (float or a list/tuple/np.array/torch.Tensor of floats, optional): term - added to the denominator to improve numerical stability (default: 1e-6) - lr (float or a list/tuple/np.array/torch.Tensor of floats, optional): - coefficient that scale delta before it is applied to the parameters - (default: 1.0) - weight_decay (float or a list/tuple/np.array/torch.Tensor of floats, - optional): weight decay (L2 penalty) (default: 0) - - __ https://arxiv.org/abs/1212.5701 - """ + r"""Implements Adadelta algorithm. + + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, + \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)}, + \: \lambda \text{ (weight decay)} \\ + &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)}, + \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}if \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\ + &\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} + + \epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\ + &\hspace{5mm} u_t \leftarrow u_{t-1} \rho + + \Delta x^2_t (1 - \rho) \\ + &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + rho (float or a list/tuple/np.array/torch.Tensor of floats, optional): coefficient used for computing a running average + of squared gradients (default: 0.9) + eps (float or a list/tuple/np.array/torch.Tensor of floats, optional): term added to the denominator to improve + numerical stability (default: 1e-6) + lr (float or a list/tuple/np.array/torch.Tensor of floats, optional): coefficient that scale delta before it is applied + to the parameters (default: 1.0) + weight_decay (float or a list/tuple/np.array/torch.Tensor of floats, optional): weight decay (L2 penalty) (default: 0) + + .. _ADADELTA\: An Adaptive Learning Rate Method: + https://arxiv.org/abs/1212.5701 + """ def __init__(self, params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0, B=1): - _validate_range('learning rate', lr, 0.0, float('inf')) - _validate_range('rho value', rho, 0.0, 1.0) - _validate_range('epsilon value', eps, 0.0, float('inf')) - _validate_range('weight_decay value', weight_decay, 0.0, float('inf')) - lr, rho, eps, weight_decay = _reduce_array_if_possible_for( + lr, rho, eps, weight_decay = reduce_array_if_possible_for( lr, rho, eps, weight_decay) + lr = make_coefficient('learning rate', lr, lb=0.0, ub=float('inf')) + rho = make_coefficient('rho value', rho, lb=0.0, ub=1.0) + eps = make_coefficient('epsilon value', eps, lb=0.0, ub=float('inf')) + weight_decay = make_coefficient('weight_decay value', + weight_decay, + lb=0.0, + ub=float('inf')) defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) super(Adadelta, self).__init__(params, defaults) - _broadcastablize(self, 'lr', B) - _broadcastablize(self, 'rho', B) - _broadcastablize(self, 'eps', B) - _broadcastablize(self, 'weight_decay', B) - - def zero_grad(self): - if not _zero_grad_if_cuda(self): - super(Adadelta, self).zero_grad() @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: + params_with_grad = [] + grads = [] + square_avgs = [] + acc_deltas = [] + lr, rho, eps, weight_decay = group['lr'], group['rho'], group[ + 'eps'], group['weight_decay'] + for p in group['params']: if p.grad is None: continue - grad = p.grad - if grad.is_sparse: + params_with_grad.append(p) + if p.grad.is_sparse: raise RuntimeError('Adadelta does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] - # State initialization + # Lazy state initialization if len(state) == 0: state['step'] = 0 state['square_avg'] = torch.zeros_like( p, memory_format=torch.preserve_format) state['acc_delta'] = torch.zeros_like( p, memory_format=torch.preserve_format) - _move_coeff_to_same_device(group, 'lr', p) - _move_coeff_to_same_device(group, 'rho', p) - _move_coeff_to_same_device(group, 'eps', p) - _move_coeff_to_same_device(group, 'weight_decay', p) - square_avg, acc_delta = state['square_avg'], state['acc_delta'] - lr, rho, eps, weight_decay = (group['lr'], group['rho'], group['eps'], - group['weight_decay']) + square_avgs.append(state['square_avg']) + acc_deltas.append(state['acc_delta']) state['step'] += 1 - if isinstance(weight_decay, dict) or weight_decay != 0: - if isinstance(weight_decay, dict): - grad = grad + weight_decay[p] * p - else: - grad = grad.add(p, alpha=weight_decay) - - if isinstance(rho, dict): - square_avg.mul_(rho[p]).add_((1 - rho[p]) * grad * grad) - else: - square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) - if isinstance(eps, dict): - std = square_avg.add(eps[p]).sqrt_() - delta = acc_delta.add(eps[p]).sqrt_().div_(std).mul_(grad) - else: - std = square_avg.add(eps).sqrt_() - delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) - if isinstance(lr, dict): - p.add_(-lr[p] * delta) - else: - p.add_(delta, alpha=-lr) - if isinstance(rho, dict): - acc_delta.mul_(rho[p]).add_((1 - rho[p]) * delta * delta) - else: - acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) + F.adadelta(params_with_grad, + grads, + square_avgs, + acc_deltas, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay) return loss diff --git a/hfta/optim/adadelta_test.py b/hfta/optim/adadelta_test.py index 306271b..f9ec248 100644 --- a/hfta/optim/adadelta_test.py +++ b/hfta/optim/adadelta_test.py @@ -8,9 +8,21 @@ from utils import _TestNet, _optim_testing_procedure -def testcase_fused(B=3, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0): - net_array = [_TestNet() for _ in range(B)] - net_fused = _TestNet(B=B) +def testcase_fused( + B=3, + lr=1.0, + rho=0.9, + eps=1e-6, + weight_decay=0, + device=torch.device('cpu'), + dtype=torch.float, +): + if B > 1 and isinstance(lr, (int, float)): + lr = [random.uniform(0.5, 2.0) for _ in range(B)] + + kwargs = {'device': device, 'dtype': dtype} + net_array = [_TestNet(**kwargs) for _ in range(B)] + net_fused = _TestNet(B=B, **kwargs) optimizer_array = [ optim.Adadelta( net_array[b].parameters(), @@ -31,9 +43,14 @@ def testcase_fused(B=3, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0): optimizer_array) -def testcase_partially_fused(B=3): - net_array = [_TestNet() for _ in range(B)] - net_fused = _TestNet(B=B, partially_fused=True) +def testcase_partially_fused( + B=3, + device=torch.device('cpu'), + dtype=torch.float, +): + kwargs = {'device': device, 'dtype': dtype} + net_array = [_TestNet(**kwargs) for _ in range(B)] + net_fused = _TestNet(B=B, partially_fused=True, **kwargs) lr = [random.uniform(0.5, 2.0) for _ in range(B)] rho = [random.uniform(0.7, 0.99) for _ in range(B)] eps = [random.uniform(1e-7, 1e-5) for _ in range(B)] @@ -111,11 +128,15 @@ def testcase_partially_fused(B=3): 0.3, 0.0, ], + 'device': [torch.device('cuda:0')], + 'dtype': [torch.double], }, ) testcase_automator( testcase_partially_fused, { 'B': [1, 5, 8], + 'device': [torch.device('cuda:0')], + 'dtype': [torch.double], }, ) diff --git a/hfta/optim/lr_scheduler.py b/hfta/optim/lr_scheduler.py index f182be5..8cd566c 100644 --- a/hfta/optim/lr_scheduler.py +++ b/hfta/optim/lr_scheduler.py @@ -6,7 +6,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import EPOCH_DEPRECATION_WARNING -from .utils import (_reduce_array_if_possible_for, _to_tensor, +from .utils import (reduce_array_if_possible_for, _to_tensor, _get_coeff_like_params_map, index_array_or_return_scalar) from .partial import PartiallyFusedLRScheduler diff --git a/hfta/optim/utils.py b/hfta/optim/utils.py index 9a6e270..ecf0e53 100644 --- a/hfta/optim/utils.py +++ b/hfta/optim/utils.py @@ -3,7 +3,7 @@ import numpy as np import itertools -from hfta.ops import get_hfta_op_for +from hfta.ops import get_hfta_op_for, assert_allclose, dump_error_msg def _snatch_grads_unfused(op_list, op, b): @@ -37,73 +37,86 @@ def _snatch_parameters_unfused(op_list, op, b): def _assert_params_unfused(op_list, op, b): - np.testing.assert_allclose( - op_list[b].weight.data.numpy(), - op.weight.data.numpy(), + assert_allclose( + op_list[b].weight.data.cpu().numpy(), + op.weight.data.cpu().numpy(), rtol=1e-4, + population_threshold=1e-2, ) if op_list[b].bias is not None: - np.testing.assert_allclose( - op_list[b].bias.data.numpy(), - op.bias.data.numpy(), + assert_allclose( + op_list[b].bias.data.cpu().numpy(), + op.bias.data.cpu().numpy(), rtol=1e-4, + population_threshold=1e-2, ) def _assert_params_linear(fused_op, op, b, fused=True): try: if fused: - np.testing.assert_allclose( - fused_op.weight.data[b].numpy(), - op.weight.data.transpose(0, 1).numpy(), + assert_allclose( + fused_op.weight.data[b].cpu().numpy(), + op.weight.data.transpose(0, 1).cpu().numpy(), rtol=1e-4, + population_threshold=1e-2, ) if fused_op.bias is not None: - np.testing.assert_allclose( - fused_op.bias.data[b].numpy(), - op.bias.data.unsqueeze(0).numpy(), + assert_allclose( + fused_op.bias.data[b].cpu().numpy(), + op.bias.data.unsqueeze(0).cpu().numpy(), rtol=1e-4, + population_threshold=1e-2, ) else: _assert_params_unfused(fused_op, op, b) except AssertionError as e: - print(e) + dump_error_msg(e) def _assert_params_conv2d(fused_op, op, b, fused=True): try: if fused: - np.testing.assert_allclose( - fused_op.weight.data[b].numpy(), - op.weight.data.numpy(), + assert_allclose( + fused_op.weight.data[b].cpu().numpy(), + op.weight.data.cpu().numpy(), rtol=1e-4, + population_threshold=1e-2, ) if fused_op.bias is not None: - np.testing.assert_allclose( - fused_op.bias.data[b].numpy(), - op.bias.data.numpy(), + assert_allclose( + fused_op.bias.data[b].cpu().numpy(), + op.bias.data.cpu().numpy(), rtol=1e-4, + population_threshold=1e-2, ) else: _assert_params_unfused(fused_op, op, b) except AssertionError as e: - print(e) + dump_error_msg(e) class _TestNet(nn.Module): - def __init__(self, B=0, partially_fused=False): + def __init__( + self, + B=0, + partially_fused=False, + device=torch.device('cpu'), + dtype=torch.float, + ): super(_TestNet, self).__init__() - self.conv1 = get_hfta_op_for(nn.Conv2d, B=B)(3, 16, 3, 3) + kwargs = {'device': device, 'dtype': dtype} + self.conv1 = get_hfta_op_for(nn.Conv2d, B=B)(256, 128, 3, 3, **kwargs) if partially_fused: - self.conv2 = [nn.Conv2d(64, 32, 5, 5) for _ in range(B)] + self.conv2 = [nn.Conv2d(128, 256, 5, 5, **kwargs) for _ in range(B)] else: - self.conv2 = get_hfta_op_for(nn.Conv2d, B=B)(64, 32, 5, 5) + self.conv2 = get_hfta_op_for(nn.Conv2d, B=B)(128, 256, 5, 5, **kwargs) if partially_fused: - self.linear1 = [nn.Linear(10, 30) for _ in range(B)] + self.linear1 = [nn.Linear(500, 1000, **kwargs) for _ in range(B)] else: - self.linear1 = get_hfta_op_for(nn.Linear, B=B)(10, 30) - self.linear2 = get_hfta_op_for(nn.Linear, B=B)(100, 20) + self.linear1 = get_hfta_op_for(nn.Linear, B=B)(500, 1000, **kwargs) + self.linear2 = get_hfta_op_for(nn.Linear, B=B)(1000, 500, **kwargs) self.partially_fused = partially_fused def snatch_parameters(self, net, b): @@ -223,22 +236,7 @@ def _optim_testing_procedure( _verify_test_nets_params(net_fused, net_array) -def _validate_range_for_element(name, e, lb, ub): - if e < lb or e > ub: - raise ValueError("Invalid {}: {}".format(name, val)) - - -def _validate_range(name, val, lb, ub): - if isinstance(val, (float, int)): - _validate_range_for_element(name, val, lb, ub) - elif isinstance(val, (list, tuple, torch.Tensor, np.ndarray)): - for e in val: - _validate_range_for_element(name, e, lb, ub) - else: - raise ValueError("Unsupported type({}): {}".format(val, type(val))) - - -def _to_tensor(coeff, B, dtype=torch.float): +def _to_tensor(coeff, B, dtype=torch.float, device=torch.device('cpu')): if isinstance(coeff, (float, int)): res = coeff elif isinstance(coeff, (list, tuple, np.ndarray)): @@ -246,10 +244,10 @@ def _to_tensor(coeff, B, dtype=torch.float): assert len(coeff) == B else: assert len(coeff.shape) == 1 and coeff.shape[0] == B - res = torch.as_tensor(coeff, dtype=dtype) + res = torch.as_tensor(coeff, dtype=dtype, device=device) elif isinstance(coeff, torch.Tensor): assert coeff.dim() == 1 and coeff.size(0) == B - res = coeff + res = coeff.to(dtype=dtype, device=device) else: raise ValueError("Unsupported type({}): {}".format(coeff, type(coeff))) return res @@ -286,14 +284,71 @@ def _broadcastablize(optimizer, name, B, is_tuple=False): group[name] = _get_coeff_like_params_map(coeff, group['params'], B) -def _move_coeff_to_same_device(group, name, p, is_tuple=False): +class Coefficient: + + def __init__(self, name, value): + if not isinstance(value, (list, tuple, torch.Tensor, np.ndarray)): + raise ValueError("Unsupported {} type({}): {}".format( + name, value, type(value))) + + self._name = name + self._value = value + self._ddt_map = {} # (device, dtype) -> tensor + + def _validate_range_for_element(self, i, e, lb=None, ub=None): + if (lb is not None and e < lb) or (ub is not None and e > ub): + raise ValueError("Invalid {}[{}]: {}".format(self._name, i, e)) + + def validate_range(self, lb=None, ub=None): + for i, e in enumerate(self._value): + self._validate_range_for_element(i, e, lb=lb, ub=ub) + + def _update_ddt_map(self, device, dtype): + if isinstance(self._value, (list, tuple, np.ndarray)): + self._ddt_map[(device, dtype)] = torch.as_tensor( + self._value, + dtype=dtype, + device=device, + ) + elif isinstance(self._value, torch.Tensor): + self._ddt_map[(device, dtype)] = self._value.to( + dtype=dtype, + device=device, + ) + else: + raise ValueError("Unsupported type({}): {}".format( + self._value, type(self._value))) + + def __getitem__(self, p): + k = (p.device, p.dtype) + if k not in self._ddt_map: + self._update_ddt_map(p.device, p.dtype) + B = self._ddt_map[k].size(0) + return self._ddt_map[k].view(B, *([1] * (p.dim() - 1))) + + +def is_coefficient(v): + return isinstance(v, Coefficient) + + +def _validate_range(name, val, lb=None, ub=None): + if is_coefficient(val): + val.validate_range(lb=lb, ub=ub) + else: + if (lb is not None and val < lb) or (ub is not None and val > ub): + raise ValueError("Invalid {}: {}".format(name, val)) + + +def make_coefficient(name, value, lb=None, ub=None, is_tuple=False): if is_tuple: - for coeff in group[name]: - if isinstance(coeff, dict): - coeff[p] = coeff[p].to(p.device) + res = tuple(v if isinstance(v, (int, float)) else Coefficient(name, v) + for v in value) + for r in res: + _validate_range(name, r, lb=lb, ub=ub) else: - if isinstance(group[name], dict): - group[name][p] = group[name][p].to(p.device) + res = value if isinstance(value, (int, float)) else Coefficient(name, value) + _validate_range(name, res, lb=lb, ub=ub) + return res def index_array_or_return_scalar(array_or_scalar, b): @@ -309,7 +364,7 @@ def index_array_or_return_scalar(array_or_scalar, b): array_or_scalar, type(array_or_scalar))) -def _reduce_array_if_possible(arr): +def reduce_array_if_possible(arr): if isinstance(arr, (list, tuple, np.ndarray, torch.Tensor)): first = arr[0] for e in arr[1:]: @@ -323,8 +378,8 @@ def _reduce_array_if_possible(arr): return arr -def _reduce_array_if_possible_for(*coeffs): - return (_reduce_array_if_possible(coeff) for coeff in coeffs) +def reduce_array_if_possible_for(*coeffs): + return (reduce_array_if_possible(coeff) for coeff in coeffs) def consolidate_hyperparams_and_determine_B(args, hp_names): From 9b208ac53d05042ce6de3815fca17370320263a6 Mon Sep 17 00:00:00 2001 From: Shang Wang Date: Fri, 1 Oct 2021 02:50:03 -0400 Subject: [PATCH 2/2] [Optim][Adam] Update op. --- hfta/optim/__init__.py | 6 +- hfta/optim/_functional.py | 78 +++++++++++ hfta/optim/adam.py | 280 ++++++++++++++++++-------------------- hfta/optim/adam_test.py | 26 +++- hfta/optim/utils.py | 10 +- 5 files changed, 244 insertions(+), 156 deletions(-) diff --git a/hfta/optim/__init__.py b/hfta/optim/__init__.py index 1ef418c..73e0c3f 100644 --- a/hfta/optim/__init__.py +++ b/hfta/optim/__init__.py @@ -2,19 +2,19 @@ import torch.optim from .adadelta import Adadelta, PartiallyFusedAdadelta -#from .adam import Adam, PartiallyFusedAdam +from .adam import Adam, PartiallyFusedAdam from .lr_scheduler import StepLR, PartiallyFusedStepLR from .utils import (index_array_or_return_scalar, consolidate_hyperparams_and_determine_B) _OPTIMIZERS_MAP = { torch.optim.Adadelta: Adadelta, - #torch.optim.Adam: Adam, + torch.optim.Adam: Adam, } _PARTIALLY_FUSED_OPTIMIZERS_MAP = { torch.optim.Adadelta: PartiallyFusedAdadelta, - #torch.optim.Adam: PartiallyFusedAdam, + torch.optim.Adam: PartiallyFusedAdam, } _LR_SCHEDULER_MAP = { diff --git a/hfta/optim/_functional.py b/hfta/optim/_functional.py index ee5ff88..a1fa34e 100644 --- a/hfta/optim/_functional.py +++ b/hfta/optim/_functional.py @@ -47,3 +47,81 @@ def adadelta( acc_delta.mul_(rho[param]).add_((1 - rho[param]) * delta * delta) else: acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) + + +def adam( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: Union[float, Coefficient], + beta2: Union[float, Coefficient], + lr: Union[float, Coefficient], + weight_decay: Union[float, Coefficient], + eps: Union[float, Coefficient], +): + r"""Functional API that performs Adam algorithm computation. + + See :class:`~torch.optim.Adam` for details. + """ + + for i, param in enumerate(params): + + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + + if is_coefficient(beta1): + bias_correction1 = 1 - beta1[param]**step + else: + bias_correction1 = 1 - beta1**step + if is_coefficient(beta2): + sqrt_bias_correction2 = (1 - beta2[param]**step).sqrt() + else: + sqrt_bias_correction2 = math.sqrt(1 - beta2**step) + + if is_coefficient(weight_decay) or weight_decay != 0: + if is_coefficient(weight_decay): + grad = grad + weight_decay[param] * param + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + if is_coefficient(beta1): + exp_avg.mul_(beta1[param]).add_((1 - beta1[param]) * grad) + else: + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + if is_coefficient(beta2): + exp_avg_sq.mul_(beta2[param]).add_( + (1 - beta2[param]) * grad * grad.conj()) + else: + exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) + # Use the max. for normalizing running avg. of gradient + if is_coefficient(eps): + denom = (max_exp_avg_sqs[i].sqrt() / sqrt_bias_correction2).add_( + eps[param]) + else: + denom = (max_exp_avg_sqs[i].sqrt() / sqrt_bias_correction2).add_(eps) + else: + if is_coefficient(eps): + denom = (exp_avg_sq.sqrt() / sqrt_bias_correction2).add_(eps[param]) + else: + denom = (exp_avg_sq.sqrt() / sqrt_bias_correction2).add_(eps) + + if is_coefficient(lr): + step_size = lr[param] / bias_correction1 + else: + step_size = lr / bias_correction1 + + if torch.is_tensor(step_size): + param.add_(-step_size * (exp_avg / denom)) + else: + param.addcdiv_(exp_avg, denom, value=-step_size) diff --git a/hfta/optim/adam.py b/hfta/optim/adam.py index 6cb4b92..16ce639 100644 --- a/hfta/optim/adam.py +++ b/hfta/optim/adam.py @@ -1,179 +1,171 @@ import math import torch - +from . import _functional as F from torch.optim import Optimizer -from .utils import (_validate_range, _broadcastablize, - _move_coeff_to_same_device, _reduce_array_if_possible_for, - _zero_grad_if_cuda, index_array_or_return_scalar) +from .utils import (make_coefficient, reduce_array_if_possible_for, + index_array_or_return_scalar) from .partial import PartiallyFusedOptimizer class Adam(Optimizer): r"""Implements Adam algorithm. - It has been proposed in `Adam: A Method for Stochastic Optimization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float or a list/tuple/np.array/torch.Tensor of floats, optional): - learning rate (default: 1e-3) - betas (Tuple[float or a list/..., float or a list/...], optional): - coefficients used for computing running averages of gradient and its - square (default: (0.9, 0.999)) - eps (float or a list/tuple/np.array/torch.Tensor of floats, optional): term - added to the denominator to improve numerical stability (default: 1e-8) - weight_decay (float or a list/..., optional): weight decay (L2 penalty) - (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ + .. math:: + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 + \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ + &\hspace{13mm} \lambda \text{ (weight decay)}, \: amsgrad \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, + v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ + &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float or a list/tuple/np.array/torch.Tensor of floats, optional): learning rate (default: 1e-3) + betas (Tuple[float or a list/tuple/np.array/torch.Tensor of floats, float or a list/tuple/np.array/torch.Tensor of floats], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float or a list/tuple/np.array/torch.Tensor of floats, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float or a list/tuple/np.array/torch.Tensor of floats, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - B=1, - ): - _validate_range('learning rate', lr, 0.0, float('inf')) - _validate_range('epsilon value', eps, 0.0, float('inf')) - _validate_range('beta parameter at index 0', betas[0], 0.0, 1.0) - _validate_range('beta parameter at index 1', betas[1], 0.0, 1.0) - _validate_range('weight_decay value', weight_decay, 0.0, float('inf')) - lr, eps, beta1, beta2, weight_decay = _reduce_array_if_possible_for( + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + B=1): + lr, eps, beta1, beta2, weight_decay = reduce_array_if_possible_for( lr, eps, betas[0], betas[1], weight_decay) betas = (beta1, beta2) - - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - amsgrad=amsgrad, # TODO(wangshangsam): amsgrad array support. - ) + lr = make_coefficient('learning rate', lr, lb=0.0, ub=float('inf')) + eps = make_coefficient('epsilon value', eps, lb=0.0, ub=float('inf')) + betas = make_coefficient('beta parameter at index', + betas, + lb=0.0, + ub=1.0, + is_tuple=True) + weight_decay = make_coefficient('weight_decay value', + weight_decay, + lb=0.0, + ub=float('inf')) + # TODO(wangshangsam): amsgrad array support. + defaults = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad) super(Adam, self).__init__(params, defaults) - _broadcastablize(self, 'lr', B) - _broadcastablize(self, 'eps', B) - _broadcastablize(self, 'betas', B, is_tuple=True) - _broadcastablize(self, 'weight_decay', B) def __setstate__(self, state): super(Adam, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) - def zero_grad(self): - if not _zero_grad_if_cuda(self): - super(Adam, self).zero_grad() - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + for p in group['params']: - if p.grad is None: - continue - grad = p.grad - if grad.is_sparse: - raise RuntimeError('Adam does not support sparse gradients, please ' - 'consider SparseAdam instead') - amsgrad = group['amsgrad'] - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like( + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse gradients, please consider SparseAdam instead' + ) + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( p, memory_format=torch.preserve_format) - _move_coeff_to_same_device(group, 'lr', p) - _move_coeff_to_same_device(group, 'eps', p) - _move_coeff_to_same_device(group, 'betas', p, is_tuple=True) - _move_coeff_to_same_device(group, 'weight_decay', p) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - if amsgrad: - max_exp_avg_sq = state['max_exp_avg_sq'] - beta1, beta2 = group['betas'] - lr, eps, weight_decay = group['lr'], group['eps'], group['weight_decay'] - - state['step'] += 1 - if isinstance(beta1, dict): - bias_correction1 = 1 - beta1[p]**state['step'] - else: - bias_correction1 = 1 - beta1**state['step'] - if isinstance(beta2, dict): - sqrt_bias_correction2 = (1 - beta2[p]**state['step']).sqrt() - else: - sqrt_bias_correction2 = math.sqrt(1 - beta2**state['step']) - - if isinstance(weight_decay, dict) or weight_decay != 0: - if isinstance(weight_decay, dict): - grad = grad + weight_decay[p] * p - else: - grad = grad.add(p, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - if isinstance(beta1, dict): - exp_avg.mul_(beta1[p]).add_((1 - beta1[p]) * grad) - else: - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - if isinstance(beta2, dict): - exp_avg_sq.mul_(beta2[p]).add_((1 - beta2[p]) * grad * grad) - else: - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) - # Use the max. for normalizing running avg. of gradient - if isinstance(eps, dict): - denom = (max_exp_avg_sq.sqrt() / sqrt_bias_correction2).add_(eps[p]) - else: - denom = (max_exp_avg_sq.sqrt() / sqrt_bias_correction2).add_(eps) - else: - if isinstance(eps, dict): - denom = (exp_avg_sq.sqrt() / sqrt_bias_correction2).add_(eps[p]) - else: - denom = (exp_avg_sq.sqrt() / sqrt_bias_correction2).add_(eps) - - if isinstance(lr, dict): - step_size = lr[p] / bias_correction1 - else: - step_size = lr / bias_correction1 - if isinstance(step_size, torch.Tensor): - p.add_(-step_size * (exp_avg / denom)) - else: - p.addcdiv_(exp_avg, denom, value=-step_size) - + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if group['amsgrad']: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + F.adam(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=group['amsgrad'], + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) return loss diff --git a/hfta/optim/adam_test.py b/hfta/optim/adam_test.py index a592be2..0e57f53 100644 --- a/hfta/optim/adam_test.py +++ b/hfta/optim/adam_test.py @@ -15,9 +15,15 @@ def testcase_fused( eps=1e-8, weight_decay=0, amsgrad=False, + device=torch.device('cpu'), + dtype=torch.float, ): - net_array = [_TestNet() for _ in range(B)] - net_fused = _TestNet(B=B) + if B > 1 and isinstance(lr, (int, float)): + lr = [random.uniform(1e-4, 1e-2) for _ in range(B)] + + kwargs = {'device': device, 'dtype': dtype} + net_array = [_TestNet(**kwargs) for _ in range(B)] + net_fused = _TestNet(B=B, **kwargs) optimizer_array = [ optim.Adam( net_array[b].parameters(), @@ -43,9 +49,15 @@ def testcase_fused( optimizer_array) -def testcase_partially_fused(B=3, amsgrad=False): - net_array = [_TestNet() for _ in range(B)] - net_fused = _TestNet(B=B, partially_fused=True) +def testcase_partially_fused( + B=3, + amsgrad=False, + device=torch.device('cpu'), + dtype=torch.float, +): + kwargs = {'device': device, 'dtype': dtype} + net_array = [_TestNet(**kwargs) for _ in range(B)] + net_fused = _TestNet(B=B, partially_fused=True, **kwargs) lr = [random.uniform(1e-4, 1e-2) for _ in range(B)] betas = ( [random.uniform(0.8, 0.99) for _ in range(B)], @@ -154,6 +166,8 @@ def testcase_partially_fused(B=3, amsgrad=False): 0.0, ], 'amsgrad': [True], + 'device': [torch.device('cuda:0')], + 'dtype': [torch.double], }, ) testcase_automator( @@ -161,5 +175,7 @@ def testcase_partially_fused(B=3, amsgrad=False): { 'B': [1, 5, 8], 'amsgrad': [True], + 'device': [torch.device('cuda:0')], + 'dtype': [torch.double], }, ) diff --git a/hfta/optim/utils.py b/hfta/optim/utils.py index ecf0e53..d92e062 100644 --- a/hfta/optim/utils.py +++ b/hfta/optim/utils.py @@ -341,10 +341,12 @@ def _validate_range(name, val, lb=None, ub=None): def make_coefficient(name, value, lb=None, ub=None, is_tuple=False): if is_tuple: - res = tuple(v if isinstance(v, (int, float)) else Coefficient(name, v) - for v in value) - for r in res: - _validate_range(name, r, lb=lb, ub=ub) + res = tuple(v if isinstance(v, (int, float)) else Coefficient( + '{}[{}]'.format(name, i), + v, + ) for i, v in enumerate(value)) + for i, r in enumerate(res): + _validate_range('{}[{}]'.format(name, i), r, lb=lb, ub=ub) else: res = value if isinstance(value, (int, float)) else Coefficient(name, value) _validate_range(name, res, lb=lb, ub=ub)