diff --git a/src/filter/mod.rs b/src/filter/mod.rs index 97f9c8de..7a9c6a77 100644 --- a/src/filter/mod.rs +++ b/src/filter/mod.rs @@ -59,27 +59,6 @@ pub fn bilateral_filter( (-0.5 * x.powi(2) / sigma_squared).exp() } - /// Effectively a meshgrid command with flattened outputs. - fn window_coords(window_size: u32) -> (Vec, Vec) { - let window_start = (-(window_size as f32) / 2.0).floor() as i32; - let window_end = (window_size as f32 / 2.0).floor() as i32 + 1; - let window_range = window_start..window_end; - - let cc = window_range - .clone() - .cycle() - .take(window_range.len().pow(2)) - .collect(); - - let n = window_size as usize + 1; - let mut rr = Vec::with_capacity(n * window_range.len()); - for i in window_range { - rr.extend(std::iter::repeat(i).take(n)); - } - - (rr, cc) - } - /// Create look-up table of Gaussian weights for color dimension. fn compute_color_lut(bins: u32, sigma: f32, max_value: f32) -> Vec { let step_size = max_value / bins as f32; @@ -90,14 +69,22 @@ pub fn bilateral_filter( /// Create look-up table of weights corresponding to flattened 2-D Gaussian kernel. fn compute_spatial_lut(window_size: u32, sigma: f32) -> Vec { - let (rr, cc) = window_coords(window_size); - let it = rr.into_iter().zip(cc); + let window_start = (-(window_size as f32) / 2.0).floor() as i32; + let window_end = (window_size as f32 / 2.0).floor() as i32 + 1; + let window_range = window_start..window_end; + + let cc = window_range.clone().cycle().take(window_range.len().pow(2)); + + let n = window_size as usize + 1; + let rr = window_range.flat_map(|i| std::iter::repeat(i).take(n)); + let sigma_squared = sigma.powi(2); - it.map(|(r, c)| { - let dist = ((r as f32).powi(2) + (c as f32).powi(2)).sqrt(); - gaussian_weight(dist, sigma_squared) - }) - .collect() + rr.zip(cc) + .map(|(r, c)| { + let dist = ((r as f32).powi(2) + (c as f32).powi(2)).sqrt(); + gaussian_weight(dist, sigma_squared) + }) + .collect() } let max_value = *image.iter().max().unwrap() as f32; @@ -122,13 +109,10 @@ pub fn bilateral_filter( let window_col_abs = (col as i32 + window_col).clamp(0, width as i32 - 1); // Wrap to edge. let kc = window_col + window_extent; let range_bin = (kr * window_size + kc) as usize; - let range_weight = range_lut[range_bin]; let val = image.get_pixel(window_col_abs as u32, window_row_abs as u32)[0]; - let color_dist = (window_center_val - val as i32).abs(); - let color_bin = (color_dist as f32 * color_dist_scale) as usize; - let color_bin = min(color_bin, max_color_bin); - let color_weight = color_lut[color_bin]; - let weight = range_weight * color_weight; + let color_dist = (window_center_val - val as i32).abs() as f32; + let color_bin = ((color_dist * color_dist_scale) as usize).min(max_color_bin); + let weight = range_lut[range_bin] * color_lut[color_bin]; total_val += val as f32 * weight; total_weight += weight; }