Skip to content

Commit

Permalink
Fix Concat for larger inputs and add an additional test
Browse files Browse the repository at this point in the history
The old calculation of actual_idx was not correct because it didn't
consider the total number of executions on the x axis.
  • Loading branch information
mayjs committed Aug 5, 2023
1 parent 07b7a9e commit c04d389
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
4 changes: 2 additions & 2 deletions wonnx/templates/matrix/concat.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(num_workgr
let gidx = global_id.x;
let gidy = global_id.y;

let nx = num_workgroups.x;
let x_executions = num_workgroups.x * 16u;

let actual_idx = gidx + gidy * nx;
let actual_idx = gidx + gidy * x_executions;

{% for input in i_lens %}
{% if loop.first %}
Expand Down
33 changes: 33 additions & 0 deletions wonnx/tests/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,39 @@ fn test_concat() {
common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
}

#[test]
fn test_concat_long() {
let n: usize = 100000;

let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
let input_dims = vec![n as i64];
let output_dims = vec![(n * 2) as i64];

let input_data = HashMap::from([
("X".into(), xdata.as_slice().into()),
("Y".into(), ydata.as_slice().into()),
]);

let model = model(graph(
vec![tensor("X", &input_dims), tensor("Y", &input_dims)],
vec![tensor("Z", &output_dims)],
vec![],
vec![],
vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();

let mut expected_result = xdata.clone();
expected_result.append(&mut ydata);

common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
}

#[test]
fn test_concat4() {
let n: usize = 13;
Expand Down

0 comments on commit c04d389

Please sign in to comment.