Skip to content

Commit

Permalink
Improve error messages when asserting shape
Browse files Browse the repository at this point in the history
Adds context for array shape mismatch such as unit operation index,
solution type etc.
  • Loading branch information
schmoelder committed Dec 3, 2024
1 parent 7fdae15 commit 0816e1b
Showing 1 changed file with 72 additions and 10 deletions.
82 changes: 72 additions & 10 deletions tests/test_dll.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Use this to specify custom cadet_roots if you require it.
cadet_root = None


def setup_model(
cadet_root,
use_dll=True,
Expand Down Expand Up @@ -875,6 +876,39 @@ def __repr__(self):
_2dgrm_split_all
]


def assert_shape(array_shape, expected_shape, context, key, unit_id=None):
"""
Assert that the shape of an array matches the expected shape.
Parameters
----------
array_shape : tuple
The shape of the actual array to validate.
expected_shape : tuple
The expected shape to compare against.
context : str
High-level context for the assertion,
e.g., 'last_state', 'coordinates', 'solution'.
key : str
Specific key or identifier within the context.
unit_id : str, optional
Unit identifier, e.g., 'unit_000'. If not provided, it is assumed the context
does not require unit-specific validation.
Raises
------
AssertionError
If the actual shape does not match the expected shape, including detailed context.
"""
unit_info = f"in unit '{unit_id}'" if unit_id else ""
assert array_shape == expected_shape, (
f"Shape mismatch {unit_info} for {context}[{key}]. "
f"Expected {expected_shape}, but got {array_shape}."
)


@pytest.mark.parametrize("use_dll", use_dll)
@pytest.mark.parametrize("test_case", test_cases)
def test_simulator_options(use_dll, test_case):
Expand All @@ -886,27 +920,55 @@ def test_simulator_options(use_dll, test_case):
use_dll, model_options, solution_recorder_options
)

assert model.root.output.last_state_y.shape == expected_results['last_state_y']
assert model.root.output.last_state_ydot.shape == expected_results['last_state_ydot']
# Assert last_state shapes
assert_shape(
model.root.output.last_state_y.shape,
expected_results['last_state_y'],
context="last_state",
key="y"
)
assert_shape(
model.root.output.last_state_ydot.shape,
expected_results['last_state_ydot'],
context="last_state",
key="ydot"
)

# Check coordinates for unit_000
for key, value in expected_results['coordinates_unit_000'].items():
assert model.root.output.coordinates.unit_000[key].shape == value

assert model.root.output.solution.solution_times.shape == expected_results['solution_times']
coordinates_shape = model.root.output.coordinates.unit_000[key].shape
assert_shape(coordinates_shape, value, context="coordinates", key=key, unit_id="unit_000")

# Assert solution_times shape
assert_shape(
model.root.output.solution.solution_times.shape,
expected_results['solution_times'],
context="solution",
key="solution_times"
)

# Check solution for unit_000
for key, value in expected_results['solution_unit_000'].items():
assert model.root.output.solution.unit_000[key].shape == value
shape = model.root.output.solution.unit_000[key].shape
assert_shape(shape, value, context="solution", key=key, unit_id="unit_000")

# Check solution for unit_001
for key, value in expected_results['solution_unit_001'].items():
assert model.root.output.solution.unit_001[key].shape == value
shape = model.root.output.solution.unit_001[key].shape
assert_shape(shape, value, context="solution", key=key, unit_id="unit_001")

if model_options['include_sensitivity']:
for key, value in expected_results['sens_param_000_unit_000'].items():
assert model.root.output.sensitivity.param_000.unit_000[key].shape == value
shape = model.root.output.sensitivity.param_000.unit_000[key].shape
assert_shape(
shape, value, context="sensitivity", key=key, unit_id="unit_000"
)

if model_options['include_sensitivity']:
for key, value in expected_results['sens_param_000_unit_001'].items():
assert model.root.output.sensitivity.param_000.unit_001[key].shape == value
shape = model.root.output.sensitivity.param_000.unit_001[key].shape
assert_shape(
shape, value, context="sensitivity", key=key, unit_id="unit_001"
)


@pytest.mark.parametrize("use_dll", use_dll)
Expand Down

0 comments on commit 0816e1b

Please sign in to comment.