Skip to content

Commit

Permalink
Bugfixes in load to enable loading "ocl_net".
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelPeeters committed Jul 11, 2024
1 parent 9fff208 commit 877997a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 28 deletions.
84 changes: 57 additions & 27 deletions kn-graph/src/onnx/load.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::PathBuf;

use byteorder::{ByteOrder, LittleEndian};
use itertools::Itertools;
use itertools::{Itertools, zip_eq};
use ndarray::{Axis, azip};
use prost::Message;

Expand Down Expand Up @@ -79,26 +79,20 @@ pub fn graph_from_onnx_bytes(buf: &[u8], external: &mut dyn ExternalDataLoader,
op_type: node_proto.op_type.as_str(),
};

if node_proto.output.len() != 1 {
return Err(OnnxError::UnsupportedMultipleOutputs(
node.to_owned(),
node_proto.output.clone(),
));
}
let output_name = &node_proto.output[0];

let mut attrs = Attributes::from(node, &node_proto.attribute);
let mut inputs = Inputs::from(node, &node_proto.input, &nodes)?;

let value: OnnxValue = visit_node(&mut graph, external, node, &mut inputs, &mut attrs)?;
let values: Vec<OnnxValue> = visit_node(&mut graph, external, node, &mut inputs, &mut attrs)?;

// set debug id for all newly created nodes to the current node name
for value in graph.take_new_values() {
graph.set_debug_id(value, node.name.to_owned())
}

// check that the value if only a size if necessary
value.assert_valid();
for value in &values {
value.assert_valid();
}

// check that we used all attributes and inputs
let leftover_attributes = attrs.leftover();
Expand All @@ -110,8 +104,12 @@ pub fn graph_from_onnx_bytes(buf: &[u8], external: &mut dyn ExternalDataLoader,
return Err(OnnxError::LeftoverInputs(node.to_owned(), leftover_inputs));
}

// actually define the current node
nodes.define(output_name, value);
// actually define the result values
let output_names = &node_proto.output;
assert_eq!(output_names.len(), values.len(), "Expected {:?} outputs, got {}", output_names, values.len());
for (name, value) in zip_eq(output_names, values) {
nodes.define(name, value);
}
}

for output in &model_graph.output {
Expand All @@ -131,19 +129,21 @@ fn visit_node(
node: Node<&str>,
inputs: &mut Inputs,
attrs: &mut Attributes,
) -> OnnxResult<OnnxValue> {
let result = match node.op_type {
) -> OnnxResult<Vec<OnnxValue>> {
let result_single = match node.op_type {
"Conv" => {
let input = inputs.required(0)?.unwrap_value().unwrap();
let filter = inputs.required(1)?.unwrap_value().unwrap();
let bias_raw = inputs.optional(2).map(|v| v.unwrap_value().unwrap());

let groups = attrs.take_int("group")?;
let groups = attrs.maybe_take_int("group")?.unwrap_or(1);
let kernel_shape = attrs.take_ints("kernel_shape")?;
let strides = attrs.take_ints("strides")?;
let dilations = attrs.take_ints("dilations")?;

let conv_rank = kernel_shape.len();
let strides = attrs.maybe_take_ints("strides")?
.map_or(vec![1; conv_rank], |strides| strides.to_vec());
let dilations = attrs.maybe_take_ints("dilations")?
.map_or(vec![1; conv_rank], |strides| strides.to_vec());

let auto_pad = attrs.maybe_take_string("auto_pad")?;

let padding = match auto_pad {
Expand All @@ -153,11 +153,11 @@ fn visit_node(
}
Some("SAME_UPPER") => {
// input and output same size, excess on upper side of dim
calculate_auto_padding(graph, conv_rank, input, filter, strides, dilations, true)?
calculate_auto_padding(graph, conv_rank, input, filter, &strides, &dilations, true)?
}
Some("SAME_LOWER") => {
// input and output same size, excess on lower side of dim
calculate_auto_padding(graph, conv_rank, input, filter, strides, dilations, false)?
calculate_auto_padding(graph, conv_rank, input, filter, &strides, &dilations, false)?
}
Some("VALID") => {
// no padding
Expand All @@ -184,8 +184,8 @@ fn visit_node(
1 => {
let kernel_size0 = unwrap_1(kernel_shape);
let [padding_0, padding_1] = unwrap_2(&padding);
let stride = unwrap_1(strides);
let dilation = unwrap_1(dilations);
let stride = unwrap_1(&strides);
let dilation = unwrap_1(&dilations);

let [_, _, kernel_size1] = filter_shape.unwrap_3();

Expand All @@ -207,8 +207,8 @@ fn visit_node(
2 => {
let [kernel_h0, kernel_w0] = unwrap_2(kernel_shape);
let [padding_y0, padding_x0, padding_y1, padding_x1] = unwrap_4(&padding);
let [stride_y, stride_x] = unwrap_2(strides);
let [dilation_y, dilation_x] = unwrap_2(dilations);
let [stride_y, stride_x] = unwrap_2(&strides);
let [dilation_y, dilation_x] = unwrap_2(&dilations);

let [_, _, kernel_h1, kernel_w1] = filter_shape.unwrap_4();

Expand Down Expand Up @@ -835,6 +835,36 @@ fn visit_node(
OnnxValue::Value(result)
}
}
"Split" => {
// TODO support "num_outputs" and "split" attribute/input
let input = inputs.required(0)?;
let shape = input.shape(graph);

let axis = attrs.take_int("axis")?;
let axis = abs_axis(axis, shape.rank());

let num_outputs = 2;
let size = shape[axis].unwrap_fixed("Split axis length");

let len_first = (size + num_outputs - 1) / num_outputs;

let result = match input {
&OnnxValue::Value(input) => {
vec![
OnnxValue::Value(graph.slice(input, axis, SliceRange::simple(0, len_first))),
OnnxValue::Value(graph.slice(input, axis, SliceRange::simple(len_first, size))),
]
}
OnnxValue::Size(input) => {
vec![
OnnxValue::new_size(input.slice_axis(Axis(axis), ndarray::Slice::from(..len_first)).into_owned().into_shared(), graph),
OnnxValue::new_size(input.slice_axis(Axis(axis), ndarray::Slice::from(len_first..)).into_owned().into_shared(), graph),
]
}
};

return Ok(result);
}
"Pad" => {
// operands
let input = inputs.required(0)?.unwrap_value().unwrap();
Expand Down Expand Up @@ -938,7 +968,7 @@ fn visit_node(
let result_shape = if keep_dims {
input_shape.replace_all(&axes, shape![1])
} else {
input_shape
input_shape.replace_all(&axes, shape![])
};

let result = graph.reduce(input, axes, op);
Expand Down Expand Up @@ -1119,7 +1149,7 @@ fn visit_node(
}
};

Ok(result)
Ok(vec![result_single])
}

fn calculate_auto_padding(graph: &Graph, conv_rank: usize, input: Value, filter: Value, strides: &[i64], dilations: &[i64], up: bool) -> OnnxResult<Vec<i64>> {
Expand Down
2 changes: 1 addition & 1 deletion kn-graph/src/onnx/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct Node<S = String> {
pub op_type: S,
}

// TODO remove variants that are never constructed
#[derive(Debug)]
pub enum OnnxError {
IO(PathBuf, io::Error),
Expand All @@ -39,7 +40,6 @@ pub enum OnnxError {

UnsupportedOperation(Node),

UnsupportedMultipleOutputs(Node, Vec<String>),
UnsupportedNonFloatOutput(String),
UnsupportedType(String, DataType),

Expand Down

0 comments on commit 877997a

Please sign in to comment.