Skip to content

Commit

Permalink
add a more complex pad reflect test
Browse files Browse the repository at this point in the history
  • Loading branch information
redthing1 committed Nov 5, 2023
1 parent 866a5c4 commit 77cc055
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions wonnx/tests/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,42 @@ fn test_pad_reflect() {
assert_eq!(actual, &test_y);
}

#[test]
fn test_pad_reflect_complex() {
let mut input_data = HashMap::new();
#[rustfmt::skip]
let data = [
1.0, 1.2, 1.3,
2.3, 3.4, 4.5,
4.5, 5.7, 6.8,
].to_vec();
input_data.insert("X".to_string(), data.as_slice().into());

let model = model(graph(
vec![tensor("X", &[3, 3])],
vec![tensor("Y", &[3, 7])],
vec![],
vec![initializer_int64("pads", vec![0, 2, 0, 2], vec![4])],
vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![
attribute("mode", "reflect"),
])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create");
let result = pollster::block_on(session.run(&input_data)).unwrap();

#[rustfmt::skip]
let test_y = vec![
1.3, 1.2, 1.0, 1.2, 1.3, 1.2, 1.0,
4.5, 3.4, 2.3, 3.4, 4.5, 3.4, 2.3,
6.8, 5.7, 4.5, 5.7, 6.8, 5.7, 4.5,
];
let actual: &[_] = (&result["Y"]).try_into().unwrap();
// No arithmetic is done, so `assert_eq!` can be used.
assert_eq!(actual, &test_y);
}

#[test]
fn test_resize() {
let _ = env_logger::builder().is_test(true).try_init();
Expand Down

0 comments on commit 77cc055

Please sign in to comment.