Skip to content

Commit

Permalink
Merge pull request #196 from saalfeldlab/feature/streak-finder
Browse files Browse the repository at this point in the history
Various tools for de-streaking
  • Loading branch information
minnerbe authored Jan 7, 2025
2 parents 9f9e41c + f3b6e17 commit c1c8906
Show file tree
Hide file tree
Showing 12 changed files with 1,030 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package org.janelia.alignment.destreak;

import ij.IJ;
import ij.ImagePlus;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;

import java.io.Serializable;

/**
* This class detects streaks in an image and returns a corresponding mask.
* <p>
* The finder first applies a derivative filter in the x-direction to detect vertical edges. Then, it applies a mean
* filter in the y-direction to smooth out the edges in the y-direction. The resulting image is then thresholded
* (from above and below) to create a mask of the streaks. Finally, an optional Gaussian blur is applied to the mask to
* smooth it. The mask is 0 where there are no streaks and 255 where there are streaks.
* <p>
* There are three parameters that can be set:
* <ul>
* <li>meanFilterSize: the number of pixels to average in the y-direction (e.g., 0 means no averaging, 50 means averaging +/-50 pixels in y)</li>
* <li>threshold: the threshold used to convert the streak mask to a binary mask</li>
* <li>blurRadius: the radius of the Gaussian blur applied to the streak mask (0 means no smoothing)</li>
* </ul>
*/
public class StreakFinder implements Serializable {

private final int meanFilterSize;
private final double threshold;
private final int blurRadius;

public StreakFinder(final int meanFilterSize, final double threshold, final int blurRadius) {
if (meanFilterSize < 0) {
throw new IllegalArgumentException("meanFilterSize must be non-negative");
}
if (threshold < 0) {
throw new IllegalArgumentException("threshold must be non-negative");
}
if (blurRadius < 0) {
throw new IllegalArgumentException("blurRadius must be 0 (no blur) or positive");
}

this.meanFilterSize = meanFilterSize;
this.threshold = threshold;
this.blurRadius = blurRadius;
}

public ImagePlus createStreakMask(final ImagePlus input) {
ImageProcessor filtered = differenceFilterX(input.getProcessor());
filtered = meanFilterY(filtered, meanFilterSize);
filtered = bidirectionalThreshold(filtered, threshold);

final ImagePlus mask = new ImagePlus("Mask", filtered);
if (blurRadius > 0) {
IJ.run(mask, "Gaussian Blur...", String.format("sigma=%d", blurRadius));
}
return mask;
}

private static ImageProcessor differenceFilterX(final ImageProcessor in) {
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight());
final int width = in.getWidth();
final int height = in.getHeight();

for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
final float left = in.getf(projectPeriodically(x - 1, width), y);
final float right = in.getf(projectPeriodically(x + 1, width), y);
out.setf(x, y, (right - left) / 2);
}
}
return out;
}

private static ImageProcessor meanFilterY(final ImageProcessor in, final int size) {
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight());
final int width = in.getWidth();
final int height = in.getHeight();
final int n = 2 * size + 1;

for (int x = 0; x < width; x++) {
// initialize running sum
float sum = in.getf(x, 0);
for (int y = 1; y <= size; y++) {
sum += 2 * in.getf(x, y);
}
out.setf(x, 0, sum / n);

// update running sum by adding the next value and subtracting the oldest value
for (int y = 1; y < height; y++) {
final float oldest = in.getf(x, projectPeriodically(y - size - 1, height));
final float newest = in.getf(x, projectPeriodically(y + size, height));
sum += newest - oldest;
out.setf(x, y, sum / n);
}
}
return out;
}

private static ImageProcessor bidirectionalThreshold(final ImageProcessor in, final double threshold) {
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight());
final int width = in.getWidth();
final int height = in.getHeight();

for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
final float value = Math.abs(in.getf(x, y));
out.setf(x, y, (value > threshold) ? 255 : 0);
}
}
return out;
}

private static int projectPeriodically(final int index, final int max) {
if (index < 0) {
return -index;
} else if (index >= max) {
return 2 * max - index - 2;
} else {
return index;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.janelia.alignment.inpainting;

import java.util.Random;

/**
* A statistic that yields a small perturbation of a given 2D direction for each sample.
*/
public class AnisotropicDirection2D implements DirectionalStatistic {

private final Random random;
private final double[] primalAxis;
private final double[] secondaryAxis;
private final double perturbation;

/**
* Creates a new statistic with a random seed.
*/
public AnisotropicDirection2D(final double[] primalAxis, final double perturbation) {
this(primalAxis, perturbation, new Random());
}

/**
* Creates a new statistic with the given random number generator.
*
* @param random the random number generator to use
*/
public AnisotropicDirection2D(final double[] primalAxis, final double perturbation, final Random random) {
final double norm = Math.sqrt(primalAxis[0] * primalAxis[0] + primalAxis[1] * primalAxis[1]);
this.primalAxis = new double[] { primalAxis[0] / norm, primalAxis[1] / norm };
this.secondaryAxis = new double[] { -primalAxis[1] / norm, primalAxis[0] / norm };
this.perturbation = perturbation;
this.random = random;
}

@Override
public void sample(final double[] direction) {
// TODO: this should be a von Mises distribution instead of this homegrown implementation
final int sign = random.nextBoolean() ? 1 : -1;
final double eps = perturbation * random.nextGaussian();
final double norm = 1 + eps * eps; // because axes are orthonormal
direction[0] = (sign * primalAxis[0] + eps * secondaryAxis[0]) / norm;
direction[1] = (sign * primalAxis[1] + eps * secondaryAxis[1]) / norm;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.janelia.alignment.inpainting;

/**
* Interface for distributions that model the direction of a ray used in {@link RayCastingInpainter}.
*/
public interface DirectionalStatistic {

/**
* Initializes the direction of the next ray. The array that is passed in is filled with the direction.
*
* @param direction the array in which to initialize the direction
*/
void sample(double[] direction);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.janelia.alignment.inpainting;

import java.util.Random;

/**
* A statistic that yields a completely random 2D direction for each sample.
*/
public class RandomDirection2D implements DirectionalStatistic {

private final Random random;

/**
* Creates a new statistic with a random seed.
*/
public RandomDirection2D() {
this(new Random());
}

/**
* Creates a new statistic with the given random number generator.
*
* @param random the random number generator to use
*/
public RandomDirection2D(final Random random) {
this.random = random;
}

@Override
public void sample(final double[] direction) {
final double angle = random.nextDouble() * 2 * Math.PI;
direction[0] = Math.cos(angle);
direction[1] = Math.sin(angle);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.janelia.alignment.inpainting;

import java.util.Random;

/**
* A statistic that yields a completely random 3D direction for each sample.
*/
public class RandomDirection3D implements DirectionalStatistic {

private final Random random;

/**
* Creates a new statistic with a random seed.
*/
public RandomDirection3D() {
this(new Random());
}

/**
* Creates a new statistic with the given random number generator.
*
* @param random the random number generator to use
*/
public RandomDirection3D(final Random random) {
this.random = random;
}

@Override
public void sample(final double[] direction) {
final double x = random.nextGaussian();
final double y = random.nextGaussian();
final double z = random.nextGaussian();
final double norm = Math.sqrt(x * x + y * y + z * z);
direction[0] = x / norm;
direction[1] = y / norm;
direction[2] = z / norm;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package org.janelia.alignment.inpainting;

import net.imglib2.RealInterval;

import net.imglib2.Cursor;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.RealLocalizable;
import net.imglib2.RealRandomAccess;
import net.imglib2.interpolation.randomaccess.NLinearInterpolatorFactory;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.Views;


/**
* Infer missing values in an image (up to 3D) by ray casting (which is equivalent to diffusion of image values).
* <p>
* This is adapted from the hotknife repository for testing purposes.
*/
public class RayCastingInpainter {

private final int nRays;
private final long maxRayLength;
private final DirectionalStatistic directionStatistic;

private final double[] direction = new double[3];
private final Result result = new Result();

public RayCastingInpainter(final int nRays, final int maxInpaintingDiameter, final DirectionalStatistic directionStatistic) {
this.nRays = nRays;
this.maxRayLength = maxInpaintingDiameter;
this.directionStatistic = directionStatistic;
}

private static boolean isInside(final RealLocalizable p, final RealInterval r) {
for (int d = 0; d < p.numDimensions(); ++d) {
final double l = p.getDoublePosition(d);
if (l < r.realMin(d) || l > r.realMax(d)) {
return false;
}
}
return true;
}

/**
* Inpaints missing values in an image (up to 3D) by casting rays in random directions and averaging the values of
* the first non-masked pixel.
*
* @param img the image to inpaint
* @param mask the mask
*/
public void inpaint(final RandomAccessibleInterval<FloatType> img, final RandomAccessibleInterval<FloatType> mask) {
final Cursor<FloatType> imgCursor = Views.iterable(img).localizingCursor();

final RealRandomAccess<FloatType> imageAccess = Views.interpolate(Views.extendBorder(img), new NLinearInterpolatorFactory<>()).realRandomAccess();
final RealRandomAccess<FloatType> maskAccess = Views.interpolate(Views.extendBorder(mask), new NLinearInterpolatorFactory<>()).realRandomAccess();

while (imgCursor.hasNext()) {
final FloatType o = imgCursor.next();
final float m = maskAccess.setPositionAndGet(imgCursor).get();
if (m == 0.0) {
// pixel not masked, no inpainting necessary
continue;
}

double weightSum = 0;
double valueSum = 0;

// interpolate value by casting rays in random directions and averaging (weighted by distances) the
// values of the first non-masked pixel
for (int i = 0; i < nRays; ++i) {
final Result result = castRay(maskAccess, mask, imgCursor);
if (result != null) {
final double weight = 1.0 / result.distance;
weightSum += weight;
final double value = imageAccess.setPositionAndGet(result.position).getRealDouble();
valueSum += value * weight;
}
}

final float v = (float) (valueSum / weightSum);
final float w = m / 255.0f;
final float oldValue = o.get();
final float newValue = v * w + oldValue * (1 - w);
o.set(newValue);
}
}

/**
* Casts a ray from the given position in a random direction until it hits a non-masked (i.e., non-NaN) pixel
* or exits the image boundary.
*
* @param mask the mask indicating which pixels are masked (> 0) and which are not (0)
* @param interval the interval of the image
* @param position the position from which to cast the ray
* @return the result of the ray casting or null if the ray exited the image boundary without hitting a
* non-masked pixel
*/
private Result castRay(final RealRandomAccess<FloatType> mask, final Interval interval, final RealLocalizable position) {
mask.setPosition(position);
directionStatistic.sample(direction);
long steps = 0;

while(true) {
mask.move(direction);
++steps;

if (!isInside(mask, interval) || steps > maxRayLength) {
// the ray exited the image boundaries without hitting a non-masked pixel
return null;
}

final float value = mask.get().get();
if (value < 1.0) {
// the ray reached a non-masked pixel
mask.localize(result.position);
result.distance = steps;
return result;
}
}
}


private static class Result {
public double[] position = new double[3];
public double distance = 0;
}
}
Loading

0 comments on commit c1c8906

Please sign in to comment.