Skip to content

Commit

Permalink
Merge pull request #183 from mayjs/fix_concat
Browse files Browse the repository at this point in the history
Fix Concat for larger inputs and add an additional test
  • Loading branch information
pixelspark authored Aug 7, 2023
2 parents 07b7a9e + c04d389 commit fbb7ab1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
4 changes: 2 additions & 2 deletions wonnx/templates/matrix/concat.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(num_workgr
let gidx = global_id.x;
let gidy = global_id.y;

let nx = num_workgroups.x;
let x_executions = num_workgroups.x * 16u;

let actual_idx = gidx + gidy * nx;
let actual_idx = gidx + gidy * x_executions;

{% for input in i_lens %}
{% if loop.first %}
Expand Down
33 changes: 33 additions & 0 deletions wonnx/tests/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,39 @@ fn test_concat() {
common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
}

#[test]
fn test_concat_long() {
let n: usize = 100000;

let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
let input_dims = vec![n as i64];
let output_dims = vec![(n * 2) as i64];

let input_data = HashMap::from([
("X".into(), xdata.as_slice().into()),
("Y".into(), ydata.as_slice().into()),
]);

let model = model(graph(
vec![tensor("X", &input_dims), tensor("Y", &input_dims)],
vec![tensor("Z", &output_dims)],
vec![],
vec![],
vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();

let mut expected_result = xdata.clone();
expected_result.append(&mut ydata);

common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
}

#[test]
fn test_concat4() {
let n: usize = 13;
Expand Down

0 comments on commit fbb7ab1

Please sign in to comment.