diff --git a/inferno/extensions/containers/graph.py b/inferno/extensions/containers/graph.py index 0bbf57b0..45e82286 100755 --- a/inferno/extensions/containers/graph.py +++ b/inferno/extensions/containers/graph.py @@ -130,7 +130,7 @@ def is_node_in_graph(self, name): ------- bool """ - return name in self.graph.node + return name in self.graph.nodes def is_source_node(self, name): """ @@ -187,7 +187,7 @@ def output_nodes(self): list A list of names (str) of the output nodes. """ - return [name for name, node_attributes in self.graph.node.items() + return [name for name, node_attributes in self.graph.nodes.items() if node_attributes.get('is_output_node', False)] @property @@ -201,7 +201,7 @@ def input_nodes(self): list A list of names (str) of the input nodes. """ - return [name for name, node_attributes in self.graph.node.items() + return [name for name, node_attributes in self.graph.nodes.items() if node_attributes.get('is_input_node', False)] @property diff --git a/inferno/extensions/optimizers/__init__.py b/inferno/extensions/optimizers/__init__.py index 7235cead..842dd8f9 100755 --- a/inferno/extensions/optimizers/__init__.py +++ b/inferno/extensions/optimizers/__init__.py @@ -1,2 +1,3 @@ from .adam import Adam -from .annealed_adam import AnnealedAdam \ No newline at end of file +from .annealed_adam import AnnealedAdam +from .ranger import Ranger, RangerQH, RangerVA diff --git a/inferno/extensions/optimizers/ranger.py b/inferno/extensions/optimizers/ranger.py new file mode 100644 index 00000000..c01ea0d5 --- /dev/null +++ b/inferno/extensions/optimizers/ranger.py @@ -0,0 +1,8 @@ +# easy support for additional ranger optimizers from +# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer +try: + from ranger import Ranger, RangerVA, RangerQH +except ImportError: + Ranger = None + RangerVA = None + RangerQH = None diff --git a/inferno/io/transform/base.py b/inferno/io/transform/base.py index 6c23113c..7a63804b 100755 --- a/inferno/io/transform/base.py +++ b/inferno/io/transform/base.py @@ -58,7 +58,7 @@ def __call__(self, *tensors, **transform_function_kwargs): transformed = self.batch_function(tensors, **transform_function_kwargs) return pyu.from_iterable(transformed) elif hasattr(self, 'tensor_function'): - transformed = [self.tensor_function(tensor, **transform_function_kwargs) + transformed = [self._apply_tensor_function(tensor, **transform_function_kwargs) if tensor_index in apply_to else tensor for tensor_index, tensor in enumerate(tensors)] return pyu.from_iterable(transformed) @@ -77,9 +77,17 @@ def __call__(self, *tensors, **transform_function_kwargs): else: raise NotImplementedError + # noinspection PyUnresolvedReferences + def _apply_tensor_function(self, tensor, **transform_function_kwargs): + if isinstance(tensor, list): + return [self._apply_tensor_function(tens) for tens in tensor] + return self.tensor_function(tensor) + # noinspection PyUnresolvedReferences def _apply_image_function(self, tensor, **transform_function_kwargs): assert pyu.has_callable_attr(self, 'image_function') + if isinstance(tensor, list): + return [self._apply_image_function(tens) for tens in tensor] # 2D case if tensor.ndim == 4: return np.array([np.array([self.image_function(image, **transform_function_kwargs) @@ -106,6 +114,8 @@ def _apply_image_function(self, tensor, **transform_function_kwargs): # noinspection PyUnresolvedReferences def _apply_volume_function(self, tensor, **transform_function_kwargs): assert pyu.has_callable_attr(self, 'volume_function') + if isinstance(tensor, list): + return [self._apply_volume_function(tens) for tens in tensor] # 3D case if tensor.ndim == 5: # tensor is bczyx @@ -125,7 +135,7 @@ def _apply_volume_function(self, tensor, **transform_function_kwargs): # We're applying the volume function on the volume itself return self.volume_function(tensor, **transform_function_kwargs) else: - raise NotImplementedError + raise NotImplementedError("Volume function not implemented for ndim %i" % tensor.ndim) class Compose(object): diff --git a/inferno/io/transform/image.py b/inferno/io/transform/image.py index db97992c..77b5a016 100755 --- a/inferno/io/transform/image.py +++ b/inferno/io/transform/image.py @@ -596,5 +596,5 @@ def batch_function(self, image): pad_r = image_shape - new_shape - pad_l padding = [(0,0)] + list(zip(pad_l, pad_r)) img = np.pad(img, padding, 'constant', constant_values=self.pad_const) - seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const) - return img, seg \ No newline at end of file + seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const) + return img, seg diff --git a/inferno/io/transform/volume.py b/inferno/io/transform/volume.py index 721203d2..8afaade2 100755 --- a/inferno/io/transform/volume.py +++ b/inferno/io/transform/volume.py @@ -63,6 +63,7 @@ def volume_function(self, volume): return volume +# TODO this is obsolete class AdditiveRandomNoise3D(Transform): """ Add gaussian noise to 3d volume @@ -105,7 +106,7 @@ def __init__(self, sigma, mode='gaussian', **super_kwargs): self.sigma = sigma # TODO check if volume is tensor and use torch functions in that case - def volume_function(self, volume): + def tensor_function(self, volume): volume += np.random.normal(loc=0, scale=self.sigma, size=volume.shape) return volume diff --git a/inferno/io/volumetric/lazy_volume_loader.py b/inferno/io/volumetric/lazy_volume_loader.py index e194a270..b57006b6 100644 --- a/inferno/io/volumetric/lazy_volume_loader.py +++ b/inferno/io/volumetric/lazy_volume_loader.py @@ -1,5 +1,7 @@ import numpy as np import os +import pickle +from concurrent import futures # try to load io libraries (h5py and z5py) try: @@ -20,10 +22,39 @@ from ...utils import python_utils as pyu +# TODO support h5py as well +def filter_base_sequence(input_path, input_key, + window_size, stride, + filter_function, n_threads): + with z5py.File(input_path, 'r') as f: + ds = f[input_key] + shape = list(ds.shape) + sequence = vu.slidingwindowslices(shape=shape, + window_size=window_size, + strides=stride, + shuffle=True, + add_overhanging=True) + + def check_slice(slice_id, slice_): + print("Checking slice_id", slice_id) + data = ds[slice_] + if filter_function(data): + return None + else: + return slice_ + + with futures.ThreadPoolExecutor(n_threads) as tp: + tasks = [tp.submit(check_slice, slice_id, slice_) for slice_id, slice_ in enumerate(sequence)] + filtered_sequence = [t.result() for t in tasks] + + filtered_sequence = [seq for seq in filtered_sequence if seq is not None] + return filtered_sequence + + class LazyVolumeLoaderBase(SyncableDataset): def __init__(self, dataset, window_size, stride, downsampling_ratio=None, padding=None, padding_mode='reflect', transforms=None, return_index_spec=False, name=None, - data_slice=None): + data_slice=None, base_sequence=None): super(LazyVolumeLoaderBase, self).__init__() assert len(window_size) == dataset.ndim, "%i, %i" % (len(window_size), dataset.ndim) assert len(stride) == dataset.ndim @@ -58,7 +89,22 @@ def __init__(self, dataset, window_size, stride, downsampling_ratio=None, paddin else: raise NotImplementedError - self.base_sequence = self.make_sliding_windows() + if base_sequence is None: + self.base_sequence = self.make_sliding_windows() + else: + self.base_sequence = self.load_base_sequence(base_sequence) + + @staticmethod + def load_base_sequence(base_sequence): + if isinstance(base_sequence, (list, tuple)): + return base_sequence + elif isinstance(base_sequence, str): + assert os.path.exists(base_sequence) + with open(base_sequence, 'rb') as f: + base_sequence = pickle.load(f) + return base_sequence + else: + raise ValueError("Unsupported base_sequence format, must be either listlike or str") def normalize_slice(self, data_slice): if data_slice is None: @@ -185,7 +231,7 @@ def __init__(self, file_impl, path, assert os.path.exists(path), path self.path = path else: - raise NotImplementedError + raise NotImplementedError("Not implemented for type %s" % type(path)) if isinstance(path_in_file, dict): assert name is not None diff --git a/inferno/io/volumetric/volume.py b/inferno/io/volumetric/volume.py index f4a7fc93..3ee3ab6c 100755 --- a/inferno/io/volumetric/volume.py +++ b/inferno/io/volumetric/volume.py @@ -100,6 +100,7 @@ def pad_volume(self, padding=None): assert_(all(isinstance(pad, (int, tuple, list)) for pad in self.padding),\ "Expect int or iterable", TypeError) self.padding = [[pad, pad] if isinstance(pad, int) else pad for pad in self.padding] + print(self.volume.shape) self.volume = np.pad(self.volume, pad_width=self.padding, mode=self.padding_mode) @@ -228,7 +229,7 @@ def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=No if self.data_slice is not None and slicing_config_for_name.get('is_multichannel', False): self.data_slice = (slice(None),) + self.data_slice - assert 'window_size' in slicing_config_for_name + assert 'window_size' in slicing_config_for_name, str(slicing_config_for_name) assert 'stride' in slicing_config_for_name # Read in volume from file (can be hdf5, n5 or zarr) diff --git a/inferno/io/volumetric/volumetric_utils.py b/inferno/io/volumetric/volumetric_utils.py index d7f321a0..a33b5ebe 100755 --- a/inferno/io/volumetric/volumetric_utils.py +++ b/inferno/io/volumetric/volumetric_utils.py @@ -42,12 +42,13 @@ def dimension_window(start, stop, wsize, stride, dimsize, ds_dim): # otherwise predict the whole volume if dataslice is not None: assert len(dataslice) == dim, "Dataslice must be a tuple with len = data dimension." - starts = [sl.start for sl in dataslice] - stops = [sl.stop - wsize for sl, wsize in zip(dataslice, window_size)] + starts = [0 if sl.start is None else sl.start for sl in dataslice] + stops = [sh - wsize if sl.stop is None else sl.stop - wsize + for sl, wsize, sh in zip(dataslice, window_size, shape)] else: starts = dim * [0] - stops = [dimsize - wsize if wsize != dimsize else dimsize - for dimsize, wsize in zip(shape, window_size)] + stops = [dimsize - wsize if wsize != dimsize else dimsize + for dimsize, wsize in zip(shape, window_size)] assert all(stp > strt for strt, stp in zip(starts, stops)),\ "%s, %s" % (str(starts), str(stops)) @@ -128,7 +129,7 @@ def _to_list(x): nslices = [_1Dwindow(startmin, startmax, nhoodsiz, st, dsample, datalen, shuffle) if windowspec == 'x' else [slice(ws, ws + 1) for ws in _to_list(windowspec)] for startmin, startmax, datalen, nhoodsiz, st, windowspec, dsample in zip(startmins, startmaxs, shape, - nhoodsize, stride, window, ds)] + nhoodsize, stride, window, ds)] return it.product(*nslices) diff --git a/inferno/trainers/basic.py b/inferno/trainers/basic.py index c90c8ef5..70b53874 100755 --- a/inferno/trainers/basic.py +++ b/inferno/trainers/basic.py @@ -27,6 +27,14 @@ from .callbacks import Console from ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError +# NOTE for distributed training, we might also need +# from apex.parallel import DistributedDataParallel as DDP +# but I don't know where exactly to put it. +try: + from apex import amp +except ImportError: + amp = None + class Trainer(object): """A basic trainer. @@ -126,10 +134,44 @@ def __init__(self, model=None): # Print console self._console = Console() + # Train with mixed precision, only works + # if we have apex + self._mixed_precision = False + self._apex_opt_level = 'O1' + # Public if model is not None: self.model = model + @property + def mixed_precision(self): + return self._mixed_precision + + # this needs to be called after model and optimizer are set + @mixed_precision.setter + def mixed_precision(self, mp): + if mp: + assert_(amp is not None, "Cannot use mixed precision training without apex library", RuntimeError) + assert_(self.model is not None and self._optimizer is not None, + "Model and optimizer need to be set before activating mixed precision", RuntimeError) + # in order to support BCE loss + amp.register_float_function(torch, 'sigmoid') + # For now, we don't allow to set 'keep_batchnorm' and 'loss_scale' + self.model, self._optimizer = amp.initialize(self.model, self._optimizer, + opt_level=self._apex_opt_level, + keep_batchnorm_fp32=None) + self._mixed_precision = mp + + @property + def apex_opt_level(self): + return self._apex_opt_level + + @apex_opt_level.setter + def apex_opt_level(self, opt_level): + assert_(opt_level in ('O0', 'O1', 'O2', 'O3'), + "Invalid optimization level", ValueError) + self._apex_opt_level = opt_level + @property def console(self): """Get the current console.""" @@ -1368,17 +1410,21 @@ def apply_model_and_loss(self, inputs, target, backward=True, mode=None): kwargs['trainer'] = self if mode == 'train': loss = self.criterion(prediction, target, **kwargs) \ - if len(target) != 0 else self.criterion(prediction, **kwargs) + if len(target) != 0 else self.criterion(prediction, **kwargs) elif mode == 'eval': loss = self.validation_criterion(prediction, target, **kwargs) \ - if len(target) != 0 else self.validation_criterion(prediction, **kwargs) + if len(target) != 0 else self.validation_criterion(prediction, **kwargs) else: raise ValueError if backward: # Backprop if required # retain_graph option is needed for some custom # loss functions like malis, False per default - loss.backward(retain_graph=self.retain_graph) + if self.mixed_precision: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward(retain_graph=self.retain_graph) + else: + loss.backward(retain_graph=self.retain_graph) return prediction, loss def train_for(self, num_iterations=None, break_callback=None): @@ -1676,7 +1722,7 @@ def load(self, from_directory=None, best=False, filename=None, map_location=None 'best_checkpoint.pytorch'. filename : str Overrides the default filename. - device : function, torch.device, string or a dict + map_location : function, torch.device, string or a dict Specify how to remap storage locations. Returns diff --git a/inferno/trainers/callbacks/essentials.py b/inferno/trainers/callbacks/essentials.py index 704e3831..f62d2cc4 100755 --- a/inferno/trainers/callbacks/essentials.py +++ b/inferno/trainers/callbacks/essentials.py @@ -277,9 +277,10 @@ def norm_or_value(self): def after_model_and_loss_is_applied(self, **_): tu.clip_gradients_(self.trainer.model.parameters(), self.mode, self.norm_or_value) + class GarbageCollection(Callback): """ - Callback that triggers garbage collection at the end of every + Callback that triggers garbage collection at the end of every training iteration in order to reduce the memory footprint of training """ diff --git a/inferno/trainers/callbacks/scheduling.py b/inferno/trainers/callbacks/scheduling.py index 3a1fd347..9efaf478 100644 --- a/inferno/trainers/callbacks/scheduling.py +++ b/inferno/trainers/callbacks/scheduling.py @@ -301,9 +301,10 @@ def end_of_validation_run(self, **_): @staticmethod def is_significantly_less_than(x, y, min_relative_delta): + eps = 1.e-6 if x > y: return False - relative_delta = abs(y - x) / abs(y) + relative_delta = abs(y - x) / (abs(y) + eps) return relative_delta > min_relative_delta diff --git a/inferno/utils/io_utils.py b/inferno/utils/io_utils.py index 5b5ee6e8..8ef4350a 100755 --- a/inferno/utils/io_utils.py +++ b/inferno/utils/io_utils.py @@ -49,7 +49,7 @@ def yaml2dict(path): # Forgivable mistake that path is a dict already return path with open(path, 'r') as f: - readict = yaml.load(f) + readict = yaml.load(f, Loader=yaml.FullLoader) return readict