-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathclip_processor.py
81 lines (68 loc) · 3.09 KB
/
clip_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import List
import numpy as np
import torch
from transformers.utils import is_tf_available
if is_tf_available():
import tensorflow as tf # type: ignore
else:
raise ValueError("Please run `pip install tensorflow` to use the processor.")
MEAN_RGB = [0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255]
STDDEV_RGB = [0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255]
def crop_image(image: tf.Tensor, center_crop_fraction: float = 0.875):
image_size = tf.cast(tf.shape(image)[:2], dtype=tf.float32)
crop_size = center_crop_fraction * tf.math.minimum(image_size[0], image_size[1])
crop_offset = tf.cast((image_size - crop_size) / 2.0, dtype=tf.int32)
crop_size = tf.cast(crop_size, dtype=tf.int32)
return image[
crop_offset[0] : crop_offset[0] + crop_size, crop_offset[1] : crop_offset[1] + crop_size, : # noqa: E203
]
def whiten(
image: tf.Tensor,
) -> tf.Tensor:
image = tf.cast(tf.convert_to_tensor(image), tf.float32)
image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
return image
def tf_image_reshape_crop(image: tf.Tensor, crop_size: int) -> tf.Tensor:
# 100000 is chosen as no image would have 100000 pixels along one edge.
shape_1 = (100000, crop_size)
shape_2 = (crop_size, 100000)
image = tf.cond(
tf.shape(image)[0] > tf.shape(image)[1],
lambda: tf.image.resize(image, shape_1, method="bilinear", preserve_aspect_ratio=True, antialias=False),
lambda: tf.image.resize(image, shape_2, method="bilinear", preserve_aspect_ratio=True, antialias=False),
)
processed_image = crop_image(image=image, center_crop_fraction=1)
return processed_image
def _single_image_preprocess(image: np.ndarray, crop_size: int = 224, resize_only: bool = False):
"""Single image preprocess.
Args:
images: image in numpy array.
crop_size: the size of the cropped images.
resize_only: If true, only resize to the crop size, otherwise, first resize then center crop.
Returns:
A torch tensor with processed image.
"""
image = tf.constant(image)
if resize_only:
image = tf.image.resize(
image, (crop_size, crop_size), method="bilinear", preserve_aspect_ratio=False, antialias=False
)
else:
image = tf_image_reshape_crop(image, crop_size)
image = whiten(image)
return torch.asarray(image.numpy())
def image_preprocess(images: List[np.ndarray], crop_size: int = 224, resize_only: bool = False):
"""Image preprocess using tf resizing function.
Args:
images: A list of numpy array.
crop_size: the size of the cropped images.
Returns:
A torch tensor with shape [size_of_images, crop_size, crop_size, 3].
"""
processed_images = []
for image in images:
image = tf.constant(image)
processed_image = _single_image_preprocess(image, crop_size=crop_size, resize_only=resize_only)
processed_images.append(processed_image)
return torch.permute(torch.stack(processed_images, 0), (0, 3, 1, 2))