Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature req.] cropping utilities #3

Open
choltz95 opened this issue Mar 13, 2021 · 1 comment
Open

[feature req.] cropping utilities #3

choltz95 opened this issue Mar 13, 2021 · 1 comment

Comments

@choltz95
Copy link

The tensorflow.image package includes some nice cropping utilities - e.g. central cropping, random cropping, etc.

Any plans to do something similar for this library? How difficult would it be to implement?

@josephrocca
Copy link

josephrocca commented Sep 12, 2021

There's a simple random crop implementation here which has been handy for me. It uses jax.lax.dynamic_slice:

def random_crop(key, image, crop_sizes):
    """Crop images randomly to specified sizes.
    Given an input image, it crops the image to the specified `crop_sizes`. If
    `crop_sizes` are lesser than the image's sizes, the offset for cropping is
    chosen at random. To deterministically crop an image,
    please use `jax.lax.dynamic_slice` and specify offsets and crop sizes.
    Args:
        key : Key for pseudo-random number generator.
        image : A JAX array which represents an image.
        crop_sizes: A sequence of integers, each of which sequentially specifies the
          crop size along the corresponding dimension of the image. Sequence length
          must be identical to the rank of the image and the crop size should not be
          greater than the corresponding image dimension.
    Returns:
        A cropped image, a JAX array whose shape is same as `crop_sizes`.
    """
    image_shape = image.shape
    assert len(image_shape) == len(crop_sizes), f"Number of image dims {len(image_shape)} and number of crop_sizes {len(crop_sizes)} do not match."
    assert image_shape >= crop_sizes, f"Crop sizes {crop_sizes} should be a subset of image size {image_shape} in each dimension."
    random_keys = jax.random.split(key, len(crop_sizes))
    slice_starts = [
      jax.random.randint(k, (), 0, img_size - crop_size + 1)
      for k, img_size, crop_size in zip(random_keys, image_shape, crop_sizes)
    ]
    out = jax.lax.dynamic_slice(image, slice_starts, crop_sizes)

    return out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants