Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add protections against duplicate frames #545

Merged
merged 15 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

- Implement code linting and automatic formatting. [#544]

- Refactor ``WCS`` to use a ``Pipeline`` base class which adds basic checks to ensure that the pipeline is valid. These
include checking for duplicate frame names and that the last transform is ``None``. [#545]


0.22.0 (2024-12-19)
-------------------
Expand Down
10 changes: 2 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import importlib.metadata

try:
from sphinx_astropy.conf.v1 import * # noqa: F403
from sphinx_astropy.conf.v2 import * # noqa: F403
except ImportError:
print( # noqa: T201
"ERROR: the documentation requires the sphinx-astropy package to be installed"
Expand Down Expand Up @@ -108,13 +108,6 @@
# name of a builtin theme or the name of a custom theme in html_theme_path.
# html_theme = None

# See sphinx-bootstrap-theme for documentation of these options
# https://github.com/ryan-roemer/sphinx-bootstrap-theme
html_theme_options = {
"logotext1": "g", # white, semi-bold
"logotext2": "wcs", # orange, light
"logotext3": ":docs", # white, light
}

# Custom sidebar templates, maps document names to template names.
# html_sidebars = {}
Expand Down Expand Up @@ -156,6 +149,7 @@
nitpicky = True
nitpick_ignore = [
("py:class", "gwcs.api.GWCSAPIMixin"),
("py:class", "gwcs.wcs._pipeline.Pipeline"),
("py:obj", "astropy.modeling.projections.projcodes"),
("py:attr", "gwcs.WCS.bounding_box"),
("py:meth", "gwcs.WCS.footprint"),
Expand Down
10 changes: 10 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,18 @@ Reference/API
-------------

.. automodapi:: gwcs.wcs
:inherited-members:

.. automodapi:: gwcs.coordinate_frames
:inherited-members:

.. automodapi:: gwcs.wcstools

.. automodapi:: gwcs.selector
:inherited-members:

.. automodapi:: gwcs.spectroscopy
:inherited-members:

.. automodapi:: gwcs.geometry
:inherited-members:
43 changes: 22 additions & 21 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def pixel_to_world_values(self, *pixel_arrays):
def array_index_to_world_values(self, *index_arrays):
"""
Convert array indices to world coordinates.
This is the same as `~BaseLowLevelWCS.pixel_to_world_values` except that
the indices should be given in ``(i, j)`` order, where for an image
This is the same as `~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_to_world_values`
except that the indices should be given in ``(i, j)`` order, where for an image
``i`` is the row and ``j`` is the column (i.e. the opposite order to
`~BaseLowLevelWCS.pixel_to_world_values`).
`~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_to_world_values`).
"""
pixel_arrays = index_arrays[::-1]
return self.pixel_to_world_values(*pixel_arrays)
Expand All @@ -127,11 +127,11 @@ def world_to_pixel_values(self, *world_arrays):
def world_to_array_index_values(self, *world_arrays):
"""
Convert world coordinates to array indices.
This is the same as `~BaseLowLevelWCS.world_to_pixel_values` except that
the indices should be returned in ``(i, j)`` order, where for an image
``i`` is the row and ``j`` is the column (i.e. the opposite order to
`~BaseLowLevelWCS.pixel_to_world_values`). The indices should be
returned as rounded integers.
This is the same as `~astropy.wcs.wcsapi.BaseLowLevelWCS.world_to_pixel_values`
except that the indices should be returned in ``(i, j)`` order, where for an
image ``i`` is the row and ``j`` is the column (i.e. the opposite order to
`~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_to_world_values`). The indices should
be returned as rounded integers.
"""
results = self.world_to_pixel_values(*world_arrays)
results = (results,) if self.pixel_n_dim == 1 else results[::-1]
Expand All @@ -143,7 +143,7 @@ def world_to_array_index_values(self, *world_arrays):
def array_shape(self):
"""
The shape of the data that the WCS applies to as a tuple of
length `~BaseLowLevelWCS.pixel_n_dim`.
length `~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_n_dim`.
If the WCS is valid in the context of a dataset with a particular
shape, then this property can be used to store the shape of the
data. This can be used for example if implementing slicing of WCS
Expand All @@ -167,12 +167,13 @@ def array_shape(self, value):
def pixel_bounds(self):
"""
The bounds (in pixel coordinates) inside which the WCS is defined,
as a list with `~BaseLowLevelWCS.pixel_n_dim` ``(min, max)`` tuples.
The bounds should be given in ``[(xmin, xmax), (ymin, ymax)]``
order. WCS solutions are sometimes only guaranteed to be accurate
within a certain range of pixel values, for example when defining a
WCS that includes fitted distortions. This is an optional property,
and it should return `None` if a shape is not known or relevant.
as a list with `~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_n_dim`
``(min, max)`` tuples. The bounds should be given in
``[(xmin, xmax), (ymin, ymax)]`` order. WCS solutions are sometimes
only guaranteed to be accurate within a certain range of pixel values,
for example when defining a WCS that includes fitted distortions. This
is an optional property, and it should return `None` if a shape is not
known or relevant.
"""
bounding_box = self.bounding_box
if bounding_box is None:
Expand Down Expand Up @@ -225,12 +226,12 @@ def pixel_shape(self, value):
@property
def axis_correlation_matrix(self):
"""
Returns an (`~BaseLowLevelWCS.world_n_dim`,
`~BaseLowLevelWCS.pixel_n_dim`) matrix that indicates using booleans
whether a given world coordinate depends on a given pixel coordinate.
This defaults to a matrix where all elements are `True` in the absence of
any further information. For completely independent axes, the diagonal
would be `True` and all other entries `False`.
Returns an (`~astropy.wcs.wcsapi.BaseLowLevelWCS.world_n_dim`,
`~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_n_dim`) matrix that indicates
using booleans whether a given world coordinate depends on a given pixel
coordinate. This defaults to a matrix where all elements are `True` in
the absence of any further information. For completely independent axes,
the diagonal would be `True` and all other entries `False`.
"""
return separable.separability_matrix(self.forward_transform)

Expand Down
14 changes: 8 additions & 6 deletions gwcs/converters/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ def _assert_frame_equal(a, b):
return a == b

assert a.name == b.name # nosec
assert a.axes_order == b.axes_order # nosec
assert a.axes_names == b.axes_names # nosec
assert a.unit == b.unit # nosec
assert a.reference_frame == b.reference_frame # nosec
if not isinstance(a, cf.EmptyFrame):
assert a.axes_order == b.axes_order # nosec
assert a.axes_names == b.axes_names # nosec
assert a.unit == b.unit # nosec
assert a.reference_frame == b.reference_frame # nosec
return None


Expand Down Expand Up @@ -155,12 +156,13 @@ def test_references(tmp_path):
m1 = models.Shift(12.4) & models.Shift(-2)
icrs = cf.CelestialFrame(name="icrs", reference_frame=coord.ICRS())
det = cf.Frame2D(name="detector", axes_order=(0, 1))
det2 = cf.Frame2D(name="detector2", axes_order=(0, 1))
focal = cf.Frame2D(name="focal", axes_order=(0, 1))

pipe1 = [(det, m1), (focal, m1), (icrs, None)]
gw1 = wcs.WCS(pipe1)

pipe2 = [(det, m1), (det, m1), (icrs, None)]
pipe2 = [(det, m1), (det2, m1), (icrs, None)]
gw2 = wcs.WCS(pipe2)

tree = {"wcs1": gw1, "wcs2": gw2}
Expand All @@ -173,4 +175,4 @@ def test_references(tmp_path):
gw2 = af.tree["wcs2"]
assert gw1.pipeline[0].transform is gw1.pipeline[1].transform
assert gw2.pipeline[0].transform is gw2.pipeline[1].transform
assert gw2.pipeline[0].frame is gw2.pipeline[1].frame
assert gw2.pipeline[0].frame is gw1.pipeline[0].frame
9 changes: 8 additions & 1 deletion gwcs/converters/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ def from_yaml_tree(self, node, tag, ctx):
return Step(frame=node["frame"], transform=node.get("transform", None))

def to_yaml_tree(self, step, tag, ctx):
return {"frame": step.frame, "transform": step.transform}
from gwcs.coordinate_frames import EmptyFrame

return {
"frame": step.frame.name
if isinstance(step.frame, EmptyFrame)
else step.frame,
"transform": step.transform,
}


class FrameConverter(Converter):
Expand Down
75 changes: 75 additions & 0 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
"CelestialFrame",
"CompositeFrame",
"CoordinateFrame",
"EmptyFrame",
"Frame2D",
"SpectralFrame",
"StokesFrame",
Expand Down Expand Up @@ -608,6 +609,80 @@
return values


class EmptyFrame(CoordinateFrame):
"""
Represents a "default" detector frame. This is for use as the default value
for input frame by the WCS object.
"""

def __init__(self, name=None):
self._name = "detector" if name is None else name

def __repr__(self):
return f'<{type(self).__name__}(name="{self.name}")>'

Check warning on line 622 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L622

Added line #L622 was not covered by tests

def __str__(self):
if self._name is not None:
return self._name
return type(self).__name__

Check warning on line 627 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L627

Added line #L627 was not covered by tests

@property
def name(self):
"""A custom name of this frame."""
return self._name

@name.setter
def name(self, val):
"""A custom name of this frame."""
self._name = val

Check warning on line 637 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L637

Added line #L637 was not covered by tests

def _raise_error(self) -> None:
msg = "EmptyFrame does not have any information"

Check warning on line 640 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L640

Added line #L640 was not covered by tests
raise NotImplementedError(msg)

@property
def naxes(self):
self._raise_error()

Check warning on line 645 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L645

Added line #L645 was not covered by tests

@property
def unit(self):
self._raise_error()

Check warning on line 649 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L649

Added line #L649 was not covered by tests

@property
def axes_names(self):
self._raise_error()

Check warning on line 653 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L653

Added line #L653 was not covered by tests

@property
def axes_order(self):
self._raise_error()

Check warning on line 657 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L657

Added line #L657 was not covered by tests

@property
def reference_frame(self):
self._raise_error()

Check warning on line 661 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L661

Added line #L661 was not covered by tests

@property
def axes_type(self):
self._raise_error()

Check warning on line 665 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L665

Added line #L665 was not covered by tests

@property
def axis_physical_types(self):
self._raise_error()

Check warning on line 669 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L669

Added line #L669 was not covered by tests

@property
def world_axis_object_classes(self):
self._raise_error()

Check warning on line 673 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L673

Added line #L673 was not covered by tests

@property
def _native_world_axis_object_components(self):
self._raise_error()

Check warning on line 677 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L677

Added line #L677 was not covered by tests

def to_high_level_coordinates(self, *values):
self._raise_error()

Check warning on line 680 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L680

Added line #L680 was not covered by tests

def from_high_level_coordinates(self, *high_level_coords):
self._raise_error()

Check warning on line 683 in gwcs/coordinate_frames.py

View check run for this annotation

Codecov / codecov/patch

gwcs/coordinate_frames.py#L683

Added line #L683 was not covered by tests


class CelestialFrame(CoordinateFrame):
"""
Representation of a Celesital coordinate system.
Expand Down
46 changes: 35 additions & 11 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
stokes = cf.StokesFrame(axes_order=(2,))

pipe = [wcs.Step(detector, m1), wcs.Step(focal, m2), wcs.Step(icrs, None)]
pipe_copy = pipe.copy()

# Create some data.
nx, ny = (5, 2)
Expand Down Expand Up @@ -104,28 +105,28 @@ def test_init_no_transform():
"""
gw = wcs.WCS(output_frame="icrs")
assert len(gw._pipeline) == 2
assert gw.pipeline[0].frame == "detector"
assert gw.pipeline[0].frame.name == "detector"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw.pipeline[0][0] == "detector"
assert gw.pipeline[1].frame == "icrs"
assert gw.pipeline[0][0].name == "detector"
assert gw.pipeline[1].frame.name == "icrs"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw.pipeline[1][0] == "icrs"
assert gw.pipeline[1][0].name == "icrs"
assert np.isin(gw.available_frames, ["detector", "icrs"]).all()
gw = wcs.WCS(output_frame=icrs, input_frame=detector)
assert gw._pipeline[0].frame == "detector"
assert gw._pipeline[0].frame.name == "detector"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw._pipeline[0][0] == "detector"
assert gw._pipeline[1].frame == "icrs"
assert gw._pipeline[0][0].name == "detector"
assert gw._pipeline[1].frame.name == "icrs"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw._pipeline[1][0] == "icrs"
assert gw._pipeline[1][0].name == "icrs"
assert np.isin(gw.available_frames, ["detector", "icrs"]).all()
with pytest.raises(NotImplementedError):
gw(1, 2)
Expand Down Expand Up @@ -732,7 +733,7 @@ def test_units(self):
assert self.wcs.unit == (u.degree, u.degree)

def test_get_transform(self):
with pytest.raises(wcs.CoordinateFrameError):
with pytest.raises(CoordinateFrameError):
assert (
self.wcs.get_transform("x_translation", "sky_rotation").submodel_names
== self.wcs.forward_transform[1:].submodel_names
Expand Down Expand Up @@ -1385,8 +1386,8 @@ def test_initialize_wcs_with_list():
shift2 = models.Shift(3 * u.pix)
pipeline = [("detector", shift1), wcs.Step("extra_step", shift2)]

extra_step = ("extra_step", None)
pipeline.append(extra_step)
end_step = ("end_step", None)
pipeline.append(end_step)

# make sure no warnings occur when creating wcs with this pipeline
with warnings.catch_warnings():
Expand Down Expand Up @@ -1735,3 +1736,26 @@ def test_high_level_objects_in_pipeline_backward(gwcs_with_pipeline_celestial):
with_units=True,
)
assert isinstance(intermediate_world, coord.SkyCoord)


def test_error_with_duplicate_frames():
"""
Test that an error is raised if a frame is used more than once in the pipeline.
"""
pipeline = [(detector, m1), (detector, m2), (focal, None)]

with pytest.raises(ValueError, match="Frame detector is already in the pipeline."):
wcs.WCS(pipeline)


def test_error_with_not_none_last():
"""
Test that an error is raised if the last transform is not None
"""

pipeline = [(detector, m1), (focal, m2)]

with pytest.raises(
ValueError, match="The last step in the pipeline must have a None transform."
):
wcs.WCS(pipeline)
10 changes: 10 additions & 0 deletions gwcs/wcs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ._exception import GwcsBoundingBoxWarning, NoConvergence
from ._step import Step
from ._wcs import WCS

__all__ = [
"WCS",
"GwcsBoundingBoxWarning",
"NoConvergence",
"Step",
]
Loading
Loading