diff --git a/avalanche/benchmarks/utils/ffcv_support/center_crop.py b/avalanche/benchmarks/utils/ffcv_support/center_crop.py new file mode 100644 index 000000000..4461d89ce --- /dev/null +++ b/avalanche/benchmarks/utils/ffcv_support/center_crop.py @@ -0,0 +1,138 @@ +""" +Implementation of the CenterCrop transformation for FFCV +""" + +from typing import Callable, Tuple +from ffcv.fields.decoders import SimpleRGBImageDecoder +from ffcv.pipeline.state import State +from ffcv.pipeline.allocation_query import AllocationQuery +import numpy as np +from dataclasses import replace +from ffcv.fields.rgb_image import IMAGE_MODES +from ffcv.pipeline.compiler import Compiler +from ffcv.libffcv import imdecode + + +def get_center_crop_torchvision_alike( + image_height, image_width, output_size, img, out_buffer +): + crop_height = output_size[0] + crop_width = output_size[1] + + padding_h = (crop_height - image_height) // 2 if crop_height > image_height else 0 + padding_w = (crop_width - image_width) // 2 if crop_width > image_width else 0 + + crop_t = ( + int(round((image_height - crop_height) / 2.0)) + if image_height > crop_height + else 0 + ) + crop_l = ( + int(round((image_width - crop_width) / 2.0)) if image_width > crop_width else 0 + ) + crop_height_effective = min(crop_height, image_height) + crop_width_effective = min(crop_width, image_width) + + # print(image_height, image_width, crop_height, crop_width, padding_h, padding_w, crop_t, crop_l, crop_height_effective, crop_width_effective) + # print(f'From ({crop_t} : {crop_t+crop_height_effective}, {crop_l} : {crop_l+crop_width_effective}) to ' + # f'{padding_h} : {padding_h+crop_height_effective}, {padding_w} : {padding_w+crop_width_effective}') + + if crop_height_effective != crop_height or crop_width_effective != crop_width: + out_buffer[:] = 0 # Set padding color + out_buffer[ + padding_h : padding_h + crop_height_effective, + padding_w : padding_w + crop_width_effective, + ] = img[ + crop_t : crop_t + crop_height_effective, crop_l : crop_l + crop_width_effective + ] + + return out_buffer + + +class CenterCropRGBImageDecoderTVAlike(SimpleRGBImageDecoder): + """Decoder for :class:`~ffcv.fields.RGBImageField` that performs a center crop operation. + + It supports both variable and constant resolution datasets. + + Differently from the original CenterCropRGBImageDecoder from FFCV, + this operates like torchvision CenterCrop. + """ + + def __init__(self, output_size): + super().__init__() + self.output_size = output_size + + def declare_state_and_memory( + self, previous_state: State + ) -> Tuple[State, AllocationQuery]: + widths = self.metadata["width"] + heights = self.metadata["height"] + # We convert to uint64 to avoid overflows + self.max_width = np.uint64(widths.max()) + self.max_height = np.uint64(heights.max()) + output_shape = (self.output_size[0], self.output_size[1], 3) + my_dtype = np.dtype(" Callable: + jpg = IMAGE_MODES["jpg"] + + mem_read = self.memory_read + my_range = Compiler.get_iterator() + imdecode_c = Compiler.compile(imdecode) + c_crop = Compiler.compile(self.get_crop_generator) + output_size = self.output_size + + def decode(batch_indices, my_storage, metadata, storage_state): + destination, temp_storage = my_storage + for dst_ix in my_range(len(batch_indices)): + source_ix = batch_indices[dst_ix] + field = metadata[source_ix] + image_data = mem_read(field["data_ptr"], storage_state) + height = np.uint32(field["height"]) + width = np.uint32(field["width"]) + + if field["mode"] == jpg: + temp_buffer = temp_storage[dst_ix] + imdecode_c( + image_data, + temp_buffer, + height, + width, + height, + width, + 0, + 0, + 1, + 1, + False, + False, + ) + selected_size = 3 * height * width + temp_buffer = temp_buffer.reshape(-1)[:selected_size] + temp_buffer = temp_buffer.reshape(height, width, 3) + else: + temp_buffer = image_data.reshape(height, width, 3) + + c_crop(height, width, output_size, temp_buffer, destination[dst_ix]) + + return destination[: len(batch_indices)] + + decode.is_parallel = True + return decode + + @property + def get_crop_generator(self): + return get_center_crop_torchvision_alike + + +__all__ = ["CenterCropRGBImageDecoderTVAlike"] diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py index 1c84949f3..f72940b88 100644 --- a/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py @@ -20,15 +20,18 @@ import torch from avalanche.benchmarks.utils.transforms import flat_transforms_recursive +from avalanche.benchmarks.utils.ffcv_support.center_crop import ( + CenterCropRGBImageDecoderTVAlike, +) from torchvision.transforms import ToTensor as ToTensorTV from torchvision.transforms import PILToTensor as PILToTensorTV from torchvision.transforms import Normalize as NormalizeTV from torchvision.transforms import ConvertImageDtype as ConvertTV from torchvision.transforms import RandomResizedCrop as RandomResizedCropTV +from torchvision.transforms import CenterCrop as CenterCropTV from torchvision.transforms import RandomHorizontalFlip as RandomHorizontalFlipTV from torchvision.transforms import RandomCrop as RandomCropTV -from torchvision.transforms import Lambda from ffcv.transforms import ToTensor as ToTensorFFCV from ffcv.transforms import ToDevice as ToDeviceFFCV @@ -282,6 +285,15 @@ def _apply_transforms_pre_optimization( elif len(size) == 1: size = [size[0], size[0]] result[-1] = RandomResizedCropRGBImageDecoder(size, t.scale, t.ratio) + elif isinstance(t, CenterCropTV) and isinstance( + result[-1], SimpleRGBImageDecoder + ): + size = t.size + if isinstance(size, int): + size = [size, size] + elif len(size) == 1: + size = [size[0], size[0]] + result[-1] = CenterCropRGBImageDecoderTVAlike(size) else: result.append(t) diff --git a/examples/ffcv/ffcv_try_speed.py b/examples/ffcv/ffcv_try_speed.py index ed487d2e2..bf647f6d5 100644 --- a/examples/ffcv/ffcv_try_speed.py +++ b/examples/ffcv/ffcv_try_speed.py @@ -28,6 +28,10 @@ from torchvision.transforms import Compose, ToTensor, Normalize from torch.utils.data import DataLoader +from torch.utils.data.sampler import ( + BatchSampler, + SequentialSampler, +) from tqdm import tqdm @@ -114,16 +118,19 @@ def benchmark_ffcv_speed( start_time = time.time() ffcv_loader = HybridFfcvLoader( - avl_set, - None, - batch_size, - dict(num_workers=num_workers, drop_last=True), + dataset=avl_set, + batch_sampler=BatchSampler( + SequentialSampler(avl_set), + batch_size=batch_size, + drop_last=True, + ), + ffcv_loader_parameters=dict(num_workers=num_workers), device=device, print_ffcv_summary=False, ) for _ in tqdm(range(epochs)): - for batch in ffcv_loader: + for batch in tqdm(ffcv_loader): # "Touch" tensors to make sure they already moved to GPU batch[0][0] batch[-1][0] @@ -152,7 +159,7 @@ def benchmark_pytorch_speed(benchmark, device, batch_size=128, num_workers=1, ep batch: Tuple[torch.Tensor] for _ in tqdm(range(epochs)): - for batch in torch_loader: + for batch in tqdm(torch_loader): batch = tuple(x.to(device, non_blocking=True) for x in batch) # "Touch" tensors to make sure they already moved to GPU