From ee04fcef45315b7443d3a5cab692602b8a9b89ce Mon Sep 17 00:00:00 2001 From: Dan Baston Date: Fri, 10 Jan 2025 14:24:13 -0500 Subject: [PATCH] VRT processed dataset: unscale source raster values (#11623) --- .../processed_OutputBands_FROM_LAST_STEP.vrt | 2 +- autotest/gdrivers/vrtprocesseddataset.py | 135 +++++++++++++++++- .../drivers/raster/vrt_processed_dataset.rst | 2 +- frmts/vrt/data/gdalvrt.xsd | 12 +- frmts/vrt/vrtdataset.h | 3 + frmts/vrt/vrtprocesseddataset.cpp | 69 +++++++++ 6 files changed, 218 insertions(+), 5 deletions(-) diff --git a/autotest/gdrivers/data/vrt/processed_OutputBands_FROM_LAST_STEP.vrt b/autotest/gdrivers/data/vrt/processed_OutputBands_FROM_LAST_STEP.vrt index 0f68c6c8e697..c8284f55c73b 100644 --- a/autotest/gdrivers/data/vrt/processed_OutputBands_FROM_LAST_STEP.vrt +++ b/autotest/gdrivers/data/vrt/processed_OutputBands_FROM_LAST_STEP.vrt @@ -1,5 +1,5 @@ - + ../byte.tif diff --git a/autotest/gdrivers/vrtprocesseddataset.py b/autotest/gdrivers/vrtprocesseddataset.py index 633f2158a654..cf6c5f038351 100755 --- a/autotest/gdrivers/vrtprocesseddataset.py +++ b/autotest/gdrivers/vrtprocesseddataset.py @@ -10,18 +10,22 @@ # SPDX-License-Identifier: MIT ############################################################################### +import os + import gdaltest import pytest from osgeo import gdal +from .vrtderived import _validate + pytestmark = pytest.mark.skipif( not gdaltest.vrt_has_open_support(), reason="VRT driver open missing", ) np = pytest.importorskip("numpy") -pytest.importorskip("osgeo.gdal_array") +gdal_array = pytest.importorskip("osgeo.gdal_array") ############################################################################### # Test error cases in general VRTProcessedDataset XML structure @@ -73,6 +77,16 @@ def test_vrtprocesseddataset_errors(tmp_vsimem): src_ds.GetRasterBand(3).Fill(3) src_ds.Close() + with pytest.raises(Exception, match="Invalid value of 'unscale'"): + gdal.Open( + f""" + + {src_filename} + + + """ + ) + with pytest.raises(Exception, match="ProcessingSteps element missing"): gdal.Open( f""" @@ -1216,7 +1230,7 @@ def test_vrtprocesseddataset_serialize(tmp_vsimem): vrt_filename = str(tmp_vsimem / "the.vrt") content = f""" - + {src_filename} @@ -1517,3 +1531,120 @@ def test_vrtprocesseddataset_RasterIO(tmp_vsimem): assert ds.GetRasterBand(1).GetBlockSize() == [1, 1] with pytest.raises(Exception): ds.ReadAsArray() + + +############################################################################### +# Validate processed datasets according to xsd + + +@pytest.mark.parametrize( + "fname", + [ + f + for f in os.listdir(os.path.join(os.path.dirname(__file__), "data/vrt")) + if f.startswith("processed") + ], +) +def test_vrt_processeddataset_validate(fname): + with open(os.path.join("data/vrt", fname)) as f: + _validate(f.read()) + + +############################################################################### +# Test reading input datasets with scale and offset + + +@pytest.mark.parametrize( + "input_scaled", (True, False), ids=lambda x: f"input scaled={x}" +) +@pytest.mark.parametrize("unscale", (True, False, "auto"), ids=lambda x: f"unscale={x}") +@pytest.mark.parametrize( + "dtype", (gdal.GDT_Int16, gdal.GDT_Float32), ids=gdal.GetDataTypeName +) +def test_vrtprocesseddataset_scaled_inputs(tmp_vsimem, input_scaled, dtype, unscale): + + src_filename = tmp_vsimem / "src.tif" + + nx = 2 + ny = 3 + nz = 2 + + if dtype == gdal.GDT_Float32: + nodata = float("nan") + else: + nodata = 99 + + np_type = gdal_array.GDALTypeCodeToNumericTypeCode(dtype) + + data = np.arange(nx * ny * nz, dtype=np_type).reshape(nz, ny, nx) + data[:, 2, 1] = nodata + + if input_scaled: + offsets = [i + 2 for i in range(nz)] + scales = [(i + 1) / 4 for i in range(nz)] + else: + offsets = [0 for i in range(nz)] + scales = [1 for i in range(nz)] + + with gdal.GetDriverByName("GTiff").Create( + src_filename, nx, ny, nz, eType=dtype + ) as src_ds: + src_ds.WriteArray(data) + for i in range(src_ds.RasterCount): + bnd = src_ds.GetRasterBand(i + 1) + bnd.SetOffset(offsets[i]) + bnd.SetScale(scales[i]) + bnd.SetNoDataValue(nodata) + + ds = gdal.Open( + f""" + + + {src_filename} + + + + BandAffineCombination + 0,1,0 + 0,0,1 + + + """ + ) + + assert ds.RasterCount == nz + + if unscale is True or (unscale == "auto" and input_scaled): + for i in range(ds.RasterCount): + bnd = ds.GetRasterBand(i + 1) + assert bnd.DataType == gdal.GDT_Float64 + assert bnd.GetScale() in (None, 1) + assert bnd.GetOffset() in (None, 0) + else: + for i in range(ds.RasterCount): + bnd = ds.GetRasterBand(i + 1) + assert bnd.DataType == dtype + assert bnd.GetScale() == scales[i] + assert bnd.GetOffset() == offsets[i] + assert ( + np.isnan(bnd.GetNoDataValue()) + if np.isnan(nodata) + else bnd.GetNoDataValue() == nodata + ) + + result = np.ma.stack( + [ds.GetRasterBand(i + 1).ReadAsMaskedArray() for i in range(ds.RasterCount)] + ) + + if unscale: + expected = np.ma.masked_array( + np.stack([data[i, :, :] * scales[i] + offsets[i] for i in range(nz)]), + np.isnan(data) if np.isnan(nodata) else data == nodata, + ) + else: + expected = np.ma.masked_array( + data, np.isnan(data) if np.isnan(nodata) else data == nodata + ) + + np.testing.assert_array_equal(result.mask, expected.mask) + np.testing.assert_array_equal(result[~result.mask], expected[~expected.mask]) diff --git a/doc/source/drivers/raster/vrt_processed_dataset.rst b/doc/source/drivers/raster/vrt_processed_dataset.rst index 28f001904af7..b3383126763b 100644 --- a/doc/source/drivers/raster/vrt_processed_dataset.rst +++ b/doc/source/drivers/raster/vrt_processed_dataset.rst @@ -121,7 +121,7 @@ The following child elements of ``VRTDataset`` may be defined: ``SRS``, ``GeoTra The ``VRTDataset`` root element must also have the 2 following child elements: -- ``Input``, which must have one and only one of the following ``SourceFilename`` or ``VRTDataset`` as child elements, to define the input dataset to which to apply the processing steps. +- ``Input``, which must have one and only one of the following ``SourceFilename`` or ``VRTDataset`` as child elements, to define the input dataset to which to apply the processing steps. Starting with GDAL 3.11, values from the input dataset will be automatically unscaled; this can be disabled by setting the ``unscale`` attribute of ``Input`` to ``false``. - ``ProcessingSteps``, with at least one child ``Step`` element. diff --git a/frmts/vrt/data/gdalvrt.xsd b/frmts/vrt/data/gdalvrt.xsd index ad9d0aa63902..6cc250615091 100644 --- a/frmts/vrt/data/gdalvrt.xsd +++ b/frmts/vrt/data/gdalvrt.xsd @@ -198,6 +198,16 @@ + + + + YES, NO, or AUTO. + If not specified, AUTO is the default and will result in + unscaling all input bands to Float64 if any input band has + a defined scale/offset. + + + @@ -249,7 +259,7 @@ - Allowed names are specific of each processing function + Allowed names are specific to each processing function diff --git a/frmts/vrt/vrtdataset.h b/frmts/vrt/vrtdataset.h index 50fbac8e0a46..aa6852b6b80c 100644 --- a/frmts/vrt/vrtdataset.h +++ b/frmts/vrt/vrtdataset.h @@ -724,6 +724,9 @@ class VRTProcessedDataset final : public VRTDataset //! Directory of the VRT std::string m_osVRTPath{}; + //! Source of source dataset generated with GDALTranslate + std::unique_ptr m_poVRTSrcDS{}; + //! Source dataset std::unique_ptr m_poSrcDS{}; diff --git a/frmts/vrt/vrtprocesseddataset.cpp b/frmts/vrt/vrtprocesseddataset.cpp index 318da09023aa..ce84a6ac5c57 100644 --- a/frmts/vrt/vrtprocesseddataset.cpp +++ b/frmts/vrt/vrtprocesseddataset.cpp @@ -12,6 +12,7 @@ #include "cpl_minixml.h" #include "cpl_string.h" +#include "gdal_utils.h" #include "vrtdataset.h" #include @@ -206,6 +207,27 @@ CPLErr VRTProcessedDataset::XMLInit(const CPLXMLNode *psTree, return CE_None; } +static bool HasScaleOffset(GDALDataset &oSrcDS) +{ + for (int i = 1; i <= oSrcDS.GetRasterCount(); i++) + { + int pbSuccess; + GDALRasterBand &oBand = *oSrcDS.GetRasterBand(i); + double scale = oBand.GetScale(&pbSuccess); + if (pbSuccess && scale != 1) + { + return true; + } + double offset = oBand.GetOffset(&pbSuccess); + if (pbSuccess && offset != 0) + { + return true; + } + } + + return false; +} + /** Instantiate object from XML tree */ CPLErr VRTProcessedDataset::Init(const CPLXMLNode *psTree, const char *pszVRTPathIn, @@ -260,6 +282,53 @@ CPLErr VRTProcessedDataset::Init(const CPLXMLNode *psTree, if (!m_poSrcDS) return CE_Failure; + const char *pszUnscale = CPLGetXMLValue(psInput, "unscale", "AUTO"); + bool bUnscale = false; + if (EQUAL(pszUnscale, "AUTO")) + { + if (HasScaleOffset(*m_poSrcDS)) + { + bUnscale = true; + } + } + else if (EQUAL(pszUnscale, "YES") || EQUAL(pszUnscale, "ON") || + EQUAL(pszUnscale, "TRUE") || EQUAL(pszUnscale, "1")) + { + bUnscale = true; + } + else if (!(EQUAL(pszUnscale, "NO") || EQUAL(pszUnscale, "OFF") || + EQUAL(pszUnscale, "FALSE") || EQUAL(pszUnscale, "0"))) + { + CPLError(CE_Failure, CPLE_AppDefined, "Invalid value of 'unscale'"); + return CE_Failure; + } + + if (bUnscale) + { + CPLStringList oArgs; + oArgs.AddString("-unscale"); + oArgs.AddString("-ot"); + oArgs.AddString("Float64"); + oArgs.AddString("-of"); + oArgs.AddString("VRT"); + oArgs.AddString("-a_nodata"); + oArgs.AddString("nan"); + auto *poArgs = GDALTranslateOptionsNew(oArgs.List(), nullptr); + int pbUsageError; + CPLAssert(poArgs); + m_poVRTSrcDS.reset(m_poSrcDS.release()); + // https://trac.cppcheck.net/ticket/11325 + // cppcheck-suppress accessMoved + m_poSrcDS.reset(GDALDataset::FromHandle( + GDALTranslate("", m_poVRTSrcDS.get(), poArgs, &pbUsageError))); + GDALTranslateOptionsFree(poArgs); + + if (pbUsageError || !m_poSrcDS) + { + return CE_Failure; + } + } + if (nRasterXSize == 0 && nRasterYSize == 0) { nRasterXSize = m_poSrcDS->GetRasterXSize();