diff --git a/wonnx/templates/matrix/concat.wgsl b/wonnx/templates/matrix/concat.wgsl index 4c282723..dfba2016 100644 --- a/wonnx/templates/matrix/concat.wgsl +++ b/wonnx/templates/matrix/concat.wgsl @@ -17,9 +17,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3, @builtin(num_workgr let gidx = global_id.x; let gidy = global_id.y; - let nx = num_workgroups.x; + let x_executions = num_workgroups.x * 16u; - let actual_idx = gidx + gidy * nx; + let actual_idx = gidx + gidy * x_executions; {% for input in i_lens %} {% if loop.first %} diff --git a/wonnx/tests/concat.rs b/wonnx/tests/concat.rs index 980ad039..db13d708 100644 --- a/wonnx/tests/concat.rs +++ b/wonnx/tests/concat.rs @@ -35,6 +35,39 @@ fn test_concat() { common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result); } +#[test] +fn test_concat_long() { + let n: usize = 100000; + + let xdata: Vec = (0..n).map(|x| x as f32).collect(); + let mut ydata: Vec = (n..2 * n).map(|x| x as f32).collect(); + let input_dims = vec![n as i64]; + let output_dims = vec![(n * 2) as i64]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ]); + + let model = model(graph( + vec![tensor("X", &input_dims), tensor("Y", &input_dims)], + vec![tensor("Z", &output_dims)], + vec![], + vec![], + vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + let mut expected_result = xdata.clone(); + expected_result.append(&mut ydata); + + common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result); +} + #[test] fn test_concat4() { let n: usize = 13;