diff --git a/normflows/flows/neural_spline/wrapper.py b/normflows/flows/neural_spline/wrapper.py index 2eefa22..f395144 100644 --- a/normflows/flows/neural_spline/wrapper.py +++ b/normflows/flows/neural_spline/wrapper.py @@ -29,6 +29,7 @@ def __init__( activation=nn.ReLU, dropout_probability=0.0, reverse_mask=False, + init_identity=True, ): """Constructor @@ -43,11 +44,12 @@ def __init__( activation (torch module): Activation function dropout_probability (float): Dropout probability of the NN reverse_mask (bool): Flag whether the reverse mask should be used + init_identity (bool): Flag, initialize transform as identity """ super().__init__() def transform_net_create_fn(in_features, out_features): - return ResidualNet( + net = ResidualNet( in_features=in_features, out_features=out_features, context_features=num_context_channels, @@ -57,6 +59,12 @@ def transform_net_create_fn(in_features, out_features): dropout_probability=dropout_probability, use_batch_norm=False, ) + if init_identity: + torch.nn.init.constant_(net.final_layer.weight, 0.0) + torch.nn.init.constant_( + net.final_layer.bias, np.log(np.exp(1 - DEFAULT_MIN_DERIVATIVE) - 1) + ) + return net self.prqct = PiecewiseRationalQuadraticCoupling( mask=create_alternating_binary_mask(num_input_channels, even=reverse_mask),