From ce80e02b7c193173deda35bc84eed07467bbc1cc Mon Sep 17 00:00:00 2001 From: Sophia Maedler Date: Thu, 26 Oct 2023 16:22:16 +0200 Subject: [PATCH] update pytest functions --- src/sparcscore/processing_test.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/sparcscore/processing_test.py b/src/sparcscore/processing_test.py index 069aca3..f6246d6 100644 --- a/src/sparcscore/processing_test.py +++ b/src/sparcscore/processing_test.py @@ -96,12 +96,12 @@ def test_shift_labels(): [ 0, 0, 13]]) expected_edge_labels = [1, 3] - shifted_map, edge_labels = shift_labels(input_map, shift) + shifted_map, edge_labels = shift_labels(input_map, shift, remove_edge_labels = False) expected_edge_labels_with_shift = np.array(expected_edge_labels) + shift - shifted_map_with_shift, edge_labels_with_shift = shift_labels(input_map, shift, return_shifted_labels = True) - + shifted_map_with_shift, edge_labels_with_shift = shift_labels(input_map, shift, return_shifted_labels = True, remove_edge_labels = False) + assert np.array_equal(shifted_map, expected_shifted_map) assert np.array_equal(shifted_map, shifted_map_with_shift) assert set(edge_labels) == set(expected_edge_labels) @@ -122,10 +122,18 @@ def test_shift_labels(): [ 0, 13, 0]]]) expected_edge_labels = [1,2, 3] - shifted_map_3d, edge_labels_3d = shift_labels(input_map_3d, shift) + shifted_map_3d, edge_labels_3d = shift_labels(input_map_3d, shift, remove_edge_labels = False) + assert np.array_equal(shifted_map_3d, expected_shifted_map_3d) assert set(edge_labels_3d) == set(expected_edge_labels) + + #test if removing edge labels works + shifted_map_removed_edge_labels, edge_labels = shift_labels(input_map, shift, remove_edge_labels = True) + expected_shifted_map_removed_edge_labels = np.array([[0, 0, 0], + [ 0, 12, 0], + [ 0, 0, 0]]) + assert np.array_equal(shifted_map_removed_edge_labels, expected_shifted_map_removed_edge_labels) @@ -376,7 +384,7 @@ def test_processing_step_get_directory(): config = {'setting1': 'value1'} with tempfile.TemporaryDirectory() as temp_dir: processing_step = ProcessingStep(config, f"{temp_dir}/test_step", temp_dir) - assert temp_dir == processing_step.get_directory() + assert f"{temp_dir}/test_step" == processing_step.get_directory() #general test to check that testing is working