Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a tf.deform_grid_layer to use as a layer for keras models #11

Open
CaptainDario opened this issue Jul 8, 2021 · 1 comment
Open

Comments

@CaptainDario
Copy link

CaptainDario commented Jul 8, 2021

I am trying to use this package as a custom layer with tensorflow (keras).
My code looks at the moment like this:

@tf.function
def elastic_deform_tf(x):
    """
    Elastic deformation layer.
    
    Args:
        x : a Tensor(batch_size, height, width, channels) of images to preprocess
        
    Returns: 
        Images on which elastic deform was applied.
    """
    x = tf.reshape(x, (x.shape[0], 64, 64))
    x = tf.cast(x, tf.float32)

    # generate a deformation grid
    displacement = np.random.randn(3, 1, 1, 1) * 2

    # perform forward deformation
    x_deformed = edf.tf.deform_grid(x, displacement)
    
    return x_deformed

class ElasticDeformTFLayer(tf.keras.layers.Layer):
    def __init__(self, name="elastic_deform_tf", **kwargs):
        super(ElasticDeformTFLayer, self).__init__(name=name, **kwargs)
        self.preprocess = elastic_deform_tf

    def call(self, input):
        return self.preprocess(input)

    def get_config(self):
        config = super(ElasticDeformTFLayer, self).get_config()
        return config

However this is very slow and is not running on the GPU.
Therefore it would be very nice if this package could add a layer which can be used in TF for ML models.

@gvtulder
Copy link
Owner

Thanks. I agree, it would be nice if the deformations could be computed on the GPU. It's somewhere on my to-do list, but it's not a small change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants