From fb9016fa1294c14bd69ca832d60ea000fc8a0aef Mon Sep 17 00:00:00 2001 From: Karel Peeters Date: Sun, 10 Mar 2024 00:07:33 +0100 Subject: [PATCH] Add ceil and floor for non-multiple MaxPooling. --- kn-cuda-eval/src/planner.rs | 2 +- kn-graph/src/graph.rs | 15 +++++---- kn-graph/src/onnx/load.rs | 67 ++++++++++++++++++++++++++++--------- kn-graph/src/shape.rs | 45 ++++++++++++++++++++----- 4 files changed, 99 insertions(+), 30 deletions(-) diff --git a/kn-cuda-eval/src/planner.rs b/kn-cuda-eval/src/planner.rs index 7b23639..d0f2c9c 100644 --- a/kn-cuda-eval/src/planner.rs +++ b/kn-cuda-eval/src/planner.rs @@ -481,7 +481,7 @@ impl<'a> Planner<'a> { let input = self.visit(input)?; let output = self.alloc_tensor_shared(result_shape, result_dtype, Some(value)); - let identity = op.identity(); + let identity = op.identity_t(); let (operation, is_mean) = op.operation(); let post_process = if is_mean { diff --git a/kn-graph/src/graph.rs b/kn-graph/src/graph.rs index d7e08ed..bcdd6cb 100644 --- a/kn-graph/src/graph.rs +++ b/kn-graph/src/graph.rs @@ -1739,15 +1739,18 @@ impl ReduceOp { ReduceOp::Max, ]; - pub fn identity(self) -> T { - let specials = T::DTYPE.specials(); - let result = match self { + pub fn identity(self, dtype: DType) -> DScalar { + let specials = dtype.specials(); + match self { ReduceOp::Sum | ReduceOp::Mean => specials.zero, ReduceOp::Prod => specials.one, ReduceOp::Min => specials.max, ReduceOp::Max => specials.min, - }; - T::from_dscalar(result).unwrap() + } + } + + pub fn identity_t(self) -> T { + T::from_dscalar(self.identity(T::DTYPE)).unwrap() } pub fn operation(self) -> (BinaryOp, bool) { @@ -1764,7 +1767,7 @@ impl ReduceOp { let (op, is_mean) = self.operation(); let mut count = 0; - let total = seq.into_iter().fold(self.identity(), |acc, x| { + let total = seq.into_iter().fold(self.identity_t(), |acc, x| { count += 1; op.map_t(acc, x) }); diff --git a/kn-graph/src/onnx/load.rs b/kn-graph/src/onnx/load.rs index b9ccd9f..c6b7829 100644 --- a/kn-graph/src/onnx/load.rs +++ b/kn-graph/src/onnx/load.rs @@ -22,7 +22,7 @@ use crate::onnx::result::{Node, OnnxError, OnnxResult, UnwrapProto}; use crate::onnx::store::Store; use crate::onnx::typed_value::{OnnxValue, SignedSize}; use crate::shape; -use crate::shape::{Shape, Size}; +use crate::shape::{DivResult, Shape, Size}; // TODO we should switch to taking an extra `HashMap` parameter, // so the user can decide which named axes match to what size or even the batch size @@ -947,6 +947,7 @@ fn visit_node( let strides = attrs.take_ints("strides")?; let kernel_shape = attrs.take_ints("kernel_shape")?; let pads = attrs.take_ints("pads")?; + let ceil_mode = attrs.maybe_take_int("ceil_mode")?.unwrap_or(0) != 0; let auto_pad = attrs.maybe_take_string("auto_pad")?; assert_eq!(strides, kernel_shape, "Real strides not supported yet"); @@ -955,32 +956,68 @@ fn visit_node( // max pool the last N dimensions: // split each pooled axis into (input_size/kernel_size, kernel_size), then max pool over all kernel sizes - - let input_shape = &graph[input].shape; - let input_rank = input_shape.rank(); + let raw_input_shape = &graph[input].shape; + let input_rank = raw_input_shape.rank(); let kernel_rank = kernel_shape.len(); - // calculate reshaped shape - let (batch_shape, active_shape) = input_shape.split(input_rank - kernel_rank); + let kept_rank = input_rank - kernel_rank; + let (batch_shape, active_shape) = raw_input_shape.split(kept_rank); + + // calculate padding and reshaping + let mut pad_amounts = vec![(0, 0); kept_rank]; let mut reshape = batch_shape.dims.clone(); let mut pooled_dims = vec![]; + let mut slices = vec![None; kept_rank]; + for i in 0..kernel_rank { - let kernel_size = Size::fixed(kernel_shape[i] as usize); - let active_size = active_shape.dims[i]; + let kernel_size = kernel_shape[i] as usize; + let input_size = active_shape.dims[i]; - // TODO support non-dividing cases - let left = (active_size / kernel_size) - .ok_or_else(|| OnnxError::NonDividingPooling(node.to_owned(), input_shape.clone(), kernel_shape.to_vec()))?; + let div_rem = input_size.div_rem(kernel_size); + let (left, pad, slice) = match div_rem { + DivResult::Exact(left) => { + (left, (0, 0), None) + } + DivResult::Remainder(rem) => { + if ceil_mode { + let pad = kernel_size - rem; + let left = ((input_size + pad).unwrap() / kernel_size).unwrap(); + (left, (0, pad), None) + } else { + let left = ((input_size - rem).unwrap() / kernel_size).unwrap(); + let end = left.unwrap_fixed("pool dim size") * kernel_size; + let slice = SliceRange::new(0, end, 1); + (left, (0, 0), Some(slice)) + } + }, + DivResult::Impossible => { + return Err(OnnxError::NonDividingPooling(node.to_owned(), raw_input_shape.clone(), kernel_shape.to_vec())); + } + }; + pad_amounts.push(pad); reshape.push(left); pooled_dims.push(reshape.len()); - reshape.push(kernel_size); + reshape.push(Size::fixed(kernel_size)); + slices.push(slice); } let reshape = Shape::new(reshape); - // reshape and pool - let mid = graph.view(input, reshape); - let result = graph.reduce(mid, pooled_dims, ReduceOp::Max); + let operation = ReduceOp::Max; + let pad_value = operation.identity(graph[input].dtype); + + // add to graph + let pad_value = graph.scalar_dyn(pad_value); + let padded = graph.pad(input, &pad_amounts, pad_value); + let sliced = slices.iter().enumerate().fold(padded, |a, (i, &s)| { + if let Some(s) = s { + graph.slice(a, i, s) + } else { + a + } + }); + let reshaped = graph.view(sliced, reshape); + let result = graph.reduce(reshaped, pooled_dims, operation); OnnxValue::Value(result) } diff --git a/kn-graph/src/shape.rs b/kn-graph/src/shape.rs index ed187d6..7ac9faa 100644 --- a/kn-graph/src/shape.rs +++ b/kn-graph/src/shape.rs @@ -224,6 +224,13 @@ impl From for Size { } } +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub enum DivResult { + Exact(Size), + Remainder(usize), + Impossible, +} + impl Size { pub const ZERO: Size = Size::new(0, 0); pub const ONE: Size = Size::new(0, 1); @@ -292,6 +299,21 @@ impl Size { )) } } + + pub fn div_rem(self, rhs: impl Into) -> DivResult { + let rhs = rhs.into(); + let fixed_rem = self.fixed_factor % rhs.fixed_factor; + if self.batch_exp < rhs.batch_exp { + DivResult::Impossible + } else if fixed_rem != 0 { + DivResult::Remainder(fixed_rem) + } else { + DivResult::Exact(Size::new( + self.batch_exp - rhs.batch_exp, + self.fixed_factor / rhs.fixed_factor, + )) + } + } } impl ConcreteShape { @@ -445,14 +467,21 @@ impl> std::ops::Div for Size { type Output = Option; fn div(self, rhs: R) -> Self::Output { - let rhs = rhs.into(); - if self.batch_exp < rhs.batch_exp || self.fixed_factor % rhs.fixed_factor != 0 { - None - } else { - Some(Size::new( - self.batch_exp - rhs.batch_exp, - self.fixed_factor / rhs.fixed_factor, - )) + match self.div_rem(rhs) { + DivResult::Exact(s) => Some(s), + DivResult::Remainder(_) | DivResult::Impossible => None, + } + } +} + +impl> std::ops::Rem for Size { + type Output = Option; + + fn rem(self, rhs: R) -> Self::Output { + match self.div_rem(rhs) { + DivResult::Exact(_) => Some(0), + DivResult::Remainder(r) => Some(r), + DivResult::Impossible => None, } } }