diff --git a/niftynet/contrib/segmentation_application_mo.py b/niftynet/contrib/segmentation_application_mo.py new file mode 100755 index 00000000..2dfa8801 --- /dev/null +++ b/niftynet/contrib/segmentation_application_mo.py @@ -0,0 +1,416 @@ +import tensorflow as tf + +from niftynet.application.base_application import BaseApplication +from niftynet.engine.application_factory import \ + ApplicationNetFactory, InitializerFactory, OptimiserFactory +from niftynet.engine.application_variables import \ + CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES +from niftynet.engine.sampler_grid import GridSampler +from niftynet.engine.sampler_resize import ResizeSampler +from niftynet.engine.sampler_uniform import UniformSampler +from niftynet.engine.sampler_weighted import WeightedSampler +from niftynet.engine.sampler_balanced import BalancedSampler +from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator +from niftynet.engine.windows_aggregator_resize import ResizeSamplesAggregator +from niftynet.io.image_reader import ImageReader +from niftynet.layer.binary_masking import BinaryMaskingLayer +from niftynet.layer.discrete_label_normalisation import \ + DiscreteLabelNormalisationLayer +from niftynet.layer.histogram_normalisation import \ + HistogramNormalisationLayer +from niftynet.layer.loss_segmentation import LossFunction +from niftynet.layer.mean_variance_normalisation import \ + MeanVarNormalisationLayer +from niftynet.layer.pad import PadLayer +from niftynet.layer.post_processing import PostProcessingLayer +from niftynet.layer.rand_flip import RandomFlipLayer +from niftynet.layer.rand_rotation import RandomRotationLayer +from niftynet.layer.rand_spatial_scaling import RandomSpatialScalingLayer +from niftynet.evaluation.segmentation_evaluator import SegmentationEvaluator +from niftynet.layer.rand_elastic_deform import RandomElasticDeformationLayer + +SUPPORTED_INPUT = set(['image', 'label', 'weight', 'sampler', 'inferred']) + + +class SegmentationApplicationMO(BaseApplication): + REQUIRED_CONFIG_SECTION = "SEGMENTATION" + + def __init__(self, net_param, action_param, action): + super(SegmentationApplicationMO, self).__init__() + tf.logging.info('starting segmentation application') + self.action = action + + self.net_param = net_param + self.action_param = action_param + + self.data_param = None + self.segmentation_param = None + self.SUPPORTED_SAMPLING = { + 'uniform': (self.initialise_uniform_sampler, + self.initialise_grid_sampler, + self.initialise_grid_aggregator), + 'weighted': (self.initialise_weighted_sampler, + self.initialise_grid_sampler, + self.initialise_grid_aggregator), + 'resize': (self.initialise_resize_sampler, + self.initialise_resize_sampler, + self.initialise_resize_aggregator), + 'balanced': (self.initialise_balanced_sampler, + self.initialise_grid_sampler, + self.initialise_grid_aggregator), + } + + def initialise_dataset_loader( + self, data_param=None, task_param=None, data_partitioner=None): + + self.data_param = data_param + self.segmentation_param = task_param + + file_lists = self.get_file_lists(data_partitioner) + # read each line of csv files into an instance of Subject + if self.is_training: + self.readers = [] + for file_list in file_lists: + reader = ImageReader({'image', 'label', 'weight', 'sampler'}) + reader.initialise(data_param, task_param, file_list) + self.readers.append(reader) + + elif self.is_inference: + # in the inference process use image input only + inference_reader = ImageReader({'image'}) + file_list = data_partitioner.inference_files + inference_reader.initialise(data_param, task_param, file_list) + self.readers = [inference_reader] + elif self.is_evaluation: + file_list = data_partitioner.inference_files + reader = ImageReader({'image', 'label', 'inferred'}) + reader.initialise(data_param, task_param, file_list) + self.readers = [reader] + else: + raise ValueError('Action `{}` not supported. Expected one of {}' + .format(self.action, self.SUPPORTED_ACTIONS)) + + foreground_masking_layer = None + if self.net_param.normalise_foreground_only: + foreground_masking_layer = BinaryMaskingLayer( + type_str=self.net_param.foreground_type, + multimod_fusion=self.net_param.multimod_foreground_type, + threshold=0.0) + + mean_var_normaliser = MeanVarNormalisationLayer( + image_name='image', binary_masking_func=foreground_masking_layer) + histogram_normaliser = None + if self.net_param.histogram_ref_file: + histogram_normaliser = HistogramNormalisationLayer( + image_name='image', + modalities=vars(task_param).get('image'), + model_filename=self.net_param.histogram_ref_file, + binary_masking_func=foreground_masking_layer, + norm_type=self.net_param.norm_type, + cutoff=self.net_param.cutoff, + name='hist_norm_layer') + + label_normalisers = None + if self.net_param.histogram_ref_file and \ + task_param.label_normalisation: + label_normalisers = [DiscreteLabelNormalisationLayer( + image_name='label', + modalities=vars(task_param).get('label'), + model_filename=self.net_param.histogram_ref_file)] + if self.is_evaluation: + label_normalisers.append( + DiscreteLabelNormalisationLayer( + image_name='inferred', + modalities=vars(task_param).get('inferred'), + model_filename=self.net_param.histogram_ref_file)) + label_normalisers[-1].key = label_normalisers[0].key + + normalisation_layers = [] + if self.net_param.normalisation: + normalisation_layers.append(histogram_normaliser) + if self.net_param.whitening: + normalisation_layers.append(mean_var_normaliser) + if task_param.label_normalisation and \ + (self.is_training or not task_param.output_prob): + normalisation_layers.extend(label_normalisers) + + augmentation_layers = [] + if self.is_training: + if self.action_param.random_flipping_axes != -1: + augmentation_layers.append(RandomFlipLayer( + flip_axes=self.action_param.random_flipping_axes)) + if self.action_param.scaling_percentage: + augmentation_layers.append(RandomSpatialScalingLayer( + min_percentage=self.action_param.scaling_percentage[0], + max_percentage=self.action_param.scaling_percentage[1])) + if self.action_param.rotation_angle or \ + self.action_param.rotation_angle_x or \ + self.action_param.rotation_angle_y or \ + self.action_param.rotation_angle_z: + rotation_layer = RandomRotationLayer() + if self.action_param.rotation_angle: + rotation_layer.init_uniform_angle( + self.action_param.rotation_angle) + else: + rotation_layer.init_non_uniform_angle( + self.action_param.rotation_angle_x, + self.action_param.rotation_angle_y, + self.action_param.rotation_angle_z) + augmentation_layers.append(rotation_layer) + + # add deformation layer + if self.action_param.do_elastic_deformation: + spatial_rank = list(self.readers[0].spatial_ranks.values())[0] + augmentation_layers.append(RandomElasticDeformationLayer( + spatial_rank=spatial_rank, + num_controlpoints=self.action_param.num_ctrl_points, + std_deformation_sigma=self.action_param.deformation_sigma, + proportion_to_augment=self.action_param.proportion_to_deform)) + + volume_padding_layer = [] + if self.net_param.volume_padding_size: + volume_padding_layer.append(PadLayer( + image_name=SUPPORTED_INPUT, + border=self.net_param.volume_padding_size, + mode=self.net_param.volume_padding_mode + )) + + # only add augmentation to first reader (not validation reader) + self.readers[0].add_preprocessing_layers( + volume_padding_layer + + normalisation_layers + + augmentation_layers) + + for reader in self.readers[1:]: + reader.add_preprocessing_layers( + volume_padding_layer + + normalisation_layers) + + def initialise_uniform_sampler(self): + self.sampler = [[UniformSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + windows_per_image=self.action_param.sample_per_volume, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_weighted_sampler(self): + self.sampler = [[WeightedSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + windows_per_image=self.action_param.sample_per_volume, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_resize_sampler(self): + self.sampler = [[ResizeSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + shuffle_buffer=self.is_training, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_grid_sampler(self): + self.sampler = [[GridSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + spatial_window_size=self.action_param.spatial_window_size, + window_border=self.action_param.border, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_balanced_sampler(self): + self.sampler = [[BalancedSampler( + reader=reader, + data_param=self.data_param, + batch_size=self.net_param.batch_size, + windows_per_image=self.action_param.sample_per_volume, + queue_length=self.net_param.queue_length) for reader in + self.readers]] + + def initialise_grid_aggregator(self): + self.output_decoder = GridSamplesAggregator( + image_reader=self.readers[0], + output_path=self.action_param.save_seg_dir, + window_border=self.action_param.border, + interp_order=self.action_param.output_interp_order, + postfix=self.action_param.output_postfix) + + def initialise_resize_aggregator(self): + self.output_decoder = ResizeSamplesAggregator( + image_reader=self.readers[0], + output_path=self.action_param.save_seg_dir, + window_border=self.action_param.border, + interp_order=self.action_param.output_interp_order, + postfix=self.action_param.output_postfix) + + def initialise_sampler(self): + if self.is_training: + self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]() + elif self.is_inference: + self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]() + + def initialise_aggregator(self): + self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]() + + def initialise_network(self): + w_regularizer = None + b_regularizer = None + reg_type = self.net_param.reg_type.lower() + decay = self.net_param.decay + if reg_type == 'l2' and decay > 0: + from tensorflow.contrib.layers.python.layers import regularizers + w_regularizer = regularizers.l2_regularizer(decay) + b_regularizer = regularizers.l2_regularizer(decay) + elif reg_type == 'l1' and decay > 0: + from tensorflow.contrib.layers.python.layers import regularizers + w_regularizer = regularizers.l1_regularizer(decay) + b_regularizer = regularizers.l1_regularizer(decay) + + self.net = ApplicationNetFactory.create(self.net_param.name)( + num_classes=self.segmentation_param.num_classes, + w_initializer=InitializerFactory.get_initializer( + name=self.net_param.weight_initializer), + b_initializer=InitializerFactory.get_initializer( + name=self.net_param.bias_initializer), + w_regularizer=w_regularizer, + b_regularizer=b_regularizer, + acti_func=self.net_param.activation_function) + + def connect_data_and_network(self, + outputs_collector=None, + gradients_collector=None): + # def data_net(for_training): + # with tf.name_scope('train' if for_training else 'validation'): + # sampler = self.get_sampler()[0][0 if for_training else -1] + # data_dict = sampler.pop_batch_op() + # image = tf.cast(data_dict['image'], tf.float32) + # return data_dict, self.net(image, is_training=for_training) + + def switch_sampler(for_training): + with tf.name_scope('train' if for_training else 'validation'): + sampler = self.get_sampler()[0][0 if for_training else -1] + return sampler.pop_batch_op() + + if self.is_training: + # if self.action_param.validation_every_n > 0: + # data_dict, net_out = tf.cond(tf.logical_not(self.is_validation), + # lambda: data_net(True), + # lambda: data_net(False)) + # else: + # data_dict, net_out = data_net(True) + if self.action_param.validation_every_n > 0: + data_dict = tf.cond(tf.logical_not(self.is_validation), + lambda: switch_sampler(for_training=True), + lambda: switch_sampler(for_training=False)) + else: + data_dict = switch_sampler(for_training=True) + + image = tf.cast(data_dict['image'], tf.float32) + net_args = {'is_training': self.is_training, + 'keep_prob': self.net_param.keep_prob} + net_out = self.net(image, **net_args) + + with tf.name_scope('Optimiser'): + optimiser_class = OptimiserFactory.create( + name=self.action_param.optimiser) + self.optimiser = optimiser_class.get_instance( + learning_rate=self.action_param.lr) + loss_func = LossFunction( + n_class=self.segmentation_param.num_classes, + loss_type=self.action_param.loss_type, + softmax=self.segmentation_param.softmax) + data_loss = loss_func( + prediction=net_out, + ground_truth=data_dict.get('label', None), + weight_map=data_dict.get('weight', None)) + reg_losses = tf.get_collection( + tf.GraphKeys.REGULARIZATION_LOSSES) + if self.net_param.decay > 0.0 and reg_losses: + reg_loss = tf.reduce_mean( + [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) + loss = data_loss + reg_loss + else: + loss = data_loss + grads = self.optimiser.compute_gradients(loss) + # collecting gradients variables + gradients_collector.add_to_collection([grads]) + # collecting output variables + outputs_collector.add_to_collection( + var=data_loss, name='loss', + average_over_devices=False, collection=CONSOLE) + outputs_collector.add_to_collection( + var=data_loss, name='loss', + average_over_devices=True, summary_type='scalar', + collection=TF_SUMMARIES) + + # outputs_collector.add_to_collection( + # var=image*180.0, name='image', + # average_over_devices=False, summary_type='image3_sagittal', + # collection=TF_SUMMARIES) + + # outputs_collector.add_to_collection( + # var=image, name='image', + # average_over_devices=False, + # collection=NETWORK_OUTPUT) + + # outputs_collector.add_to_collection( + # var=tf.reduce_mean(image), name='mean_image', + # average_over_devices=False, summary_type='scalar', + # collection=CONSOLE) + elif self.is_inference: + # converting logits into final output for + # classification probabilities or argmax classification labels + data_dict = switch_sampler(for_training=False) + image = tf.cast(data_dict['image'], tf.float32) + net_args = {'is_training': self.is_training, + 'keep_prob': self.net_param.keep_prob} + net_out = self.net(image, **net_args) + + output_prob = self.segmentation_param.output_prob + num_classes = self.segmentation_param.num_classes + if num_classes > 1: + post_process_layer_proba = PostProcessingLayer( + 'SOFTMAX', num_classes=num_classes) + post_process_layer_argmax = PostProcessingLayer( + 'ARGMAX', num_classes=num_classes) + else: + post_process_layer_proba = PostProcessingLayer( + 'IDENTITY', num_classes=num_classes) + post_process_layer_argmax = PostProcessingLayer( + 'IDENTITY', num_classes=num_classes) + + net_out_proba = post_process_layer_proba(net_out) + net_out_argmax = post_process_layer_argmax(net_out) + + outputs_collector.add_to_collection( + var=net_out_proba, name='window_proba', + average_over_devices=False, collection=NETWORK_OUTPUT) + outputs_collector.add_to_collection( + var=net_out_argmax, name='window_argmax', + average_over_devices=False, collection=NETWORK_OUTPUT) + outputs_collector.add_to_collection( + var=data_dict['image_location'], name='location', + average_over_devices=False, collection=NETWORK_OUTPUT) + self.initialise_aggregator() + + def interpret_output(self, batch_output): + if self.is_inference: + return self.output_decoder.decode_dict_batch({ + 'window_proba': batch_output['window_proba'], + 'window_argmax': batch_output['window_argmax']} + , batch_output['location']) + return True + + def initialise_evaluator(self, eval_param): + self.eval_param = eval_param + self.evaluator = SegmentationEvaluator(self.readers[0], + self.segmentation_param, + eval_param) + + def add_inferred_output(self, data_param, task_param): + return self.add_inferred_output_like(data_param, task_param, 'label') diff --git a/niftynet/engine/windows_aggregator_grid.py b/niftynet/engine/windows_aggregator_grid.py index d26341c7..9fbaa9bd 100755 --- a/niftynet/engine/windows_aggregator_grid.py +++ b/niftynet/engine/windows_aggregator_grid.py @@ -59,6 +59,47 @@ def decode_batch(self, window, location): z_start:z_end, ...] = window[batch_id, ...] return True + def decode_dict_batch(self, window, location): + """ + Create the aggregation in case of multiple outputs with same location + information. The dictionary keys are used in the saving name when + calling save_dict_current_image + :param window: + :param location: + :param name_opt: + :return: + """ + n_samples = location.shape[0] + location_init = np.copy(location) + dummy = None + for w in window: + dummy = np.ones_like(window[w]) + window[w], _ = self.crop_batch(window[w], location_init, + self.window_border) + location_init = np.copy(location) + _, location = self.crop_batch(dummy, location_init, self.window_border) + for batch_id in range(n_samples): + image_id, x_start, y_start, z_start, x_end, y_end, z_end = \ + location[batch_id, :] + if image_id != self.image_id: + # image name changed: + # save current image and create an empty image + self._save_dict_current_image() + if self._is_stopping_signal(location[batch_id]): + return False + self.image_out = {} + for w in window: + self.image_out[w] = self._initialise_empty_image( + image_id=image_id, + n_channels=window[w].shape[-1], + dtype=window[w].dtype) + for w in window: + self.image_out[w][x_start:x_end, + y_start:y_end, + z_start:z_end, ...] = window[w][batch_id, ...] + print(np.sum(self.image_out[w]), w) + return True + def _initialise_empty_image(self, image_id, n_channels, dtype=np.float): self.image_id = image_id spatial_shape = self.input_image[self.name].shape[:3] @@ -89,3 +130,25 @@ def _save_current_image(self): self.output_interp_order) self.log_inferred(subject_name, filename) return + + def _save_dict_current_image(self): + if self.input_image is None: + return + + for layer in reversed(self.reader.preprocessors): + if isinstance(layer, PadLayer): + self.image_out, _ = layer.inverse_op(self.image_out) + if isinstance(layer, DiscreteLabelNormalisationLayer): + self.image_out, _ = layer.inverse_op(self.image_out) + subject_name = self.reader.get_subject_id(self.image_id) + for i in self.image_out: + filename = "{}_{}{}.nii.gz".format(i, subject_name, + self.postfix) + source_image_obj = self.input_image[self.name] + misc_io.save_data_array(self.output_path, + filename, + self.image_out, + source_image_obj, + self.output_interp_order) + self.log_inferred(subject_name, filename) + return diff --git a/niftynet/engine/windows_aggregator_resize.py b/niftynet/engine/windows_aggregator_resize.py index 361e8395..27d3653d 100755 --- a/niftynet/engine/windows_aggregator_resize.py +++ b/niftynet/engine/windows_aggregator_resize.py @@ -55,6 +55,30 @@ def decode_batch(self, window, location): self._save_current_image(window[batch_id, ...], resize_to_shape) return True + def decode_dict_batch(self, window, location): + """ + Resizing each output image window element of the window dictionary in + the batch as an image volume + location specifies the original input image (so that the + interpolation order, original shape information retained in the + generated outputs). The key in the dictionary will be added to the + postfix to differentiate between dictionary output elements + :param window: + :param location: + :return: + """ + n_samples = location.shape[0] + for batch_id in range(n_samples): + if self._is_stopping_signal(location[batch_id]): + return False + for w in window: + self.image_id = location[batch_id, 0] + resize_to_shape = self._initialise_image_shape( + image_id=self.image_id, + n_channels=window[w].shape[-1]) + self._save_current_image(window[w][batch_id, ...], + resize_to_shape, name=w) + def _initialise_image_shape(self, image_id, n_channels): self.image_id = image_id spatial_shape = self.input_image[self.name].shape[:3] @@ -65,7 +89,7 @@ def _initialise_image_shape(self, image_id, n_channels): empty_image, _ = layer(empty_image) return empty_image.shape - def _save_current_image(self, image_out, resize_to): + def _save_current_image(self, image_out, resize_to, name=''): if self.input_image is None: return window_shape = resize_to @@ -92,7 +116,7 @@ def _save_current_image(self, image_out, resize_to): if isinstance(layer, DiscreteLabelNormalisationLayer): image_out, _ = layer.inverse_op(image_out) subject_name = self.reader.get_subject_id(self.image_id) - filename = "{}{}.nii.gz".format(subject_name, self.postfix) + filename = "{}{}.nii.gz".format(subject_name, self.postfix+name) source_image_obj = self.input_image[self.name] misc_io.save_data_array(self.output_path, filename,