Skip to content

Commit

Permalink
Add ceil and floor for non-multiple MaxPooling.
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelPeeters committed Mar 9, 2024
1 parent 77f7e1b commit fb9016f
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 30 deletions.
2 changes: 1 addition & 1 deletion kn-cuda-eval/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ impl<'a> Planner<'a> {
let input = self.visit(input)?;
let output = self.alloc_tensor_shared(result_shape, result_dtype, Some(value));

let identity = op.identity();
let identity = op.identity_t();
let (operation, is_mean) = op.operation();

let post_process = if is_mean {
Expand Down
15 changes: 9 additions & 6 deletions kn-graph/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1739,15 +1739,18 @@ impl ReduceOp {
ReduceOp::Max,
];

pub fn identity<T: IntoDScalar>(self) -> T {
let specials = T::DTYPE.specials();
let result = match self {
pub fn identity(self, dtype: DType) -> DScalar {
let specials = dtype.specials();
match self {
ReduceOp::Sum | ReduceOp::Mean => specials.zero,
ReduceOp::Prod => specials.one,
ReduceOp::Min => specials.max,
ReduceOp::Max => specials.min,
};
T::from_dscalar(result).unwrap()
}
}

pub fn identity_t<T: IntoDScalar>(self) -> T {
T::from_dscalar(self.identity(T::DTYPE)).unwrap()
}

pub fn operation(self) -> (BinaryOp, bool) {
Expand All @@ -1764,7 +1767,7 @@ impl ReduceOp {
let (op, is_mean) = self.operation();

let mut count = 0;
let total = seq.into_iter().fold(self.identity(), |acc, x| {
let total = seq.into_iter().fold(self.identity_t(), |acc, x| {
count += 1;
op.map_t(acc, x)
});
Expand Down
67 changes: 52 additions & 15 deletions kn-graph/src/onnx/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::onnx::result::{Node, OnnxError, OnnxResult, UnwrapProto};
use crate::onnx::store::Store;
use crate::onnx::typed_value::{OnnxValue, SignedSize};
use crate::shape;
use crate::shape::{Shape, Size};
use crate::shape::{DivResult, Shape, Size};

// TODO we should switch to taking an extra `HashMap<String, Size>` parameter,
// so the user can decide which named axes match to what size or even the batch size
Expand Down Expand Up @@ -947,6 +947,7 @@ fn visit_node(
let strides = attrs.take_ints("strides")?;
let kernel_shape = attrs.take_ints("kernel_shape")?;
let pads = attrs.take_ints("pads")?;
let ceil_mode = attrs.maybe_take_int("ceil_mode")?.unwrap_or(0) != 0;
let auto_pad = attrs.maybe_take_string("auto_pad")?;

assert_eq!(strides, kernel_shape, "Real strides not supported yet");
Expand All @@ -955,32 +956,68 @@ fn visit_node(

// max pool the last N dimensions:
// split each pooled axis into (input_size/kernel_size, kernel_size), then max pool over all kernel sizes

let input_shape = &graph[input].shape;
let input_rank = input_shape.rank();
let raw_input_shape = &graph[input].shape;
let input_rank = raw_input_shape.rank();
let kernel_rank = kernel_shape.len();

// calculate reshaped shape
let (batch_shape, active_shape) = input_shape.split(input_rank - kernel_rank);
let kept_rank = input_rank - kernel_rank;
let (batch_shape, active_shape) = raw_input_shape.split(kept_rank);

// calculate padding and reshaping
let mut pad_amounts = vec![(0, 0); kept_rank];
let mut reshape = batch_shape.dims.clone();
let mut pooled_dims = vec![];
let mut slices = vec![None; kept_rank];

for i in 0..kernel_rank {
let kernel_size = Size::fixed(kernel_shape[i] as usize);
let active_size = active_shape.dims[i];
let kernel_size = kernel_shape[i] as usize;
let input_size = active_shape.dims[i];

// TODO support non-dividing cases
let left = (active_size / kernel_size)
.ok_or_else(|| OnnxError::NonDividingPooling(node.to_owned(), input_shape.clone(), kernel_shape.to_vec()))?;
let div_rem = input_size.div_rem(kernel_size);
let (left, pad, slice) = match div_rem {
DivResult::Exact(left) => {
(left, (0, 0), None)
}
DivResult::Remainder(rem) => {
if ceil_mode {
let pad = kernel_size - rem;
let left = ((input_size + pad).unwrap() / kernel_size).unwrap();
(left, (0, pad), None)
} else {
let left = ((input_size - rem).unwrap() / kernel_size).unwrap();
let end = left.unwrap_fixed("pool dim size") * kernel_size;
let slice = SliceRange::new(0, end, 1);
(left, (0, 0), Some(slice))
}
},
DivResult::Impossible => {
return Err(OnnxError::NonDividingPooling(node.to_owned(), raw_input_shape.clone(), kernel_shape.to_vec()));
}
};

pad_amounts.push(pad);
reshape.push(left);
pooled_dims.push(reshape.len());
reshape.push(kernel_size);
reshape.push(Size::fixed(kernel_size));
slices.push(slice);
}
let reshape = Shape::new(reshape);

// reshape and pool
let mid = graph.view(input, reshape);
let result = graph.reduce(mid, pooled_dims, ReduceOp::Max);
let operation = ReduceOp::Max;
let pad_value = operation.identity(graph[input].dtype);

// add to graph
let pad_value = graph.scalar_dyn(pad_value);
let padded = graph.pad(input, &pad_amounts, pad_value);
let sliced = slices.iter().enumerate().fold(padded, |a, (i, &s)| {
if let Some(s) = s {
graph.slice(a, i, s)
} else {
a
}
});
let reshaped = graph.view(sliced, reshape);
let result = graph.reduce(reshaped, pooled_dims, operation);

OnnxValue::Value(result)
}
Expand Down
45 changes: 37 additions & 8 deletions kn-graph/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ impl From<usize> for Size {
}
}

#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum DivResult {
Exact(Size),
Remainder(usize),
Impossible,
}

impl Size {
pub const ZERO: Size = Size::new(0, 0);
pub const ONE: Size = Size::new(0, 1);
Expand Down Expand Up @@ -292,6 +299,21 @@ impl Size {
))
}
}

pub fn div_rem(self, rhs: impl Into<Size>) -> DivResult {
let rhs = rhs.into();
let fixed_rem = self.fixed_factor % rhs.fixed_factor;
if self.batch_exp < rhs.batch_exp {
DivResult::Impossible
} else if fixed_rem != 0 {
DivResult::Remainder(fixed_rem)
} else {
DivResult::Exact(Size::new(
self.batch_exp - rhs.batch_exp,
self.fixed_factor / rhs.fixed_factor,
))
}
}
}

impl ConcreteShape {
Expand Down Expand Up @@ -445,14 +467,21 @@ impl<R: Into<Size>> std::ops::Div<R> for Size {
type Output = Option<Size>;

fn div(self, rhs: R) -> Self::Output {
let rhs = rhs.into();
if self.batch_exp < rhs.batch_exp || self.fixed_factor % rhs.fixed_factor != 0 {
None
} else {
Some(Size::new(
self.batch_exp - rhs.batch_exp,
self.fixed_factor / rhs.fixed_factor,
))
match self.div_rem(rhs) {
DivResult::Exact(s) => Some(s),
DivResult::Remainder(_) | DivResult::Impossible => None,
}
}
}

impl<R: Into<Size>> std::ops::Rem<R> for Size {
type Output = Option<usize>;

fn rem(self, rhs: R) -> Self::Output {
match self.div_rem(rhs) {
DivResult::Exact(_) => Some(0),
DivResult::Remainder(r) => Some(r),
DivResult::Impossible => None,
}
}
}
Expand Down

0 comments on commit fb9016f

Please sign in to comment.