Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to Keras2? #12

Open
EladNoy opened this issue Jun 5, 2017 · 4 comments
Open

Upgrade to Keras2? #12

EladNoy opened this issue Jun 5, 2017 · 4 comments

Comments

@EladNoy
Copy link

EladNoy commented Jun 5, 2017

With Keras1 now being deprecated, a Keras2 version would be greatly appreciated.
I tried converting it myself but Keras2 does not support the batchnorm mode=2 option, so it will probably require some sort of a workaround.

@engharat
Copy link

engharat commented Jul 4, 2017

I was stuck with your same problem. I ended up developing a batchnorm version that uses always batchnorm mode = 2. you can easily edit the keras file where bn is defined, and you can modify it so it will never use batchnorm training accumulated statistics.

@frnk99
Copy link

frnk99 commented Jul 15, 2017

Can you share the code. @engharat Please.

@engharat
Copy link

engharat commented Jul 16, 2017

Sure. Here is a link to the code: https://drive.google.com/open?id=0B0E8DCU-EnYRR2l3aV9oTkJORHc . The file needs to be put in the same folder of your script and it needs to be imported of course, then you can substitute any occurrence of BatchNormalization layer in the generator / discriminator code with the layer BatchNormGAN.

Or if you prefer the code:

`# -- coding: utf-8 --
from future import absolute_import

from keras.engine import Layer, InputSpec
from keras import initializers
from keras import regularizers
from keras import constraints
from keras import backend as K
from keras.legacy import interfaces

class BatchNormGAN(Layer):
"""Batch normalization layer (Ioffe and Szegedy, 2014).

Normalize the activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.

# Arguments
    axis: Integer, the axis that should be normalized
        (typically the features axis).
        For instance, after a `Conv2D` layer with
        `data_format="channels_first"`,
        set `axis=1` in `BatchNormGAN`.
    momentum: Momentum for the moving average.
    epsilon: Small float added to variance to avoid dividing by zero.
    center: If True, add offset of `beta` to normalized tensor.
        If False, `beta` is ignored.
    scale: If True, multiply by `gamma`.
        If False, `gamma` is not used.
        When the next layer is linear (also e.g. `nn.relu`),
        this can be disabled since the scaling
        will be done by the next layer.
    beta_initializer: Initializer for the beta weight.
    gamma_initializer: Initializer for the gamma weight.
    moving_mean_initializer: Initializer for the moving mean.
    moving_variance_initializer: Initializer for the moving variance.
    beta_regularizer: Optional regularizer for the beta weight.
    gamma_regularizer: Optional regularizer for the gamma weight.
    beta_constraint: Optional constraint for the beta weight.
    gamma_constraint: Optional constraint for the gamma weight.

# Input shape
    Arbitrary. Use the keyword argument `input_shape`
    (tuple of integers, does not include the samples axis)
    when using this layer as the first layer in a model.

# Output shape
    Same shape as input.

# References
    - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""

@interfaces.legacy_batchnorm_support
def __init__(self,
             axis=-1,
             momentum=0.99,
             epsilon=1e-3,
             center=True,
             scale=True,
             beta_initializer='zeros',
             gamma_initializer='ones',
             moving_mean_initializer='zeros',
             moving_variance_initializer='ones',
             beta_regularizer=None,
             gamma_regularizer=None,
             beta_constraint=None,
             gamma_constraint=None,
             **kwargs):
    super(BatchNormGAN, self).__init__(**kwargs)
    self.supports_masking = True
    self.axis = axis
    self.momentum = momentum
    self.epsilon = epsilon
    self.center = center
    self.scale = scale
    self.beta_initializer = initializers.get(beta_initializer)
    self.gamma_initializer = initializers.get(gamma_initializer)
    self.moving_mean_initializer = initializers.get(moving_mean_initializer)
    self.moving_variance_initializer = initializers.get(moving_variance_initializer)
    self.beta_regularizer = regularizers.get(beta_regularizer)
    self.gamma_regularizer = regularizers.get(gamma_regularizer)
    self.beta_constraint = constraints.get(beta_constraint)
    self.gamma_constraint = constraints.get(gamma_constraint)

def build(self, input_shape):
    dim = input_shape[self.axis]
    if dim is None:
        raise ValueError('Axis ' + str(self.axis) + ' of '
                         'input tensor should have a defined dimension '
                         'but the layer received an input with shape ' +
                         str(input_shape) + '.')
    self.input_spec = InputSpec(ndim=len(input_shape),
                                axes={self.axis: dim})
    shape = (dim,)

    if self.scale:
        self.gamma = self.add_weight(shape,
                                     name='gamma',
                                     initializer=self.gamma_initializer,
                                     regularizer=self.gamma_regularizer,
                                     constraint=self.gamma_constraint)
    else:
        self.gamma = None
    if self.center:
        self.beta = self.add_weight(shape,
                                    name='beta',
                                    initializer=self.beta_initializer,
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint)
    else:
        self.beta = None
    self.moving_mean = self.add_weight(
        shape,
        name='moving_mean',
        initializer=self.moving_mean_initializer,
        trainable=False)
    self.moving_variance = self.add_weight(
        shape,
        name='moving_variance',
        initializer=self.moving_variance_initializer,
        trainable=False)
    self.built = True

def call(self, inputs, training=None):
    input_shape = K.int_shape(inputs)
    # Prepare broadcasting shape.
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]

    normed, mean, variance = K.normalize_batch_in_training(
        inputs, self.gamma, self.beta, reduction_axes,
        epsilon=self.epsilon)

    return normed #K.in_train_phase(normed,
                   #         normalize_inference,
                   #         training=True)

def get_config(self):
    config = {
        'axis': self.axis,
        'momentum': self.momentum,
        'epsilon': self.epsilon,
        'center': self.center,
        'scale': self.scale,
        'beta_initializer': initializers.serialize(self.beta_initializer),
        'gamma_initializer': initializers.serialize(self.gamma_initializer),
        'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
        'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
        'beta_constraint': constraints.serialize(self.beta_constraint),
        'gamma_constraint': constraints.serialize(self.gamma_constraint)
    }
    base_config = super(BatchNormGAN, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))``

@frnk99
Copy link

frnk99 commented Jul 22, 2017

thank you! @engharat

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants