Skip to content

Commit

Permalink
FFCV: add CenterCrop, fix try speed example.
Browse files Browse the repository at this point in the history
  • Loading branch information
lrzpellegrini committed Oct 11, 2023
1 parent 0515a47 commit d7a096e
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 7 deletions.
138 changes: 138 additions & 0 deletions avalanche/benchmarks/utils/ffcv_support/center_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Implementation of the CenterCrop transformation for FFCV
"""

from typing import Callable, Tuple
from ffcv.fields.decoders import SimpleRGBImageDecoder
from ffcv.pipeline.state import State
from ffcv.pipeline.allocation_query import AllocationQuery
import numpy as np
from dataclasses import replace
from ffcv.fields.rgb_image import IMAGE_MODES
from ffcv.pipeline.compiler import Compiler
from ffcv.libffcv import imdecode


def get_center_crop_torchvision_alike(
image_height, image_width, output_size, img, out_buffer
):
crop_height = output_size[0]
crop_width = output_size[1]

padding_h = (crop_height - image_height) // 2 if crop_height > image_height else 0
padding_w = (crop_width - image_width) // 2 if crop_width > image_width else 0

crop_t = (
int(round((image_height - crop_height) / 2.0))
if image_height > crop_height
else 0
)
crop_l = (
int(round((image_width - crop_width) / 2.0)) if image_width > crop_width else 0
)
crop_height_effective = min(crop_height, image_height)
crop_width_effective = min(crop_width, image_width)

# print(image_height, image_width, crop_height, crop_width, padding_h, padding_w, crop_t, crop_l, crop_height_effective, crop_width_effective)
# print(f'From ({crop_t} : {crop_t+crop_height_effective}, {crop_l} : {crop_l+crop_width_effective}) to '
# f'{padding_h} : {padding_h+crop_height_effective}, {padding_w} : {padding_w+crop_width_effective}')

if crop_height_effective != crop_height or crop_width_effective != crop_width:
out_buffer[:] = 0 # Set padding color
out_buffer[
padding_h : padding_h + crop_height_effective,
padding_w : padding_w + crop_width_effective,
] = img[
crop_t : crop_t + crop_height_effective, crop_l : crop_l + crop_width_effective
]

return out_buffer


class CenterCropRGBImageDecoderTVAlike(SimpleRGBImageDecoder):
"""Decoder for :class:`~ffcv.fields.RGBImageField` that performs a center crop operation.
It supports both variable and constant resolution datasets.
Differently from the original CenterCropRGBImageDecoder from FFCV,
this operates like torchvision CenterCrop.
"""

def __init__(self, output_size):
super().__init__()
self.output_size = output_size

def declare_state_and_memory(
self, previous_state: State
) -> Tuple[State, AllocationQuery]:
widths = self.metadata["width"]
heights = self.metadata["height"]
# We convert to uint64 to avoid overflows
self.max_width = np.uint64(widths.max())
self.max_height = np.uint64(heights.max())
output_shape = (self.output_size[0], self.output_size[1], 3)
my_dtype = np.dtype("<u1")

return (
replace(previous_state, jit_mode=True, shape=output_shape, dtype=my_dtype),
(
AllocationQuery(output_shape, my_dtype),
AllocationQuery(
(self.max_height * self.max_width * np.uint64(3),), my_dtype
),
),
)

def generate_code(self) -> Callable:
jpg = IMAGE_MODES["jpg"]

mem_read = self.memory_read
my_range = Compiler.get_iterator()
imdecode_c = Compiler.compile(imdecode)
c_crop = Compiler.compile(self.get_crop_generator)
output_size = self.output_size

def decode(batch_indices, my_storage, metadata, storage_state):
destination, temp_storage = my_storage
for dst_ix in my_range(len(batch_indices)):
source_ix = batch_indices[dst_ix]
field = metadata[source_ix]
image_data = mem_read(field["data_ptr"], storage_state)
height = np.uint32(field["height"])
width = np.uint32(field["width"])

if field["mode"] == jpg:
temp_buffer = temp_storage[dst_ix]
imdecode_c(
image_data,
temp_buffer,
height,
width,
height,
width,
0,
0,
1,
1,
False,
False,
)
selected_size = 3 * height * width
temp_buffer = temp_buffer.reshape(-1)[:selected_size]
temp_buffer = temp_buffer.reshape(height, width, 3)
else:
temp_buffer = image_data.reshape(height, width, 3)

c_crop(height, width, output_size, temp_buffer, destination[dst_ix])

return destination[: len(batch_indices)]

decode.is_parallel = True
return decode

@property
def get_crop_generator(self):
return get_center_crop_torchvision_alike


__all__ = ["CenterCropRGBImageDecoderTVAlike"]
14 changes: 13 additions & 1 deletion avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
import torch

from avalanche.benchmarks.utils.transforms import flat_transforms_recursive
from avalanche.benchmarks.utils.ffcv_support.center_crop import (
CenterCropRGBImageDecoderTVAlike,
)

from torchvision.transforms import ToTensor as ToTensorTV
from torchvision.transforms import PILToTensor as PILToTensorTV
from torchvision.transforms import Normalize as NormalizeTV
from torchvision.transforms import ConvertImageDtype as ConvertTV
from torchvision.transforms import RandomResizedCrop as RandomResizedCropTV
from torchvision.transforms import CenterCrop as CenterCropTV
from torchvision.transforms import RandomHorizontalFlip as RandomHorizontalFlipTV
from torchvision.transforms import RandomCrop as RandomCropTV
from torchvision.transforms import Lambda

from ffcv.transforms import ToTensor as ToTensorFFCV
from ffcv.transforms import ToDevice as ToDeviceFFCV
Expand Down Expand Up @@ -282,6 +285,15 @@ def _apply_transforms_pre_optimization(
elif len(size) == 1:
size = [size[0], size[0]]
result[-1] = RandomResizedCropRGBImageDecoder(size, t.scale, t.ratio)
elif isinstance(t, CenterCropTV) and isinstance(
result[-1], SimpleRGBImageDecoder
):
size = t.size
if isinstance(size, int):
size = [size, size]
elif len(size) == 1:
size = [size[0], size[0]]
result[-1] = CenterCropRGBImageDecoderTVAlike(size)
else:
result.append(t)

Expand Down
19 changes: 13 additions & 6 deletions examples/ffcv/ffcv_try_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from torchvision.transforms import Compose, ToTensor, Normalize

from torch.utils.data import DataLoader
from torch.utils.data.sampler import (
BatchSampler,
SequentialSampler,
)
from tqdm import tqdm


Expand Down Expand Up @@ -114,16 +118,19 @@ def benchmark_ffcv_speed(

start_time = time.time()
ffcv_loader = HybridFfcvLoader(
avl_set,
None,
batch_size,
dict(num_workers=num_workers, drop_last=True),
dataset=avl_set,
batch_sampler=BatchSampler(
SequentialSampler(avl_set),
batch_size=batch_size,
drop_last=True,
),
ffcv_loader_parameters=dict(num_workers=num_workers),
device=device,
print_ffcv_summary=False,
)

for _ in tqdm(range(epochs)):
for batch in ffcv_loader:
for batch in tqdm(ffcv_loader):
# "Touch" tensors to make sure they already moved to GPU
batch[0][0]
batch[-1][0]
Expand Down Expand Up @@ -152,7 +159,7 @@ def benchmark_pytorch_speed(benchmark, device, batch_size=128, num_workers=1, ep

batch: Tuple[torch.Tensor]
for _ in tqdm(range(epochs)):
for batch in torch_loader:
for batch in tqdm(torch_loader):
batch = tuple(x.to(device, non_blocking=True) for x in batch)

# "Touch" tensors to make sure they already moved to GPU
Expand Down

0 comments on commit d7a096e

Please sign in to comment.