Skip to content

Commit

Permalink
update pytest functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Sophia Maedler committed Oct 26, 2023
1 parent dac93cb commit ce80e02
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/sparcscore/processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ def test_shift_labels():
[ 0, 0, 13]])
expected_edge_labels = [1, 3]

shifted_map, edge_labels = shift_labels(input_map, shift)
shifted_map, edge_labels = shift_labels(input_map, shift, remove_edge_labels = False)

expected_edge_labels_with_shift = np.array(expected_edge_labels) + shift

shifted_map_with_shift, edge_labels_with_shift = shift_labels(input_map, shift, return_shifted_labels = True)

shifted_map_with_shift, edge_labels_with_shift = shift_labels(input_map, shift, return_shifted_labels = True, remove_edge_labels = False)
assert np.array_equal(shifted_map, expected_shifted_map)
assert np.array_equal(shifted_map, shifted_map_with_shift)
assert set(edge_labels) == set(expected_edge_labels)
Expand All @@ -122,10 +122,18 @@ def test_shift_labels():
[ 0, 13, 0]]])
expected_edge_labels = [1,2, 3]

shifted_map_3d, edge_labels_3d = shift_labels(input_map_3d, shift)
shifted_map_3d, edge_labels_3d = shift_labels(input_map_3d, shift, remove_edge_labels = False)


assert np.array_equal(shifted_map_3d, expected_shifted_map_3d)
assert set(edge_labels_3d) == set(expected_edge_labels)

#test if removing edge labels works
shifted_map_removed_edge_labels, edge_labels = shift_labels(input_map, shift, remove_edge_labels = True)
expected_shifted_map_removed_edge_labels = np.array([[0, 0, 0],
[ 0, 12, 0],
[ 0, 0, 0]])
assert np.array_equal(shifted_map_removed_edge_labels, expected_shifted_map_removed_edge_labels)



Expand Down Expand Up @@ -376,7 +384,7 @@ def test_processing_step_get_directory():
config = {'setting1': 'value1'}
with tempfile.TemporaryDirectory() as temp_dir:
processing_step = ProcessingStep(config, f"{temp_dir}/test_step", temp_dir)
assert temp_dir == processing_step.get_directory()
assert f"{temp_dir}/test_step" == processing_step.get_directory()


#general test to check that testing is working
Expand Down

0 comments on commit ce80e02

Please sign in to comment.