Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for slice operator #171

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ lazy_static! {
include_str!("../templates/endomorphism/broadcast.wgsl"),
)
.unwrap();
tera.add_raw_template(
"endomorphism/slice.wgsl",
include_str!("../templates/endomorphism/slice.wgsl"),
)
.unwrap();
tera
};
}
Expand Down Expand Up @@ -440,6 +445,85 @@ pub fn compile(
}
}

"Slice" => {
// TODO @ Raphael
// Goal: Implementing slice in version 11:
// https://onnx.ai/onnx/operators/onnx__Slice.html#slice-11
// 1. Create input buffer.
// 2. Create output buffer.
// 3. Provide attributes to compute shader.
// 4. Call compute shader over output.
// 5. Fill output in compute shader.

// There must be at least starts and ends defined in the inputs
let input_count = input_lengths.len();
if input_count < 3 {
return Err(CompileError::InvalidInputCount {
expected: 3,
actual: input_count,
});
}

// Print inputs.
println!("starts rank: {:?}", input_shapes[1].rank());
println!("starts dim: {:?}", input_shapes[1].dim(0));
println!("ends rank: {:?}", input_shapes[2].rank());
println!("ends dim: {:?}", input_shapes[2].dim(0));
if input_count > 3 {
println!("axes rank: {:?}", input_shapes[3].rank());
println!("axes dim: {:?}", input_shapes[3].dim(0));
}
if input_count > 4 {
println!("steps rank: {:?}", input_shapes[4].rank());
println!("steps dim: {:?}", input_shapes[4].dim(0));
}

// TODO: Create fallback input buffers for axes and steps if not provided.

// Copied from Gather to avoid compilation error.
// Input 0 is data, input 1 is indices
// Which axis to gather on. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).
// Default is 0. See https://github.com/onnx/onnx/blob/main/docs/Operators.md#attributes-25
let axis = node.get_attribute_value("axis", Some(0))?;
if axis != 0 {
return Err(CompileError::UnimplementedVariant {
variant: format!("axis={}", axis),
op: String::from("Gather"),
});
}

let elements_per_index = input_chunks[0][0];
let scalar_type = agreed_type(&input_shapes[0..1], output_shapes)?;
let chunk_type = MultiType::for_size(elements_per_index as usize, scalar_type);
let chunk_size = chunk_type.elements();

// The X dimension represents the indexes
let (x_threads, workgroup_size_x) = workgroup_size(
input_lengths[1],
MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
MAX_WORKGROUP_SIZE_X,
)?;

// The Y dimension represents the elements to copy for each index
let (y_threads, workgroup_size_y) = workgroup_size(
ceil(elements_per_index, chunk_size as u64),
MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
MAX_WORKGROUP_SIZE_Y,
)?;

context.insert("chunk_type", &chunk_type.wgsl_type_name());
context.insert("chunk_size", &chunk_size);
context.insert("workgroup_size_x", &workgroup_size_x);
context.insert("workgroup_size_y", &workgroup_size_y);

NodeTemplate {
scalar_type,
template: "endomorphism/slice.wgsl",
threads: (x_threads, y_threads, 1),
}
// Above to be removed / replaced.
}

"Cast" => {
let cast_to_type =
ScalarType::from_i32(node.get_attribute_value::<i64>("to", None)? as i32)?;
Expand Down
23 changes: 23 additions & 0 deletions wonnx/templates/endomorphism/slice.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{%- include "structs.wgsl" -%}

struct Indices {
data: array<i32>
};

struct Chunk {
data: array<{{ chunk_type }}>
};

@group(0) @binding(0)
var<storage, read> input_0: Chunk; // data

@group(0) @binding(1)
var<storage, read> input_1: Indices; // indices

@group(0) @binding(2)
var<storage, read_write> output_0: Chunk;

@compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }})
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
// TODO @ Raphael
}
92 changes: 92 additions & 0 deletions wonnx/tests/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use std::{collections::HashMap, convert::TryInto};
use wonnx::{
onnx::ValueInfoProto,
utils::{graph, model, node, tensor},
};
mod common;

fn assert_slice(
data: &[f32],
data_shape: &[i64],
output: &[f32],
output_shape: &[i64],
starts: &[f32],
ends: &[f32],
axes: Option<Vec<f32>>,
steps: Option<Vec<f32>>
) {
let mut input_data = HashMap::new();
let mut input_shapes: Vec<ValueInfoProto> = vec![];
let mut input_names: Vec<&str> = vec![];

let starts_lengths = vec![starts.len() as i64];
let ends_lengths = vec![ends.len() as i64];

input_data.insert("X".to_string(), data.into());
input_shapes.push(tensor("X", data_shape));
input_names.push("X");

input_data.insert("S".to_string(), starts.into());
input_shapes.push(tensor("S", &starts_lengths[..]));
input_names.push("S");

input_data.insert("E".to_string(), ends.into());
input_shapes.push(tensor("E", &ends_lengths[..]));
input_names.push("E");

let mut axes_unwraped: Vec<f32> = vec![];
if axes.is_some() {
axes_unwraped = axes.unwrap();
let axes_lengths = vec![axes_unwraped.len() as i64];
input_data.insert("A".to_string(), (&axes_unwraped[..]).into());
input_shapes.push(tensor("A", &axes_lengths[..]));
input_names.push("A");
}

let mut steps_unwraped: Vec<f32> = vec![];
if steps.is_some() {
steps_unwraped = steps.unwrap();
let steps_lengths = vec![steps_unwraped.len() as i64];
input_data.insert("P".to_string(), (&steps_unwraped[..]).into());
input_shapes.push(tensor("P", &steps_lengths[..]));
input_names.push("P");
}

// Model: (X, S, E, A?, P?) -> Slice -> Y
let bn_model = model(graph(
input_shapes,
vec![tensor("Y", output_shape)],
vec![],
vec![],
vec![node(
input_names,
vec!["Y"],
"mySlice",
"Slice",
vec![]
)],
));

let session =
pollster::block_on(wonnx::Session::from_model(bn_model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();
common::assert_eq_vector((&result["Y"]).try_into().unwrap(), output);
}

#[test]
fn slice() {
let _ = env_logger::builder().is_test(true).try_init();

// Example 1 from https://onnx.ai/onnx/operators/onnx__Slice.html#slice.
assert_slice(
&[1., 2., 3., 4., 5., 6., 7., 8.],
&[4, 2],
&[5., 7.],
&[2, 1],
&[1., 0.],
&[2., 3.],
Some(vec![0., 1.]),
Some(vec![1., 2.])
);
}