Skip to content

Commit

Permalink
Do manual affine relative to first slice
Browse files Browse the repository at this point in the history
  • Loading branch information
minnerbe committed Jan 24, 2025
1 parent 0fe94fb commit 0830a89
Showing 1 changed file with 40 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import mpicbg.models.AbstractAffineModel1D;
import mpicbg.models.AffineModel1D;
import mpicbg.models.IllDefinedDataPointsException;
import mpicbg.models.NotEnoughDataPointsException;
import mpicbg.models.TranslationModel1D;
import net.imglib2.Cursor;
import net.imglib2.RandomAccess;
import net.imglib2.converter.Converters;
Expand Down Expand Up @@ -57,10 +54,6 @@
* @param <T> pixel type, determined automatically from the input stack (either 8bit or 16bit)
*/
public class SparkNormalizeLayerIntensityN5<T extends NativeType<T> & IntegerType<T>> implements Serializable {
public static enum NormalizationModel {
TRANSLATION,
AFFINE
}

@SuppressWarnings({"FieldMayBeFinal", "unused"})
public static class Options extends AbstractOptions implements Serializable {
Expand Down Expand Up @@ -88,9 +81,9 @@ public static class Options extends AbstractOptions implements Serializable {
usage = "If specified, generates a scale pyramid with given factors, e.g. 2,2,1")
public String factors;

@Option(name = "--model",
usage = "Normalization model: TRANSLATION or AFFINE")
private NormalizationModel model = NormalizationModel.TRANSLATION;
@Option(name = "--spreadIntensities",
usage = "Spread intensities to full range of first layer in addition to shifting them")
private boolean spreadIntensities = false;

public Options(final String[] args) {
final CmdLineParser parser = new CmdLineParser(this);
Expand Down Expand Up @@ -131,6 +124,8 @@ public static void main(final String... args) throws IOException, InterruptedExc
}


private static final double DEFAULT_CUTOFF = 0.03;

private final String fullScaleInputDataset;
private final String downScaledInputDataset;
private final String fullScaleOutputDataset;
Expand All @@ -150,19 +145,10 @@ private SparkNormalizeLayerIntensityN5(final Options options, final DatasetAttri

private void run() throws IOException {

final List<? extends AbstractAffineModel1D<?>> transformations;
final List<AffineModel1D> transformations;
try (final N5Reader n5reader = new N5FSReader(options.n5Path)) {
final Img<T> downScaledImg = N5Utils.open(n5reader, downScaledInputDataset);
switch (options.model) {
case TRANSLATION:
transformations = computeTransformations(downScaledImg, TranslationModel1D::new);
break;
case AFFINE:
transformations = computeTransformations(downScaledImg, AffineModel1D::new);
break;
default:
throw new IllegalArgumentException("Unsupported normalization model: " + options.model);
}
transformations = computeTransformations(downScaledImg, options.spreadIntensities);
}

if (transformations.size() != attributes.getDimensions()[2]) {
Expand Down Expand Up @@ -195,10 +181,9 @@ private void run() throws IOException {
}


private <A extends AbstractAffineModel1D<A>>
List<A> computeTransformations(
List<AffineModel1D> computeTransformations(
final RandomAccessibleInterval<T> rai,
final Supplier<A> modelSupplier
final boolean spreadIntensities
) {

// create mask from pixels that have "content" throughout the stack
Expand All @@ -221,48 +206,55 @@ List<A> computeTransformations(
.filter(pixel -> pixel.getInteger() == 1)
.count();

// Match intensity of content pixels in each layer and compute relative transformations
final List<A> models = new ArrayList<>(downScaledStack.size());
models.add(modelSupplier.get());
final double[][] previousLayerPixels = new double[1][nContentPixels];
final double[][] currentLayerPixels = new double[1][nContentPixels];
final double[] weights = new double[nContentPixels];
Arrays.fill(weights, 1);
extractContentPixels(downScaledStack.get(0), zProjectedContentMask, previousLayerPixels);
// Match intensity of content pixels in each layer to match first layer
final List<AffineModel1D> models = new ArrayList<>(downScaledStack.size());
models.add(new AffineModel1D());

final double[] layerPixels = new double[nContentPixels];
extractContentPixels(downScaledStack.get(0), zProjectedContentMask, layerPixels);
final double firstLayerIntensityAverage = Arrays.stream(layerPixels).average().orElseThrow(NoSuchElementException::new);
final double firstLayerIntensitySpread = computeIntensitySpread(layerPixels, DEFAULT_CUTOFF);

for (int z = 1; z < downScaledStack.size(); ++z) {
final A model = modelSupplier.get();
final IntervalView<T> currentLayer = downScaledStack.get(z);
extractContentPixels(currentLayer, zProjectedContentMask, currentLayerPixels);
try {
model.fit(currentLayerPixels, previousLayerPixels, weights);
} catch (final NotEnoughDataPointsException | IllDefinedDataPointsException e) {
throw new RuntimeException("Could not estimate model for layer " + z, e);
extractContentPixels(currentLayer, zProjectedContentMask, layerPixels);
final double currentLayerIntensityAverage = Arrays.stream(layerPixels).average().orElseThrow(NoSuchElementException::new);

final AffineModel1D model = new AffineModel1D();
if (spreadIntensities) {
final double currentLayerIntensitySpread = computeIntensitySpread(layerPixels, DEFAULT_CUTOFF);
final double scale = firstLayerIntensitySpread / currentLayerIntensitySpread;
final double shift = firstLayerIntensityAverage - currentLayerIntensityAverage * scale;
model.set(scale, shift);
} else {
final double shift = firstLayerIntensityAverage - currentLayerIntensityAverage;
model.set(1.0, shift);
}
models.add(model);
System.arraycopy(currentLayerPixels[0], 0, previousLayerPixels[0], 0, currentLayerPixels[0].length);
}

// Make transformations relative to the first layer
for (int z = 1; z < models.size(); ++z) {
models.get(z).preConcatenate(models.get(z - 1));
}

return models;
}

private void extractContentPixels(final RandomAccessibleInterval<T> layer, final Img<T> mask, final double[][] contentPixels) {
private void extractContentPixels(final RandomAccessibleInterval<T> layer, final Img<T> mask, final double[] contentPixels) {
final Cursor<T> layerCursor = Views.flatIterable(layer).localizingCursor();
final RandomAccess<T> maskAccess = mask.randomAccess();
int i = 0;
while (layerCursor.hasNext()) {
layerCursor.fwd();
maskAccess.setPosition(layerCursor);
if (maskAccess.get().getInteger() == 1) {
contentPixels[0][i++] = layerCursor.get().getInteger();
contentPixels[i++] = layerCursor.get().getInteger();
}
}
// Arrays.sort(contentPixels, Comparator.comparingDouble(a -> a[0]));
}

private double computeIntensitySpread(final double[] pixels, final double cutoff) {
Arrays.sort(pixels);
final int n = (int) Math.round(pixels.length * cutoff);
final double min = pixels[n];
final double max = pixels[pixels.length - n - 1];
return max - min;
}

private RandomAccessibleInterval<T> applyTransformations(
Expand Down

0 comments on commit 0830a89

Please sign in to comment.