From 4a01e1c98ebdeb689cb66aee6c45e2643b401967 Mon Sep 17 00:00:00 2001 From: Raphael Menges Date: Thu, 1 Jun 2023 22:27:57 +0200 Subject: [PATCH 1/5] Understanding how to add another operator. --- wonnx/src/compiler.rs | 52 +++++++++++++++++++++++++ wonnx/templates/endomorphism/slice.wgsl | 23 +++++++++++ 2 files changed, 75 insertions(+) create mode 100644 wonnx/templates/endomorphism/slice.wgsl diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..410fe234 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -135,6 +135,11 @@ lazy_static! { include_str!("../templates/endomorphism/broadcast.wgsl"), ) .unwrap(); + tera.add_raw_template( + "endomorphism/slice.wgsl", + include_str!("../templates/endomorphism/slice.wgsl"), + ) + .unwrap(); tera }; } @@ -440,6 +445,53 @@ pub fn compile( } } + "Slice" => { + // TODO @ Raphael + + // Copied from Gather. + // Input 0 is data, input 1 is indices + // Which axis to gather on. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data). + // Default is 0. See https://github.com/onnx/onnx/blob/main/docs/Operators.md#attributes-25 + let axis = node.get_attribute_value("axis", Some(0))?; + if axis != 0 { + return Err(CompileError::UnimplementedVariant { + variant: format!("axis={}", axis), + op: String::from("Gather"), + }); + } + + let elements_per_index = input_chunks[0][0]; + let scalar_type = agreed_type(&input_shapes[0..1], output_shapes)?; + let chunk_type = MultiType::for_size(elements_per_index as usize, scalar_type); + let chunk_size = chunk_type.elements(); + + // The X dimension represents the indexes + let (x_threads, workgroup_size_x) = workgroup_size( + input_lengths[1], + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_X, + )?; + + // The Y dimension represents the elements to copy for each index + let (y_threads, workgroup_size_y) = workgroup_size( + ceil(elements_per_index, chunk_size as u64), + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_Y, + )?; + + context.insert("chunk_type", &chunk_type.wgsl_type_name()); + context.insert("chunk_size", &chunk_size); + context.insert("workgroup_size_x", &workgroup_size_x); + context.insert("workgroup_size_y", &workgroup_size_y); + + + NodeTemplate { + scalar_type, + template: "endomorphism/slice.wgsl", + threads: (x_threads, y_threads, 1), + } + } + "Cast" => { let cast_to_type = ScalarType::from_i32(node.get_attribute_value::("to", None)? as i32)?; diff --git a/wonnx/templates/endomorphism/slice.wgsl b/wonnx/templates/endomorphism/slice.wgsl new file mode 100644 index 00000000..2234dc03 --- /dev/null +++ b/wonnx/templates/endomorphism/slice.wgsl @@ -0,0 +1,23 @@ +{%- include "structs.wgsl" -%} + +struct Indices { + data: array +}; + +struct Chunk { + data: array<{{ chunk_type }}> +}; + +@group(0) @binding(0) +var input_0: Chunk; // data + +@group(0) @binding(1) +var input_1: Indices; // indices + +@group(0) @binding(2) +var output_0: Chunk; + +@compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}) +fn main(@builtin(global_invocation_id) global_id: vec3) { + // TODO @ Raphael +} \ No newline at end of file From 05202639b25d4020478135dbdc717a04ea08a16d Mon Sep 17 00:00:00 2001 From: Raphael Menges Date: Fri, 2 Jun 2023 10:17:31 +0200 Subject: [PATCH 2/5] Attempt to get attributes from slice operator node. --- wonnx/src/compiler.rs | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index 410fe234..da27d662 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -447,8 +447,31 @@ pub fn compile( "Slice" => { // TODO @ Raphael - - // Copied from Gather. + // Goal: Implementing slice in version 11: + // https://onnx.ai/onnx/operators/onnx__Slice.html#slice-11 + // 1. Get attributes out of node. + // 2. Create input and output buffer. + // 3. Provide attributes to compute shader. + // 4. Call compute shader over output. + // 5. Fill output in compute shader. + + println!("Slice by Raphael"); + println!("op_type: {:?}", &node.get_op_type()); + println!("opset_version: {:?}", &opset_version); + + // Get attributes from node. + let starts = node.get_attribute_value("starts", Some(vec![0]))?; + let ends = node.get_attribute_value("ends", Some(vec![0]))?; + let axes = node.get_attribute_value("axes", Some(vec![0]))?; + let steps = node.get_attribute_value("steps", Some(vec![1]))?; + + // Print attributes. + println!("starts: {:?}", &starts); + println!("ends: {:?}", &ends); + println!("axes: {:?}", &axes); + println!("steps: {:?}", &steps); + + // Copied from Gather to avoid compilation error. // Input 0 is data, input 1 is indices // Which axis to gather on. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data). // Default is 0. See https://github.com/onnx/onnx/blob/main/docs/Operators.md#attributes-25 @@ -484,12 +507,12 @@ pub fn compile( context.insert("workgroup_size_x", &workgroup_size_x); context.insert("workgroup_size_y", &workgroup_size_y); - NodeTemplate { scalar_type, template: "endomorphism/slice.wgsl", threads: (x_threads, y_threads, 1), } + // Above to be removed / replaced. } "Cast" => { From 2cfc0af9448beab75b2483860abfe683ce46fe4e Mon Sep 17 00:00:00 2001 From: raphaelmenges Date: Fri, 2 Jun 2023 16:03:36 +0200 Subject: [PATCH 3/5] Add test example for slice operator. --- wonnx/tests/slice.rs | 70 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 wonnx/tests/slice.rs diff --git a/wonnx/tests/slice.rs b/wonnx/tests/slice.rs new file mode 100644 index 00000000..7481c1df --- /dev/null +++ b/wonnx/tests/slice.rs @@ -0,0 +1,70 @@ +use std::{collections::HashMap, convert::TryInto}; +use wonnx::{ + onnx::AttributeProto, + utils::{attribute, graph, model, node, tensor}, +}; +mod common; + +fn assert_slice( + data: &[f32], + data_shape: &[i64], + output: &[f32], + output_shape: &[i64], + starts: &[i64], + ends: &[i64], + axes: Option>, + steps: Option> +) { + let mut input_data = HashMap::new(); + + input_data.insert("X".to_string(), data.into()); + + let mut attributes: Vec = vec![ + attribute("starts", starts.to_vec()), + attribute("ends", ends.to_vec()) + ]; + if let Some(axes) = axes { + attributes.push(attribute("axes", axes)); + } + if let Some(steps) = steps { + attributes.push(attribute("steps", steps)); + } + + // Model: (X) -> Slice -> Y + let bn_model = model(graph( + vec![tensor("X", data_shape)], + vec![tensor("Y", output_shape)], + vec![], + vec![], + vec![node( + vec!["X"], + vec!["Y"], + "mySlice", + "Slice", + attributes + )], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(bn_model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + common::assert_eq_vector((&result["Y"]).try_into().unwrap(), output); +} + +#[test] +fn slice() { + let _ = env_logger::builder().is_test(true).try_init(); + + // Example 1 from https://onnx.ai/onnx/operators/onnx__Slice.html#slice. + assert_slice( + &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + &[4, 2], + &[5.0, 7.0], + &[2 ,1], + &[1, 0], + &[2, 3], + Some(vec![0, 1]), + Some(vec![1, 2]) + ); +} From a066e267588c8273b93cbfa88e5a090c08933bed Mon Sep 17 00:00:00 2001 From: Raphael Menges Date: Mon, 5 Jun 2023 10:16:12 +0200 Subject: [PATCH 4/5] Change parameters in Slice operator to be inputs. --- wonnx/src/compiler.rs | 43 +++++++++++++++++------------ wonnx/tests/slice.rs | 64 +++++++++++++++++++++++++++---------------- 2 files changed, 67 insertions(+), 40 deletions(-) diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index da27d662..1a93cb52 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -449,27 +449,36 @@ pub fn compile( // TODO @ Raphael // Goal: Implementing slice in version 11: // https://onnx.ai/onnx/operators/onnx__Slice.html#slice-11 - // 1. Get attributes out of node. - // 2. Create input and output buffer. + // 1. Create input buffer. + // 2. Create output buffer. // 3. Provide attributes to compute shader. // 4. Call compute shader over output. // 5. Fill output in compute shader. - println!("Slice by Raphael"); - println!("op_type: {:?}", &node.get_op_type()); - println!("opset_version: {:?}", &opset_version); - - // Get attributes from node. - let starts = node.get_attribute_value("starts", Some(vec![0]))?; - let ends = node.get_attribute_value("ends", Some(vec![0]))?; - let axes = node.get_attribute_value("axes", Some(vec![0]))?; - let steps = node.get_attribute_value("steps", Some(vec![1]))?; - - // Print attributes. - println!("starts: {:?}", &starts); - println!("ends: {:?}", &ends); - println!("axes: {:?}", &axes); - println!("steps: {:?}", &steps); + // There must be at least starts and ends defined in the inputs + let input_count = input_lengths.len(); + if input_count < 3 { + return Err(CompileError::InvalidInputCount { + expected: 3, + actual: input_count, + }); + } + + // Print inputs. + println!("starts rank: {:?}", input_shapes[1].rank()); + println!("starts dim: {:?}", input_shapes[1].dim(0)); + println!("ends rank: {:?}", input_shapes[2].rank()); + println!("ends dim: {:?}", input_shapes[2].dim(0)); + if input_count > 3 { + println!("axes rank: {:?}", input_shapes[3].rank()); + println!("axes dim: {:?}", input_shapes[3].dim(0)); + } + if input_count > 4 { + println!("steps rank: {:?}", input_shapes[4].rank()); + println!("steps dim: {:?}", input_shapes[4].dim(0)); + } + + // TODO: Create fallback input buffers for axes and steps if not provided. // Copied from Gather to avoid compilation error. // Input 0 is data, input 1 is indices diff --git a/wonnx/tests/slice.rs b/wonnx/tests/slice.rs index 7481c1df..8b4f5d15 100644 --- a/wonnx/tests/slice.rs +++ b/wonnx/tests/slice.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, convert::TryInto}; use wonnx::{ - onnx::AttributeProto, - utils::{attribute, graph, model, node, tensor}, + onnx::ValueInfoProto, + utils::{graph, model, node, tensor}, }; mod common; @@ -10,38 +10,56 @@ fn assert_slice( data_shape: &[i64], output: &[f32], output_shape: &[i64], - starts: &[i64], - ends: &[i64], - axes: Option>, - steps: Option> + starts: &[f32], + ends: &[f32], + axes: Option>, + steps: Option> ) { let mut input_data = HashMap::new(); + let mut input_shapes: Vec = vec![]; + let mut input_names: Vec<&str> = vec![]; + + let starts_lengths = vec![starts.len() as i64]; + let ends_lengths = vec![ends.len() as i64]; input_data.insert("X".to_string(), data.into()); + input_shapes.push(tensor("X", data_shape)); + input_names.push("X"); + + input_data.insert("S".to_string(), starts.into()); + input_shapes.push(tensor("S", &starts_lengths[..])); + input_names.push("S"); + + input_data.insert("E".to_string(), ends.into()); + input_shapes.push(tensor("E", &ends_lengths[..])); + input_names.push("E"); - let mut attributes: Vec = vec![ - attribute("starts", starts.to_vec()), - attribute("ends", ends.to_vec()) - ]; if let Some(axes) = axes { - attributes.push(attribute("axes", axes)); + let axes_lengths = vec![axes.len() as i64]; + input_data.insert("A".to_string(), (&axes[..]).into()); // TODO: Lifetime issues + input_shapes.push(tensor("A", &axes_lengths[..])); + input_names.push("A"); } + if let Some(steps) = steps { - attributes.push(attribute("steps", steps)); + let steps_lengths = vec![steps.len() as i64]; + input_data.insert("P".to_string(), (&steps[..]).into()); // TODO: Lifetime issues + input_shapes.push(tensor("P", &steps_lengths[..])); + input_names.push("P"); } - // Model: (X) -> Slice -> Y + // Model: (X, S, E, A?, P?) -> Slice -> Y let bn_model = model(graph( - vec![tensor("X", data_shape)], + input_shapes, vec![tensor("Y", output_shape)], vec![], vec![], vec![node( - vec!["X"], + input_names, vec!["Y"], "mySlice", "Slice", - attributes + vec![] )], )); @@ -58,13 +76,13 @@ fn slice() { // Example 1 from https://onnx.ai/onnx/operators/onnx__Slice.html#slice. assert_slice( - &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + &[1., 2., 3., 4., 5., 6., 7., 8.], &[4, 2], - &[5.0, 7.0], - &[2 ,1], - &[1, 0], - &[2, 3], - Some(vec![0, 1]), - Some(vec![1, 2]) + &[5., 7.], + &[2, 1], + &[1., 0.], + &[2., 3.], + Some(vec![0., 1.]), + Some(vec![1., 2.]) ); } From dac263fab7386729098ed988bd0d27bcfbde2f4b Mon Sep 17 00:00:00 2001 From: Raphael Menges Date: Wed, 7 Jun 2023 15:57:23 +0200 Subject: [PATCH 5/5] Fix lifetime issues in slice test. --- wonnx/tests/slice.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/wonnx/tests/slice.rs b/wonnx/tests/slice.rs index 8b4f5d15..53b1cddc 100644 --- a/wonnx/tests/slice.rs +++ b/wonnx/tests/slice.rs @@ -34,16 +34,20 @@ fn assert_slice( input_shapes.push(tensor("E", &ends_lengths[..])); input_names.push("E"); - if let Some(axes) = axes { - let axes_lengths = vec![axes.len() as i64]; - input_data.insert("A".to_string(), (&axes[..]).into()); // TODO: Lifetime issues + let mut axes_unwraped: Vec = vec![]; + if axes.is_some() { + axes_unwraped = axes.unwrap(); + let axes_lengths = vec![axes_unwraped.len() as i64]; + input_data.insert("A".to_string(), (&axes_unwraped[..]).into()); input_shapes.push(tensor("A", &axes_lengths[..])); input_names.push("A"); } - if let Some(steps) = steps { - let steps_lengths = vec![steps.len() as i64]; - input_data.insert("P".to_string(), (&steps[..]).into()); // TODO: Lifetime issues + let mut steps_unwraped: Vec = vec![]; + if steps.is_some() { + steps_unwraped = steps.unwrap(); + let steps_lengths = vec![steps_unwraped.len() as i64]; + input_data.insert("P".to_string(), (&steps_unwraped[..]).into()); input_shapes.push(tensor("P", &steps_lengths[..])); input_names.push("P"); }