diff --git a/python/lsst/drp/tasks/assemble_cell_coadd.py b/python/lsst/drp/tasks/assemble_cell_coadd.py new file mode 100644 index 00000000..1de1b317 --- /dev/null +++ b/python/lsst/drp/tasks/assemble_cell_coadd.py @@ -0,0 +1,377 @@ +# This file is part of drp_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +__all__ = ["AssembleCellCoaddTask", "AssembleCellCoaddConfig"] + + +import lsst.afw.image as afwImage +import lsst.afw.math as afwMath +import numpy as np +from lsst.cell_coadds import ( + CellIdentifiers, + CoaddUnits, + CommonComponents, + GridContainer, + MultipleCellCoadd, + OwnedImagePlanes, + PatchIdentifiers, + SingleCellCoadd, + UniformGrid, +) +from lsst.meas.algorithms import AccumulatorMeanStack, CoaddPsf, CoaddPsfConfig +from lsst.pex.config import ConfigField, ConfigurableField, Field, ListField, RangeField +from lsst.pipe.base import NoWorkFound, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output +from lsst.pipe.tasks.coaddBase import makeSkyInfo +from lsst.pipe.tasks.coaddInputRecorder import CoaddInputRecorderTask +from lsst.pipe.tasks.interpImage import InterpImageTask +from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask +from lsst.skymap import BaseSkyMap + + +class AssembleCellCoaddConnections( + PipelineTaskConnections, + dimensions=("tract", "patch", "band", "skymap"), + defaultTemplates={"inputWarpName": "deep", "outputCoaddSuffix": "Cell"}, +): + inputWarps = Input( + doc="Input warps", + name="{inputWarpName}Coadd_directWarp", + storageClass="ExposureF", + dimensions=("tract", "patch", "skymap", "visit", "instrument"), + deferLoad=True, + multiple=True, + ) + + skyMap = Input( + doc="Input definition of geometry/bbox and projection/wcs. This must be cell-based.", + name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, + storageClass="SkyMap", + dimensions=("skymap",), + ) + + multipleCellCoadd = Output( + doc="Output multiple cell coadd", + name="{inputWarpName}Coadd{outputCoaddSuffix}", + storageClass="MultipleCellCoadd", + dimensions=("tract", "patch", "band", "skymap"), + ) + + +class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCellCoaddConnections): + do_interpolate_coadd = Field[bool](doc="Interpolate over pixels with NO_DATA mask set?", default=False) + interpolate_coadd = ConfigurableField( + target=InterpImageTask, + doc="Task to interpolate (and extrapolate) over pixels with NO_DATA mask on cell coadds", + ) + scale_zero_point = ConfigurableField( + target=ScaleZeroPointTask, + doc="Task to scale warps to a common zero point", + ) + bad_mask_planes = ListField[str]( + doc="Mask planes that count towards the masked fraction within a cell.", + default=("BAD", "NO_DATA", "SAT"), + ) + calc_error_from_input_variance = Field[bool]( + doc="Calculate coadd variance from input variance by stacking " + "statistic. Passed to AccumulatorMeanStack.", + default=False, + ) + max_maskfrac = RangeField[float]( + doc="Maximum fraction of masked pixels in a cell. This is currently " + "just a placeholder and is not used now", + default=0.99, + min=0.0, + max=1.0, + inclusiveMin=True, + inclusiveMax=False, + ) + # The following config options are specific to the CoaddPsf. + coadd_psf = ConfigField( + doc="Configuration for CoaddPsf", + dtype=CoaddPsfConfig, + ) + input_recorder = ConfigurableField( + doc="Subtask that helps fill CoaddInputs catalogs added to the final Exposure", + target=CoaddInputRecorderTask, + ) + + +class AssembleCellCoaddTask(PipelineTask): + """Assemble a cell-based coadded image from a set of warps. + + This task reads in the warp one at a time, and accumulates it in all the + cells that it completely overlaps with. This is the optimal I/O pattern but + this also implies that it is not possible to build one or only a few cells. + + Each cell coadds is guaranteed to have a well-defined PSF. This is done by + 1) excluding warps that only partially overlap a cell from that cell coadd; + 2) interpolating bad pixels in the warps rather than excluding them; + 3) by computing the coadd as a weighted mean of the warps without clipping; + 4) by computing the coadd PSF as the weighted mean of the PSF of the warps + with the same weights. + + The cells are (and must be) defined in the skymap, and cannot be configured + or redefined here. The cells are assumed to be small enough that the PSF is + assumed to be spatially constant within a cell. + + Raises + ------ + NoWorkFound + Raised if no input warps are provided. + RuntimeError + Raised if the skymap is not cell-based. + + Notes + ----- + This is not yet a part of the standard DRP pipeline. As such, the Task and + especially its Config and Connections are experimental and subject to + change any time without a formal RFC or standard deprecation procedures + until it is included in the DRP pipeline. + """ + + ConfigClass = AssembleCellCoaddConfig + _DefaultName = "assembleCellCoadd" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.makeSubtask("input_recorder") + self.makeSubtask("interpolate_coadd") + self.makeSubtask("scale_zero_point") + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + # Docstring inherited. + inputData = butlerQC.get(inputRefs) + + if not inputData["inputWarps"]: + raise NoWorkFound("No input warps provided for co-addition") + self.log.info("Found %d input warps", len(inputData["inputWarps"])) + + # Construct skyInfo expected by run + # Do not remove skyMap from inputData in case _makeSupplementaryData + # needs it + skyMap = inputData["skyMap"] + + if not skyMap.config.tractBuilder.name == "cells": + raise RuntimeError("AssembleCellCoaddTask requires a cell-based skymap.") + + outputDataId = butlerQC.quantum.dataId + + inputData["skyInfo"] = makeSkyInfo( + skyMap, tractId=outputDataId["tract"], patchId=outputDataId["patch"] + ) + + self.common = CommonComponents( + units=CoaddUnits.legacy, # until the ScaleZeroPointTask can scale it to nJy. + wcs=inputData["skyInfo"].patchInfo.wcs, + band=outputDataId.get("band", None), + identifiers=PatchIdentifiers.from_data_id(outputDataId), + ) + + returnStruct = self.run(**inputData) + butlerQC.put(returnStruct, outputRefs) + return returnStruct + + @staticmethod + def _compute_weight(maskedImage, statsCtrl): + """Compute a weight for a masked image. + + Parameters + ---------- + maskedImage : `~lsst.afw.image.MaskedImage` + The masked image to compute the weight. + statsCtrl : `~lsst.afw.math.StatisticsControl` + A control (config-like) object for StatisticsStack. + + Returns + ------- + weight : `float` + Inverse of the clipped mean variance of the masked image. + """ + statObj = afwMath.makeStatistics( + maskedImage.getVariance(), maskedImage.getMask(), afwMath.MEANCLIP, statsCtrl + ) + meanVar, _ = statObj.getResult(afwMath.MEANCLIP) + weight = 1.0 / float(meanVar) + return weight + + @staticmethod + def _construct_grid(skyInfo): + """Construct a UniformGrid object from a SkyInfo struct. + + Parameters + ---------- + skyInfo : `~lsst.pipe.base.Struct` + A Struct object + + Returns + ------- + grid : `~lsst.cell_coadds.UniformGrid` + A UniformGrid object. + """ + # grid has no notion about border or inner/outer boundaries. + # So we have to clip the outermost border when constructing the grid. + grid_bbox = skyInfo.patchInfo.outer_bbox.erodedBy(skyInfo.patchInfo.getCellBorder()) + grid = UniformGrid.from_bbox_cell_size(grid_bbox, skyInfo.patchInfo.getCellInnerDimensions()) + return grid + + def _construct_grid_container(self, skyInfo, statsCtrl): + """Construct a grid of AccumulatorMeanStack instances. + + Parameters + ---------- + skyInfo : `~lsst.pipe.base.Struct` + A Struct object + statsCtrl : `~lsst.afw.math.StatisticsControl` + A control (config-like) object for StatisticsStack. + + Returns + ------- + gc : `~lsst.cell_coadds.GridContainer` + A GridContainer object container one AccumulatorMeanStack per cell. + """ + grid = self._construct_grid(skyInfo) + + # Initialize the grid container with AccumulatorMeanStacks + gc = GridContainer[AccumulatorMeanStack](grid.shape) + for cellInfo in skyInfo.patchInfo: + stacker = AccumulatorMeanStack( + # The shape is for the numpy arrays, hence transposed. + shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width), + bit_mask_value=afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes), + calc_error_from_input_variance=self.config.calc_error_from_input_variance, + compute_n_image=False, + ) + gc[cellInfo.index] = stacker + + return gc + + def _construct_stats_control(self): + statsCtrl = afwMath.StatisticsControl() + statsCtrl.setAndMask(afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes)) + statsCtrl.setNanSafe(True) + return statsCtrl + + def run(self, inputWarps, skyInfo, **kwargs): + statsCtrl = self._construct_stats_control() + + gc = self._construct_grid_container(skyInfo, statsCtrl) + coadd_inputs_gc = GridContainer(gc.shape) + for cellInfo in skyInfo.patchInfo: + coadd_inputs = self.input_recorder.makeCoaddInputs() + # Reserve the absolute maximum of how many ccds, visits + # we could potentially have. + coadd_inputs.ccds.reserve(len(inputWarps)) + coadd_inputs.visits.reserve(len(inputWarps)) + coadd_inputs_gc[cellInfo.index] = coadd_inputs + # Read in one warp at a time, and accumulate it in all the cells that + # it completely overlaps. + + for warpRef in inputWarps: + warp = warpRef.get() + + # Pre-process the warp before coadding. + # Each Warp that goes into a coadd will typically have an + # independent photometric zero-point. Therefore, we must scale each + # Warp to set it to a common photometric zeropoint. + self.scale_zero_point.run(exposure=warp, dataRef=warpRef) + + # Coadd the warp onto the cells it completely overlaps. + edge = afwImage.Mask.getPlaneBitMask("EDGE") + for cellInfo in skyInfo.patchInfo: + bbox = cellInfo.outer_bbox + stacker = gc[cellInfo.index] + mi = warp[bbox].getMaskedImage() + + if (mi.getMask().array & edge).any(): + self.log.debug( + "Skipping %s in cell %s because it has an EDGE", warpRef.dataId, cellInfo.index + ) + continue + + weight = self._compute_weight(mi, statsCtrl) + if not np.isfinite(weight): + # Log at the debug level, because this can be quite common. + self.log.debug( + "Non-finite weight for %s in cell %s: skipping", warpRef.dataId, cellInfo.index + ) + continue + + stacker.add_masked_image(mi, weight=weight) + + coadd_inputs = coadd_inputs_gc[cellInfo.index] + self.input_recorder.addVisitToCoadd(coadd_inputs, warp[bbox], weight) + + del warp + + cells: list[SingleCellCoadd] = [] + for cellInfo in skyInfo.patchInfo: + stacker = gc[cellInfo.index] + cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox) + stacker.fill_stacked_masked_image(cell_masked_image) + + # Post-process the coadd before converting to new data structures. + if self.config.do_interpolate_coadd: + self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA") + # The variance must be positive; work around for DM-3201. + varArray = cell_masked_image.variance.array + with np.errstate(invalid="ignore"): + varArray[:] = np.where(varArray > 0, varArray, np.inf) + + # Finalize the PSF on the cell coadds. + coadd_inputs = coadd_inputs_gc[cellInfo.index] + coadd_inputs.ccds.sort() + coadd_inputs.visits.sort() + cell_coadd_psf = CoaddPsf(coadd_inputs.ccds, skyInfo.wcs, self.config.coadd_psf.makeControl()) + + image_planes = OwnedImagePlanes.from_masked_image(cell_masked_image) + identifiers = CellIdentifiers( + cell=cellInfo.index, + skymap=self.common.identifiers.skymap, + tract=self.common.identifiers.tract, + patch=self.common.identifiers.patch, + band=self.common.identifiers.band, + ) + + singleCellCoadd = SingleCellCoadd( + outer=image_planes, + psf=cell_coadd_psf.computeKernelImage(cell_coadd_psf.getAveragePosition()), + inner_bbox=cellInfo.inner_bbox, + inputs=None, # TODO + common=self.common, + identifiers=identifiers, + ) + # TODO: Attach transmission curve when they become available. + cells.append(singleCellCoadd) + + grid = self._construct_grid(skyInfo) + multipleCellCoadd = MultipleCellCoadd( + cells, + grid=grid, + outer_cell_size=cellInfo.outer_bbox.getDimensions(), + inner_bbox=None, + common=self.common, + psf_image_size=cells[0].psf_image.getDimensions(), + ) + + return Struct( + multipleCellCoadd=multipleCellCoadd, + )