diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..1a93cb52 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -135,6 +135,11 @@ lazy_static! { include_str!("../templates/endomorphism/broadcast.wgsl"), ) .unwrap(); + tera.add_raw_template( + "endomorphism/slice.wgsl", + include_str!("../templates/endomorphism/slice.wgsl"), + ) + .unwrap(); tera }; } @@ -440,6 +445,85 @@ pub fn compile( } } + "Slice" => { + // TODO @ Raphael + // Goal: Implementing slice in version 11: + // https://onnx.ai/onnx/operators/onnx__Slice.html#slice-11 + // 1. Create input buffer. + // 2. Create output buffer. + // 3. Provide attributes to compute shader. + // 4. Call compute shader over output. + // 5. Fill output in compute shader. + + // There must be at least starts and ends defined in the inputs + let input_count = input_lengths.len(); + if input_count < 3 { + return Err(CompileError::InvalidInputCount { + expected: 3, + actual: input_count, + }); + } + + // Print inputs. + println!("starts rank: {:?}", input_shapes[1].rank()); + println!("starts dim: {:?}", input_shapes[1].dim(0)); + println!("ends rank: {:?}", input_shapes[2].rank()); + println!("ends dim: {:?}", input_shapes[2].dim(0)); + if input_count > 3 { + println!("axes rank: {:?}", input_shapes[3].rank()); + println!("axes dim: {:?}", input_shapes[3].dim(0)); + } + if input_count > 4 { + println!("steps rank: {:?}", input_shapes[4].rank()); + println!("steps dim: {:?}", input_shapes[4].dim(0)); + } + + // TODO: Create fallback input buffers for axes and steps if not provided. + + // Copied from Gather to avoid compilation error. + // Input 0 is data, input 1 is indices + // Which axis to gather on. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data). + // Default is 0. See https://github.com/onnx/onnx/blob/main/docs/Operators.md#attributes-25 + let axis = node.get_attribute_value("axis", Some(0))?; + if axis != 0 { + return Err(CompileError::UnimplementedVariant { + variant: format!("axis={}", axis), + op: String::from("Gather"), + }); + } + + let elements_per_index = input_chunks[0][0]; + let scalar_type = agreed_type(&input_shapes[0..1], output_shapes)?; + let chunk_type = MultiType::for_size(elements_per_index as usize, scalar_type); + let chunk_size = chunk_type.elements(); + + // The X dimension represents the indexes + let (x_threads, workgroup_size_x) = workgroup_size( + input_lengths[1], + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_X, + )?; + + // The Y dimension represents the elements to copy for each index + let (y_threads, workgroup_size_y) = workgroup_size( + ceil(elements_per_index, chunk_size as u64), + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_Y, + )?; + + context.insert("chunk_type", &chunk_type.wgsl_type_name()); + context.insert("chunk_size", &chunk_size); + context.insert("workgroup_size_x", &workgroup_size_x); + context.insert("workgroup_size_y", &workgroup_size_y); + + NodeTemplate { + scalar_type, + template: "endomorphism/slice.wgsl", + threads: (x_threads, y_threads, 1), + } + // Above to be removed / replaced. + } + "Cast" => { let cast_to_type = ScalarType::from_i32(node.get_attribute_value::("to", None)? as i32)?; diff --git a/wonnx/templates/endomorphism/slice.wgsl b/wonnx/templates/endomorphism/slice.wgsl new file mode 100644 index 00000000..2234dc03 --- /dev/null +++ b/wonnx/templates/endomorphism/slice.wgsl @@ -0,0 +1,23 @@ +{%- include "structs.wgsl" -%} + +struct Indices { + data: array +}; + +struct Chunk { + data: array<{{ chunk_type }}> +}; + +@group(0) @binding(0) +var input_0: Chunk; // data + +@group(0) @binding(1) +var input_1: Indices; // indices + +@group(0) @binding(2) +var output_0: Chunk; + +@compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}) +fn main(@builtin(global_invocation_id) global_id: vec3) { + // TODO @ Raphael +} \ No newline at end of file diff --git a/wonnx/tests/slice.rs b/wonnx/tests/slice.rs new file mode 100644 index 00000000..53b1cddc --- /dev/null +++ b/wonnx/tests/slice.rs @@ -0,0 +1,92 @@ +use std::{collections::HashMap, convert::TryInto}; +use wonnx::{ + onnx::ValueInfoProto, + utils::{graph, model, node, tensor}, +}; +mod common; + +fn assert_slice( + data: &[f32], + data_shape: &[i64], + output: &[f32], + output_shape: &[i64], + starts: &[f32], + ends: &[f32], + axes: Option>, + steps: Option> +) { + let mut input_data = HashMap::new(); + let mut input_shapes: Vec = vec![]; + let mut input_names: Vec<&str> = vec![]; + + let starts_lengths = vec![starts.len() as i64]; + let ends_lengths = vec![ends.len() as i64]; + + input_data.insert("X".to_string(), data.into()); + input_shapes.push(tensor("X", data_shape)); + input_names.push("X"); + + input_data.insert("S".to_string(), starts.into()); + input_shapes.push(tensor("S", &starts_lengths[..])); + input_names.push("S"); + + input_data.insert("E".to_string(), ends.into()); + input_shapes.push(tensor("E", &ends_lengths[..])); + input_names.push("E"); + + let mut axes_unwraped: Vec = vec![]; + if axes.is_some() { + axes_unwraped = axes.unwrap(); + let axes_lengths = vec![axes_unwraped.len() as i64]; + input_data.insert("A".to_string(), (&axes_unwraped[..]).into()); + input_shapes.push(tensor("A", &axes_lengths[..])); + input_names.push("A"); + } + + let mut steps_unwraped: Vec = vec![]; + if steps.is_some() { + steps_unwraped = steps.unwrap(); + let steps_lengths = vec![steps_unwraped.len() as i64]; + input_data.insert("P".to_string(), (&steps_unwraped[..]).into()); + input_shapes.push(tensor("P", &steps_lengths[..])); + input_names.push("P"); + } + + // Model: (X, S, E, A?, P?) -> Slice -> Y + let bn_model = model(graph( + input_shapes, + vec![tensor("Y", output_shape)], + vec![], + vec![], + vec![node( + input_names, + vec!["Y"], + "mySlice", + "Slice", + vec![] + )], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(bn_model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + common::assert_eq_vector((&result["Y"]).try_into().unwrap(), output); +} + +#[test] +fn slice() { + let _ = env_logger::builder().is_test(true).try_init(); + + // Example 1 from https://onnx.ai/onnx/operators/onnx__Slice.html#slice. + assert_slice( + &[1., 2., 3., 4., 5., 6., 7., 8.], + &[4, 2], + &[5., 7.], + &[2, 1], + &[1., 0.], + &[2., 3.], + Some(vec![0., 1.]), + Some(vec![1., 2.]) + ); +}