Skip to content

Commit

Permalink
Implemented Danish algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcrenshaw committed Jan 12, 2024
1 parent b72cfe0 commit 008bc36
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 52 deletions.
9 changes: 9 additions & 0 deletions policy/estimation/danish.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Default parameters for the Danish algorithm
# see the docstring for definitions of parameter

lstsqKwargs:
ftol: 1.0e-3
xtol: 1.0e-3
gtol: 1.0e-3
max_nfev: 20
saveHistory: False
288 changes: 288 additions & 0 deletions python/lsst/ts/wep/estimation/danish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# This file is part of ts_wep.
#
# Developed for the LSST Telescope and Site Systems.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = ["DanishAlgorithm"]

import warnings
from typing import Optional, Tuple, Union

import danish
import numpy as np
from lsst.ts.wep import Image, Instrument
from lsst.ts.wep.estimation.wfAlgorithm import WfAlgorithm
from scipy.optimize import least_squares


class DanishAlgorithm(WfAlgorithm):
"""Wavefront estimation algorithm class for Danish.
Parameters
----------
configFile : str, optional
Path to file specifying values for the other parameters. If the
path starts with "policy/", it will look in the policy directory.
Any explicitly passed parameters override values found in this file
(the default is policy/estimation/danish.yaml)
lstsqKwargs : dict, optional
A dictionary containing any of the keyword arguments for
scipy.optimize.least_squares, except `fun`, `jac`, or `args`.
Note that if x0 is not provided, it will default to an initial
guess of PSF FWHM = 1 arcsec and all other initial guesses = 0.
saveHistory : bool, optional
Whether to save the algorithm history in the self.history attribute.
If True, then self.history contains information about the most recent
time the algorithm was run.
"""

def __init__(
self,
configFile: Union[str, None] = "policy/estimation/danish.yaml",
lstsqKwargs: Optional[dict] = None,
saveHistory: Optional[bool] = None,
) -> None:
super().__init__(
configFile=configFile,
lstsqKwargs=lstsqKwargs,
saveHistory=saveHistory,
)

@property
def lstsqKwargs(self) -> dict:
"""Keyword arguments for scipy.optimize.least_squares"""
return self._lstsqKwargs

@lstsqKwargs.setter
def lstsqKwargs(self, value: Union[dict, None]) -> None:
"""Set the keyword arguments for scipy.optimize.least_squares.
Parameters
----------
value : dict, optional
A dictionary containing any of the keyword arguments for
scipy.optimize.least_squares, except `fun`, `x0`, `jac`, or `args`.
(the default is an empty dictionary)
"""
# If None, get empty dictionary
if value is None:
value = dict()

# Cast to a dict
value = dict(value)

# Make sure these keys are not provided
notAllowed = ["fun", "x0", "jac", "args"]
for key in notAllowed:
if key in value:
raise KeyError(f"Please do not provide '{key}' in lstsqKwargs.")

self._lstsqKwargs = value

@property
def history(self) -> dict:
"""The algorithm history.
The history is a dictionary that contains intermediate products
from the Zernike fitting. The dict contains entries for "intra"
and/or "extra". Each of these is a dictionary that contains the
following entries
- "image" - the image that is being fit
- "variance" - the background variance that was used for fitting
- "lstsqResult" - dictionary of results returned by least_squares
- "zk" - the Zernike coefficients fit to the donut
- "model" - the final forward modeled donut image
"""
return super().history

def _estimateSingleZk(
self,
image: Image,
instrument: Instrument,
factory: danish.DonutFactory,
x0: list,
) -> Tuple[np.ndarray, dict]:
"""Estimate Zernikes (in meters) for a single donut stamp.
Parameters
----------
image : Image
The ts_wep image of the donut stamp
instrument : Instrument
The ts_wep Instrument
factory : danish.DonutFactory
The Danish donut factory
x0 : list
The initial guess for the model parameters
Returns
-------
np.ndarray
The Zernike coefficients (in meters) for Noll indices >= 4
dict
The single-stamp history. This is empty if saveHistory is False.
"""
# Warn about using Danish for blended donuts
if image.blendOffsets.size > 0:
warnings.warn("Danish is currently only setup for non-blended donuts.")

# Get the Zernike coefficients for the off-axis model
zkRef = instrument.getOffAxisCoeff(
*image.fieldAngle,
image.defocalType,
image.bandLabel,
return4Up=False,
)

# Get the image array
img = image.image

# If size of image is even, cut off final row/column
if img.shape[0] % 2 == 0:
img = img[:-1, :-1]

# Create the Danish donut model
model = danish.SingleDonutModel(
factory,
z_ref=zkRef,
z_terms=np.arange(4, len(x0) + 1),
thx=np.deg2rad(image.fieldAngle[0]),
thy=np.deg2rad(image.fieldAngle[1]),
npix=img.shape[0],
)

# Get variance of background from first few rows of image
var = img[:4].std() ** 2

# Use scipy to optimize the parameters
result = least_squares(
model.chi,
jac=model.jac,
x0=x0,
args=(img, var),
**self.lstsqKwargs,
)

# Unpack the parameters
dx, dy, fwhm, *zk = result.x

if self.saveHistory:
# Save the image and variance
hist = {
"image": img,
"variance": var,
}

# Save the least squares result
hist["lstsqResult"] = dict(result)

# Save the Zernike coefficients in meters
hist["zk"] = zk

# Also add the forward modeled image
hist["model"] = model.model(dx, dy, fwhm, zk)

else:
hist = {}

return zk, hist

def estimateZk(
self,
I1: Image,
I2: Optional[Image] = None,
jmax: int = 22,
instrument: Instrument = Instrument(),
) -> np.ndarray:
"""Return the wavefront Zernike coefficients in meters.
Parameters
----------
I1 : Image
An Image object containing an intra- or extra-focal donut image.
I2 : Image
A second image, on the opposite side of focus from I1.
jmax : int, optional
The maximum Zernike Noll index to estimate.
(the default is 22)
instrument : Instrument, optional
The Instrument object associated with the DonutStamps.
(the default is the default Instrument)
Returns
-------
np.ndarray
Zernike coefficients (for Noll indices >= 4) estimated from
the images, in meters.
"""
# Validate the inputs
self._validateInputs(I1, I2, jmax, instrument)

# Create the Danish donut factory
factory = danish.DonutFactory(
R_outer=instrument.radius,
R_inner=instrument.radius * instrument.obscuration,
obsc_radii={
key.replace("Inner", "_inner"): val["radius"]
for key, val in instrument.maskParams.items()
},
obsc_centers={
key.replace("Inner", "_inner"): val["center"]
for key, val in instrument.maskParams.items()
},
obsc_th_mins={
key.replace("Inner", "_inner"): val["thetaMin"]
for key, val in instrument.maskParams.items()
},
focal_length=instrument.focalLength,
pixel_scale=instrument.pixelSize,
)

# Create the initial guess for the model parameters
x0 = [0.0, 0.0, 1.0] + [0.0] * (jmax - 3)

# Create an empty history
hist = {}

# Estimate for I1
zk1, hist[I1.defocalType.value] = self._estimateSingleZk(
I1,
instrument,
factory,
x0,
)

if I2 is not None:
# If I2 provided, estimate for that donut as well
zk2, hist[I2.defocalType.value] = self._estimateSingleZk(
I2,
instrument,
factory,
x0,
)

# Average the Zernikes
zk = np.mean([zk1, zk2], axis=0)
else:
zk = np.array(zk1)

if self.saveHistory:
self._history = hist

return zk
44 changes: 2 additions & 42 deletions python/lsst/ts/wep/estimation/tie.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
__all__ = ["TieAlgorithm"]

import inspect
import warnings
from typing import Iterable, Optional, Union

import numpy as np
Expand Down Expand Up @@ -80,7 +79,6 @@ class TieAlgorithm(WfAlgorithm):
Dictionary of mask keyword arguments to pass to mask creation.
To see possibilities, see the docstring for
lsst.ts.wep.imageMapper.ImageMapper.createPupilMask().
(the default is an empty dictionary)
saveHistory : bool, optional
Whether to save the algorithm history in the self.history attribute.
If True, then self.history contains information about the most recent
Expand Down Expand Up @@ -115,12 +113,9 @@ def __init__(
saveHistory=saveHistory,
)

# Instantiate an empty history
self._history = {} # type: ignore

@property
def opticalModel(self) -> str:
"""The optical model to use for"""
"""The optical model to use for mapping the image to the pupil."""
return self._opticalModel

@opticalModel.setter
Expand Down Expand Up @@ -374,35 +369,6 @@ def maskKwargs(self, value: Union[dict, None]) -> None:

self._maskKwargs = value

@property
def saveHistory(self) -> bool:
"""Whether the algorithm history is saved."""
return self._saveHistory

@saveHistory.setter
def saveHistory(self, value: bool) -> None:
"""Set boolean that determines whether algorithm history is saved.
Parameters
----------
value : bool
Boolean that determines whether the algorithm history is saved.
Raises
------
TypeError
If the value is not a boolean
"""
if not isinstance(value, bool):
raise TypeError("saveHistory must be a boolean.")

self._saveHistory = value

# If we are turning history-saving off, delete any old history
# This is to avoid confusion
if value is False:
self._history = {}

@property
def history(self) -> dict:
"""The algorithm history.
Expand All @@ -426,13 +392,7 @@ def history(self) -> dict:
Note the units for all Zernikes are in meters, and the z-derivative
in dIdz is also in meters.
"""
if not self._saveHistory:
warnings.warn(
"saveHistory is False. If you want the history to be saved, "
"run self.config(saveHistory=True)."
)

return self._history
return super().history

def _expSolve(
self,
Expand Down
Loading

0 comments on commit 008bc36

Please sign in to comment.