diff --git a/CHANGES.rst b/CHANGES.rst index 6d8a997af..f8727b6d8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,8 +6,11 @@ SCICO Release Notes Version 0.0.4 (unreleased) ---------------------------- -• New methods and a function for computing Jacobian-vector products for - `Operator` objects. +• New `Function` class for representing array-to-array mappings with more than + one input. +• New methods and a function for computing Jacobian-vector products for `Operator` + objects. +• New proximal ADMM solvers. diff --git a/data b/data index aa85087b4..198923b6a 160000 --- a/data +++ b/data @@ -1 +1 @@ -Subproject commit aa85087b471f09a4165f1e465ed69ae33ec95183 +Subproject commit 198923b6ab894b14fbf4be87f5f84988927c8564 diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 6901e9312..b42b94e72 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -72,6 +72,7 @@ Miscellaneous examples/denoise_tv_admm examples/denoise_tv_pgm examples/denoise_tv_multi + examples/denoise_cplx_tv_nlpadmm examples/denoise_cplx_tv_pdhg examples/denoise_dncnn_universal examples/video_rpca_admm @@ -120,6 +121,7 @@ Total Variation examples/denoise_tv_admm examples/denoise_tv_pgm examples/denoise_tv_multi + examples/denoise_cplx_tv_nlpadmm examples/denoise_cplx_tv_pdhg @@ -182,6 +184,25 @@ Linearized ADMM examples/denoise_tv_multi +Proximal ADMM +^^^^^^^^^^^^^ + +.. toctree:: + :maxdepth: 1 + + examples/denoise_tv_multi + examples/denoise_cplx_tv_nlpadmm + + +Non-linear Proximal ADMM +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. toctree:: + :maxdepth: 1 + + examples/denoise_cplx_tv_nlpadmm + + PDHG ^^^^ diff --git a/docs/source/references.bib b/docs/source/references.bib index 8b598d432..53e7c728b 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -83,8 +83,7 @@ @InCollection{beck-2010-gradient publisher = {Cambridge University Press}, year = 2010, doi = {10.1017/CBO9780511804458.003}, - url = - {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf} + url = {http://www.math.tau.ac.il/~teboulle/papers/gradient_chapter.pdf} } @Book {beck-2017-first, @@ -110,6 +109,17 @@ @Software {bradbury-2018-jax year = {2018} } +@InProceedings {benning-2016-preconditioned, + title = {Preconditioned {ADMM} with nonlinear operator + constraint}, + author = {Benning, Martin and Knoll, Florian and + Sch{\"o}nlieb, Carola-Bibiane and Valkonen, Tuomo}, + booktitle = {IFIP Conference on System Modeling and Optimization (CSMO) 2015}, + pages = {117--126}, + year = 2016, + doi = {10.1007/978-3-319-55795-3_10} +} + @Article {boyd-2010-distributed, title = {Distributed optimization and statistical learning via the alternating direction method of multipliers}, @@ -206,6 +216,20 @@ @Article {daubechies-2004-iterative doi = {10.1002/cpa.20042} } +@Article {deng-2015-global, + author = {Wei Deng and Wotao Yin}, + title = {On the Global and Linear Convergence of the + Generalized Alternating Direction Method of + Multipliers}, + journal = {Journal of Scientific Computing}, + year = 2015, + month = May, + volume = 66, + number = 3, + pages = {889--916}, + doi = {10.1007/s10915-015-0048-x}, +} + @Article {esser-2010-general, author = {Ernie Esser and Xiaoqun Zhang and Tony F. Chan}, title = {A General Framework for a Class of First Order @@ -345,10 +369,10 @@ @Article {kamilov-2022-plug T. Buzzard and Brendt Wohlberg}, title = {Plug-and-Play Methods for Integrating Physical and Learned Models in Computational Imaging}, - journal = {IEEE Signal Processing Magazine}, + journal = {IEEE Signal Processing Magazine}, year = 2022, eprint = {arXiv:2203.17061}, - note = {To appear.} + note = {To appear.} } @Article {liu-2018-first, @@ -375,7 +399,7 @@ @Article {maggioni-2012-nonlocal number = 1, pages = {119--133}, year = 2012, - doi = {10.1109/TIP.2012.2210725} + doi = {10.1109/TIP.2012.2210725} } @InProceedings {makinen-2019-exact, @@ -428,7 +452,7 @@ @Book {nocedal-2006-numerical @Book {paganin-2006-coherent, doi = {10.1093/acprof:oso/9780198567288.001.0001}, - isbn = 9780198567288, + isbn = {9780198567288}, year = 2006, month = Jan, publisher = {Oxford University Press}, @@ -541,7 +565,7 @@ @Article {valkonen-2014-primal journal = {Inverse Problems}, volume = 30, number = 5, - pages = 055012, + pages = {055012}, year = 2014, doi = {10.1088/0266-5611/30/5/055012} } diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst index 75e0fd064..5480bbe7c 100644 --- a/examples/scripts/README.rst +++ b/examples/scripts/README.rst @@ -76,8 +76,10 @@ Miscellaneous Total Variation Denoising with Constraint (APGM) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising + `denoise_cplx_tv_nlpadmm.py `_ + Complex Total Variation Denoising with NLPADMM Solver `denoise_cplx_tv_pdhg.py `_ - Complex Total Variation Denoising + Complex Total Variation Denoising with PDHG Solver `denoise_dncnn_universal.py `_ Comparison of DnCNN Variants for Image Denoising `video_rpca_admm.py `_ @@ -140,8 +142,10 @@ Total Variation Total Variation Denoising with Constraint (APGM) `denoise_tv_multi.py `_ Comparison of Optimization Algorithms for Total Variation Denoising + `denoise_cplx_tv_nlpadmm.py `_ + Complex Total Variation Denoising with NLPADMM Solver `denoise_cplx_tv_pdhg.py `_ - Complex Total Variation Denoising + Complex Total Variation Denoising with PDHG Solver Sparsity diff --git a/examples/scripts/denoise_cplx_tv_nlpadmm.py b/examples/scripts/denoise_cplx_tv_nlpadmm.py new file mode 100644 index 000000000..5d33769dd --- /dev/null +++ b/examples/scripts/denoise_cplx_tv_nlpadmm.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# This file is part of the SCICO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +r""" +Complex Total Variation Denoising with NLPADMM Solver +===================================================== + +This example demonstrates solution of a problem of the form + +$$\argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; \text{such that}\; +H(\mb{x}, \mb{z}) = 0 \;,$$ + +where $H$ is a nonlinear function, via a variant of the proximal ADMM +algorithm for problems with a non-linear operator constraint +:cite:`benning-2016-preconditioned`. The example problem represents +total variation (TV) denoising applied to a complex image with +piece-wise smooth magnitude and non-smooth phase. (This example is rather +contrived, and was not constructed to represent a specific real imaging +problem, but it does have some properties in common with synthetic +aperture radar single look complex data in which the magnitude has much +more discernible structure than the phase.) The appropriate TV denoising +formulation for this problem is + +$$\argmin_{\mb{x}} \; (1/2) \| \mb{y} - \mb{x} \|_2^2 + \lambda +\| C(\mb{x}) \|_{2,1} \;,$$ + +where $\mb{y}$ is the measurement, $\|\cdot\|_{2,1}$ is the +$\ell_{2,1}$ mixed norm, and $C$ is a non-linear operator consisting of +a linear difference operator applied to the magnitude of a complex array. +This problem is represented in the form above by taking $H(\mb{x}, +\mb{z}) = C(\mb{x}) - \mb{z}$. The standard TV solution, which is +also computed for comparison purposes, gives very poor results since +the difference is applied independently to real and imaginary +components of the complex image. +""" + + +from mpl_toolkits.axes_grid1 import make_axes_locatable +from xdesign import SiemensStar, discrete_phantom + +import scico.numpy as snp +import scico.random +from scico import function, functional, linop, loss, metric, operator, plot +from scico.examples import phase_diff +from scico.optimize import NonLinearPADMM, ProximalADMM +from scico.util import device_info + +""" +Create a ground truth image. +""" +N = 256 # image size +phantom = SiemensStar(16) +x_mag = snp.pad(discrete_phantom(phantom, N - 16), 8) + 1.0 +x_mag /= x_mag.max() +# Create reference image with structured magnitude and random phase +x_gt = x_mag * snp.exp(-1j * scico.random.randn(x_mag.shape, seed=0)[0]) + + +""" +Add noise to create a noisy test image. +""" +σ = 0.25 # noise standard deviation +noise, key = scico.random.randn(x_gt.shape, seed=1, dtype=snp.complex64) +y = x_gt + σ * noise + + +""" +Denoise with standard total variation. +""" +λ_tv = 6e-2 +f = loss.SquaredL2Loss(y=y) +g = λ_tv * functional.L21Norm() +# The append=0 option makes the results of horizontal and vertical finite +# differences the same shape, which is required for the L21Norm. +C = linop.FiniteDifference(input_shape=y.shape, input_dtype=snp.complex64, append=0) + +solver_tv = ProximalADMM( + f=f, + g=g, + A=C, + B=None, + rho=1.0, + mu=8.0, + nu=1.0, + maxiter=200, + itstat_options={"display": True, "period": 20}, +) +print(f"Solving on {device_info()}\n") +x_tv = solver_tv.solve() +hist_tv = solver_tv.itstat_object.history(transpose=True) + + +""" +Denoise with total variation applied to the magnitude of a complex image. +""" +λ_nltv = 2e-1 +g = λ_nltv * functional.L21Norm() +# Redefine C for real input (now applied to magnitude of a complex array) +C = linop.FiniteDifference(input_shape=y.shape, input_dtype=snp.float32, append=0) +# Operator computing differences of absolute values +D = C @ operator.Abs(input_shape=x_gt.shape, input_dtype=snp.complex64) +# Constraint function imposing z = D(x) constraint +H = function.Function( + (C.shape[1], C.shape[0]), + output_shape=C.shape[0], + eval_fn=lambda x, z: D(x) - z, + input_dtypes=(snp.complex64, snp.float32), + output_dtype=snp.float32, +) + +solver_nltv = NonLinearPADMM( + f=f, + g=g, + H=H, + rho=5.0, + mu=6.0, + nu=1.0, + maxiter=200, + itstat_options={"display": True, "period": 20}, +) +x_nltv = solver_nltv.solve() +hist_nltv = solver_nltv.itstat_object.history(transpose=True) + + +""" +Plot results. +""" +fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) +plot.plot( + snp.vstack((hist_tv.Objective, hist_nltv.Objective)).T, + ptyp="semilogy", + title="Objective function", + xlbl="Iteration", + lgnd=("Standard TV", "Magnitude TV"), + fig=fig, + ax=ax[0], +) +plot.plot( + snp.vstack((hist_tv.Prml_Rsdl, hist_nltv.Prml_Rsdl)).T, + ptyp="semilogy", + title="Primal residual", + xlbl="Iteration", + lgnd=("Standard TV", "Magnitude TV"), + fig=fig, + ax=ax[1], +) +plot.plot( + snp.vstack((hist_tv.Dual_Rsdl, hist_nltv.Dual_Rsdl)).T, + ptyp="semilogy", + title="Dual residual", + xlbl="Iteration", + lgnd=("Standard TV", "Magnitude TV"), + fig=fig, + ax=ax[2], +) +fig.show() + + +fig, ax = plot.subplots(nrows=2, ncols=4, figsize=(20, 10)) +norm = plot.matplotlib.colors.Normalize( + vmin=min(snp.abs(x_gt).min(), snp.abs(y).min(), snp.abs(x_tv).min(), snp.abs(x_nltv).min()), + vmax=max(snp.abs(x_gt).max(), snp.abs(y).max(), snp.abs(x_tv).max(), snp.abs(x_nltv).max()), +) +plot.imview(snp.abs(x_gt), title="Ground truth", cbar=None, fig=fig, ax=ax[0, 0], norm=norm) +plot.imview( + snp.abs(y), + title="Measured: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(y)), + cbar=None, + fig=fig, + ax=ax[0, 1], + norm=norm, +) +plot.imview( + snp.abs(x_tv), + title="Standard TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_tv)), + cbar=None, + fig=fig, + ax=ax[0, 2], + norm=norm, +) +plot.imview( + snp.abs(x_nltv), + title="Magnitude TV: PSNR %.2f (dB)" % metric.psnr(snp.abs(x_gt), snp.abs(x_nltv)), + cbar=None, + fig=fig, + ax=ax[0, 3], + norm=norm, +) +divider = make_axes_locatable(ax[0, 3]) +cax = divider.append_axes("right", size="5%", pad=0.2) +fig.colorbar(ax[0, 3].get_images()[0], cax=cax) +norm = plot.matplotlib.colors.Normalize( + vmin=min(snp.angle(x_gt).min(), snp.angle(x_tv).min(), snp.angle(x_nltv).min()), + vmax=max(snp.angle(x_gt).max(), snp.angle(x_tv).max(), snp.angle(x_nltv).max()), +) +plot.imview( + snp.angle(x_gt), + title="Ground truth", + cbar=None, + fig=fig, + ax=ax[1, 0], + norm=norm, +) +plot.imview( + snp.angle(y), + title="Measured: Mean phase diff. %.2f" % phase_diff(snp.angle(x_gt), snp.angle(y)).mean(), + cbar=None, + fig=fig, + ax=ax[1, 1], + norm=norm, +) +plot.imview( + snp.angle(x_tv), + title="Standard TV: Mean phase diff. %.2f" + % phase_diff(snp.angle(x_gt), snp.angle(x_tv)).mean(), + cbar=None, + fig=fig, + ax=ax[1, 2], + norm=norm, +) +plot.imview( + snp.angle(x_nltv), + title="Magnitude TV: Mean phase diff. %.2f" + % phase_diff(snp.angle(x_gt), snp.angle(x_nltv)).mean(), + cbar=None, + fig=fig, + ax=ax[1, 3], + norm=norm, +) +divider = make_axes_locatable(ax[1, 3]) +cax = divider.append_axes("right", size="5%", pad=0.2) +fig.colorbar(ax[1, 3].get_images()[0], cax=cax) +ax[0, 0].set_ylabel("Magnitude") +ax[1, 0].set_ylabel("Phase") +fig.tight_layout() +fig.show() + + +input("\nWaiting for input to close figures and exit") diff --git a/examples/scripts/denoise_cplx_tv_pdhg.py b/examples/scripts/denoise_cplx_tv_pdhg.py index 880f5839d..8bb6a02a1 100644 --- a/examples/scripts/denoise_cplx_tv_pdhg.py +++ b/examples/scripts/denoise_cplx_tv_pdhg.py @@ -5,8 +5,8 @@ # with the package. r""" -Complex Total Variation Denoising -================================= +Complex Total Variation Denoising with PDHG Solver +================================================== This example demonstrates solution of a problem of the form diff --git a/examples/scripts/denoise_tv_multi.py b/examples/scripts/denoise_tv_multi.py index 6f9f4aa66..194832896 100644 --- a/examples/scripts/denoise_tv_multi.py +++ b/examples/scripts/denoise_tv_multi.py @@ -8,8 +8,10 @@ Comparison of Optimization Algorithms for Total Variation Denoising =================================================================== -This example compares the performance of ADMM, Linearized ADMM, and PDHG -in solving the isotropic total variation (TV) denoising problem +This example compares the performance of alternating direction method of +multipliers (ADMM), linearized ADMM, proximal ADMM, and primal–dual +hybrid gradient (PDHG) in solving the isotropic total variation (TV) +denoising problem $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} \|_2^2 + \lambda R(\mathbf{x}) \;,$$ @@ -25,7 +27,7 @@ import scico.numpy as snp import scico.random from scico import functional, linop, loss, plot -from scico.optimize import PDHG, LinearizedADMM +from scico.optimize import PDHG, LinearizedADMM, ProximalADMM from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.util import device_info @@ -58,10 +60,15 @@ """ -For reasons that are not entirely clear, the first step of the first-run -solver is much slower than the following steps. Perform a preliminary -solver step, the result of which is discarded, to avoid this bias in the -timing results. +The first step of the first-run solver is much slower than the +following steps, presumably due to just-in-time compilation of +relevant operators in first use. The code below performs a preliminary +solver step, the result of which is discarded, to reduce this bias in +the timing results. The precise cause of the remaining differences in +time required to compute the first step of each algorithm is unknown, +but it is worth noting that this difference becomes negligible when +just-in-time compilation is disabled (e.g. via the JAX_DISABLE_JIT +environment variable). """ solver_admm = ADMM( f=f, @@ -90,6 +97,7 @@ itstat_options={"display": True, "period": 10}, ) print(f"Solving on {device_info()}\n") +print("ADMM solver") solver_admm.solve() hist_admm = solver_admm.itstat_object.history(transpose=True) @@ -107,22 +115,46 @@ maxiter=200, itstat_options={"display": True, "period": 10}, ) +print("Linearized ADMM solver") solver_ladmm.solve() hist_ladmm = solver_ladmm.itstat_object.history(transpose=True) +""" +Solve via Proximal ADMM. +""" +mu, nu = ProximalADMM.estimate_parameters(C) +solver_padmm = ProximalADMM( + f=f, + g=g, + A=C, + B=None, + rho=1e0, + mu=mu, + nu=nu, + x0=y, + maxiter=200, + itstat_options={"display": True, "period": 10}, +) +print("Proximal ADMM solver") +solver_padmm.solve() +hist_padmm = solver_padmm.itstat_object.history(transpose=True) + + """ Solve via PDHG. """ +tau, sigma = PDHG.estimate_parameters(C, factor=1.5) solver_pdhg = PDHG( f=f, g=g, C=C, - tau=4e-1, - sigma=4e-1, + tau=tau, + sigma=sigma, maxiter=200, itstat_options={"display": True, "period": 10}, ) +print("PDHG solver") solver_pdhg.solve() hist_pdhg = solver_pdhg.itstat_object.history(transpose=True) @@ -131,37 +163,45 @@ Plot results. It is worth noting that: 1. PDHG outperforms ADMM both with respect to iterations and time. -2. ADMM greatly outperforms Linearized ADMM with respect to iterations. -3. ADMM slightly outperforms Linearized ADMM with respect to time. This is +2. Proximal ADMM has similar performance to PDHG with respect to iterations, + but is slightly inferior with respect to time. +3. ADMM greatly outperforms Linearized ADMM with respect to iterations. +4. ADMM slightly outperforms Linearized ADMM with respect to time. This is possible because the ADMM $\mathbf{x}$-update can be solved relatively cheaply, with only 2 CG iterations. If more CG iterations were required, the time comparison would be favorable to Linearized ADMM. """ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( - snp.vstack((hist_admm.Objective, hist_ladmm.Objective, hist_pdhg.Objective)).T, + snp.vstack( + (hist_admm.Objective, hist_ladmm.Objective, hist_padmm.Objective, hist_pdhg.Objective) + ).T, ptyp="semilogy", title="Objective function", xlbl="Iteration", - lgnd=("ADMM", "LinADMM", "PDHG"), + lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[0], ) plot.plot( - snp.vstack((hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)).T, + snp.vstack( + (hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_padmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl) + ).T, ptyp="semilogy", title="Primal residual", xlbl="Iteration", - lgnd=("ADMM", "LinADMM", "PDHG"), + lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[1], ) plot.plot( - snp.vstack((hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)).T, + snp.vstack( + (hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_padmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl) + ).T, ptyp="semilogy", title="Dual residual", xlbl="Iteration", - lgnd=("ADMM", "LinADMM", "PDHG"), + lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[2], ) @@ -169,32 +209,38 @@ fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(27, 6)) plot.plot( - snp.vstack((hist_admm.Objective, hist_ladmm.Objective, hist_pdhg.Objective)).T, - snp.vstack((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T, + snp.vstack( + (hist_admm.Objective, hist_ladmm.Objective, hist_padmm.Objective, hist_pdhg.Objective) + ).T, + snp.vstack((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Objective function", xlbl="Time (s)", - lgnd=("ADMM", "LinADMM", "PDHG"), + lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[0], ) plot.plot( - snp.vstack((hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl)).T, - snp.vstack((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T, + snp.vstack( + (hist_admm.Prml_Rsdl, hist_ladmm.Prml_Rsdl, hist_padmm.Prml_Rsdl, hist_pdhg.Prml_Rsdl) + ).T, + snp.vstack((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Primal residual", xlbl="Time (s)", - lgnd=("ADMM", "LinADMM", "PDHG"), + lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[1], ) plot.plot( - snp.vstack((hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl)).T, - snp.vstack((hist_admm.Time, hist_ladmm.Time, hist_pdhg.Time)).T, + snp.vstack( + (hist_admm.Dual_Rsdl, hist_ladmm.Dual_Rsdl, hist_padmm.Dual_Rsdl, hist_pdhg.Dual_Rsdl) + ).T, + snp.vstack((hist_admm.Time, hist_ladmm.Time, hist_padmm.Time, hist_pdhg.Time)).T, ptyp="semilogy", title="Dual residual", xlbl="Time (s)", - lgnd=("ADMM", "LinADMM", "PDHG"), + lgnd=("ADMM", "LinADMM", "ProxADMM", "PDHG"), fig=fig, ax=ax[2], ) diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst index 488b5e553..067b49015 100644 --- a/examples/scripts/index.rst +++ b/examples/scripts/index.rst @@ -50,6 +50,7 @@ Miscellaneous - denoise_tv_admm.py - denoise_tv_pgm.py - denoise_tv_multi.py + - denoise_cplx_tv_nlpadmm.py - denoise_cplx_tv_pdhg.py - denoise_dncnn_universal.py - video_rpca_admm.py @@ -89,6 +90,7 @@ Total Variation - denoise_tv_admm.py - denoise_tv_pgm.py - denoise_tv_multi.py + - denoise_cplx_tv_nlpadmm.py - denoise_cplx_tv_pdhg.py @@ -139,6 +141,19 @@ Linearized ADMM - denoise_tv_multi.py +Proximal ADMM +^^^^^^^^^^^^^ + + - denoise_tv_multi.py + - denoise_cplx_tv_nlpadmm.py + + +Non-linear Proximal ADMM +^^^^^^^^^^^^^^^^^^^^^^^^ + + - denoise_cplx_tv_nlpadmm.py + + PDHG ^^^^ diff --git a/scico/function.py b/scico/function.py new file mode 100644 index 000000000..ce41d9e5f --- /dev/null +++ b/scico/function.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Function class.""" + +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +import jax + +import scico.numpy as snp +from scico.linop import LinearOperator, jacobian +from scico.numpy import BlockArray +from scico.operator import Operator +from scico.typing import BlockShape, DType, JaxArray, Shape + + +class Function: + r"""Function class. + + A :class:`Function` maps multiple :code:`array-like` arguments to + another :code:`array-like`. It is more general than both + :class:`.Functional`, which is a mapping to a scalar, and + :class:`.Operator`, which takes a single argument. + """ + + def __init__( + self, + input_shapes: Sequence[Union[Shape, BlockShape]], + output_shape: Optional[Union[Shape, BlockShape]] = None, + eval_fn: Optional[Callable] = None, + input_dtypes: Union[DType, Sequence[DType]] = snp.float32, + output_dtype: Optional[DType] = None, + jit: bool = False, + ): + """ + Args: + input_shapes: Shapes of input arrays. + output_shape: Shape of output array. Defaults to ``None``. + If ``None``, `output_shape` is determined by evaluating + `self.__call__` on input arrays of zeros. + eval_fn: Function used in evaluating this :class:`Function`. + Defaults to ``None``. Required unless `__init__` is being + called from a derived class with an `_eval` method. + input_dtypes: `dtype` for input argument. If a single `dtype` + is specified, it implies a common `dtype` for all inputs, + otherwise a list or tuple of values should be provided, + one per input. Defaults to ``float32``. + output_dtype: `dtype` for output argument. Defaults to + ``None``. If ``None``, `output_dtype` is determined by + evaluating `self.__call__` on an input arrays of zeros. + jit: If ``True``, jit the evaluation function. + """ + self.jit = jit + self.input_shapes = input_shapes + if isinstance(input_dtypes, (list, tuple)): + self.input_dtypes = input_dtypes + else: + self.input_dtypes = (input_dtypes,) * len(input_shapes) + + if eval_fn is not None: + self._eval = jax.jit(eval_fn) if jit else eval_fn + elif not hasattr(self, "_eval"): + raise NotImplementedError( + "Function is an abstract base class when the eval_fn parameter is not specified." + ) + + # If the output shape or dtype isn't specified, it can be + # inferred by calling the evaluation function. + if output_shape is None or output_dtype is None: + zeros = [ + snp.zeros(shape, dtype=dtype) + for (shape, dtype) in zip(self.input_shapes, self.input_dtypes) + ] + tmp = self._eval(*zeros) + if output_shape is None: + self.output_shape = tmp.shape # type: ignore + else: + self.output_shape = output_shape + if output_dtype is None: + self.output_dtype = tmp.dtype + else: + self.output_dtype = output_dtype + + def __repr__(self): + return f"""{type(self)} +input_shapes : {self.input_shapes} +input_dtypes : {self.input_dtypes} +output_shape : {self.output_shape} +output_dtype : {self.output_dtype} + """ + + def __call__(self, *args: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: + """Evaluate this function with the specified parameters. + + Args: + *args: Parameters at which to evaluate the function. + + Returns: + Value of function with specified parameters. + """ + return self._eval(*args) + + def slice(self, index: int, *fix_args: Union[JaxArray, BlockArray]) -> Operator: + """Fix all but one parameter, returning a :class:`.Operator`. + + Args: + index: Index of parameter that remains free. + *fix_args: Fixed values for remaining parameters. + + Returns: + An :class:`.Operator` taking the free parameter of the + :class:`Function` as its input. + """ + + def pfunc(var_arg): + args = fix_args[0:index] + (var_arg,) + fix_args[index:] + return self._eval(*args) + + return Operator( + self.input_shapes[index], + output_shape=self.output_shape, + eval_fn=pfunc, + input_dtype=self.input_dtypes[index], + output_dtype=self.output_dtype, + jit=self.jit, + ) + + def join(self) -> Operator: + """Combine inputs into a :class:`.BlockArray`. + + Construct an equivalent :class:`.Operator` taking a single + :class:`.BlockArray` input combining all inputs of this + :class:`Function`. + + Returns: + An :class:`.Operator` taking a :class:`.BlockArray` as its + input. + """ + for dtype in self.input_dtypes[1:]: + if dtype != self.input_dtypes[0]: + raise ValueError( + "The join method may only be applied to Functions that have " + "homogenous input dtypes." + ) + + def jfunc(blkarr): + return self._eval(*blkarr.arrays) + + return Operator( + self.input_shapes, # type: ignore + output_shape=self.output_shape, + eval_fn=jfunc, + input_dtype=self.input_dtypes[0], + output_dtype=self.output_dtype, + jit=self.jit, + ) + + def jvp( + self, index: int, v: Union[JaxArray, BlockArray], *args: Union[JaxArray, BlockArray] + ) -> Tuple[Union[JaxArray, BlockArray], Union[JaxArray, BlockArray]]: + """Jacobian-vector product with respect to a single parameter. + + Compute a Jacobian-vector product with respect to a single + parameter of a :class:`Function`. Note that the order of the + parameters specifying where to evaluate the Jacobian and the + vector in the product is reverse with respect to :func:`jax.jvp`. + + Args: + index: Index of parameter with respect to which the Jacobian + is to be computed. + v: Vector against which the Jacobian-vector product is to be + computed. + *args: Values of function parameters at which Jacobian is to + be computed. + + Returns: + A pair consisting of the operator evaluated at the parameters + specified by `*args` and the Jacobian-vector product. + """ + var_arg = args[index] + fix_args = args[0:index] + args[(index + 1) :] + F = self.slice(index, *fix_args) + return F.jvp(var_arg, v) + + def vjp( + self, index: int, *args: Union[JaxArray, BlockArray], conjugate: Optional[bool] = True + ) -> Tuple[Tuple[Any, ...], Callable]: + """Vector-Jacobian product with respect to a single parameter. + + Compute a vector-Jacobian product with respect to a single + parameter of a :class:`Function`. + + Args: + index: Index of parameter with respect to which the Jacobian + is to be computed. + *args: Values of function parameters at which Jacobian is to + be computed. + conjugate: If ``True``, compute the product using the + conjugate (Hermitian) transpose. + + Returns: + A pair consisting of the operator evaluated at the parameters + specified by `*args` and a function that computes the + vector-Jacobian product. + """ + var_arg = args[index] + fix_args = args[0:index] + args[(index + 1) :] + F = self.slice(index, *fix_args) + return F.vjp(var_arg, conjugate=conjugate) + + def jacobian( + self, index: int, *args: Union[JaxArray, BlockArray], include_eval: Optional[bool] = False + ) -> LinearOperator: + """Construct Jacobian linear operator for the function. + + Construct a Jacobian :class:`.LinearOperator` that computes + vector products with the Jacobian with respect to a specified + variable of the function. + + Args: + index: Index of parameter with respect to which the Jacobian + is to be computed. + *args: Values of function parameters at which Jacobian is to + be computed. + include_eval: Flag indicating whether the result of evaluating + the :class:`.Operator` should be included (as the first + component of a :class:`.BlockArray`) in the output of the + Jacobian :class:`.LinearOperator` constructed by this + function. + + Returns: + A :class:`.LinearOperator` capable of computing Jacobian-vector + products. + """ + var_arg = args[index] + fix_args = args[0:index] + args[(index + 1) :] + F = self.slice(index, *fix_args) + return jacobian(F, var_arg, include_eval=include_eval) diff --git a/scico/linop/_util.py b/scico/linop/_util.py index f16effc1e..a0c6fd521 100644 --- a/scico/linop/_util.py +++ b/scico/linop/_util.py @@ -46,8 +46,13 @@ def power_iteration(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey for i in range(maxiter): Av = A @ v + normAv = snp.linalg.norm(Av) + if normAv == 0.0: # Assume that ||Av|| == 0 implies A is a zero operator + mu = 0.0 + v = Av + break mu = snp.sum(v.conj() * Av) / snp.linalg.norm(v) ** 2 - v = Av / snp.linalg.norm(Av) + v = Av / normAv return mu, v @@ -81,7 +86,7 @@ def operator_norm(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] float: Norm of operator :math:`A`. """ - return snp.sqrt(power_iteration(A.H @ A, maxiter, key)[0]) + return snp.sqrt(power_iteration(A.H @ A, maxiter, key)[0].real) def valid_adjoint( diff --git a/scico/optimize/__init__.py b/scico/optimize/__init__.py index 6f6399a5b..2c6ae0db1 100644 --- a/scico/optimize/__init__.py +++ b/scico/optimize/__init__.py @@ -14,9 +14,18 @@ from ._ladmm import LinearizedADMM from .pgm import PGM, AcceleratedPGM from ._primaldual import PDHG +from ._padmm import ProximalADMM, NonLinearPADMM -__all__ = ["ADMM", "LinearizedADMM", "PGM", "AcceleratedPGM", "PDHG"] +__all__ = [ + "ADMM", + "LinearizedADMM", + "ProximalADMM", + "NonLinearPADMM", + "PGM", + "AcceleratedPGM", + "PDHG", +] # Imported items in __all__ appear to originate in top-level linop module for name in __all__: diff --git a/scico/optimize/_padmm.py b/scico/optimize/_padmm.py new file mode 100644 index 000000000..61e2317aa --- /dev/null +++ b/scico/optimize/_padmm.py @@ -0,0 +1,655 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2022 by SCICO Developers +# All rights reserved. BSD 3-clause License. +# This file is part of the SCICO package. Details of the copyright and +# user license can be found in the 'LICENSE' file distributed with the +# package. + +"""Proximal ADMM solvers.""" + +# Needed to annotate a class method that returns the encapsulating class; +# see https://www.python.org/dev/peps/pep-0563/ +from __future__ import annotations + +from typing import Callable, List, Optional, Tuple, Union + +import scico.numpy as snp +from scico import cvjp, jvp +from scico.function import Function +from scico.functional import Functional +from scico.linop import Identity, LinearOperator, operator_norm +from scico.numpy import BlockArray +from scico.numpy.linalg import norm +from scico.numpy.util import ensure_on_device +from scico.typing import JaxArray, PRNGKey +from scico.util import Timer + +from ._common import itstat_func_and_object + +# mypy: disable-error-code=override + + +class ProximalADMM: + r"""Proximal alternating direction method of multipliers. + + | + + Solve an optimization problem of the form + + .. math:: + \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; + \text{such that}\; A \mb{x} + B \mb{z} = \mb{c} \;, + + where :math:`f` and :math:`g` are instances of :class:`.Functional`, + (in most cases :math:`f` will, more specifically be an an instance + of :class:`.Loss`), and :math:`A` and :math:`B` are instances of + :class:`LinearOperator`. + + The optimization problem is solved via a variant of the proximal ADMM + algorithm :cite:`deng-2015-global`, consisting of the iterations + (see :meth:`step`) + + .. math:: + \begin{aligned} + \mb{x}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \mu^{-1} f} \left( + \mb{x}^{(k)} - \mu^{-1} A^T \left(2 \mb{u}^{(k)} - + \mb{u}^{(k-1)} \right) \right) \\ + \mb{z}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \nu^{-1} g} \left( + \mb{z}^{(k)} - \nu^{-1} B^T \left( + B \mb{x}^{(k+1)} + A \mb{z}^{(k)} - \mb{c} + \mb{u}^{(k)} + \right) \right) \\ + \mb{u}^{(k+1)} &= \mb{u}^{(k)} + A \mb{x}^{(k+1)} + B + \mb{z}^{(k+1)} - \mb{c} \;. + \end{aligned} + + Parameters :math:`\mu` and :math:`\nu` are required to satisfy + + .. math:: + \mu > \norm{ A }_2^2 \quad \text{and} \quad \nu > \norm{ B }_2^2 \;. + + + Attributes: + f (:class:`.Functional`): Functional :math:`f` (usually a + :class:`.Loss`). + g (:class:`.Functional`): Functional :math:`g`. + A (:class:`.LinearOperator`): :math:`A` linear operator. + B (:class:`.LinearOperator`): :math:`B` linear operator. + c (array-like): constant :math:`\mb{c}`. + itnum (int): Iteration counter. + maxiter (int): Number of linearized ADMM outer-loop iterations. + timer (:class:`.Timer`): Iteration timer. + rho (scalar): Penalty parameter. + mu (scalar): First algorithm parameter. + nu (scalar): Second algorithm parameter. + x (array-like): Solution variable. + z (array-like): Auxiliary variables :math:`\mb{z}` at current + iteration. + z_old (array-like): Auxiliary variables :math:`\mb{z}` at + previous iteration. + u (array-like): Scaled Lagrange multipliers :math:`\mb{u}` at + current iteration. + u_old (array-like): Scaled Lagrange multipliers :math:`\mb{u}` at + previous iteration. + """ + + def __init__( + self, + f: Functional, + g: Functional, + A: LinearOperator, + B: Optional[LinearOperator], + rho: float, + mu: float, + nu: float, + c: Optional[Union[float, JaxArray, BlockArray]] = None, + x0: Optional[Union[JaxArray, BlockArray]] = None, + z0: Optional[Union[JaxArray, BlockArray]] = None, + u0: Optional[Union[JaxArray, BlockArray]] = None, + maxiter: int = 100, + fast_dual_residual: bool = True, + itstat_options: Optional[dict] = None, + ): + r"""Initialize a :class:`ProximalADMM` object. + + Args: + f: Functional :math:`f` (usually a loss function). + g: Functional :math:`g`. + A: Linear operator :math:`A`. + B: Linear operator :math:`B` (if ``None``, :math:`B = -I` + where :math:`I` is the identity operator). + rho: Penalty parameter. + mu: First algorithm parameter. + nu: Second algorithm parameter. + c: Constant :math:`\mb{c}`. If ``None``, defaults to zero. + x0: Starting value for :math:`\mb{x}`. If ``None``, defaults + to an array of zeros. + z0: Starting value for :math:`\mb{z}`. If ``None``, defaults + to an array of zeros. + u0: Starting value for :math:`\mb{u}`. If ``None``, defaults + to an array of zeros. + maxiter: Number of main algorithm iterations. Default: 100. + fast_dual_residual: Flag indicating whether to use fast + approximation to the dual residual, or a slower but more + accurate calculation. + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and a :class:`ProximalADMM` + object, responsible for constructing a tuple ready for + insertion into the :class:`.diagnostics.IterationStats` + object. If ``None``, default values are used for the dict + entries, otherwise the default dict is updated with the + dict specified by this parameter. + """ + self.f: Functional = f + self.g: Functional = g + self.A: LinearOperator = A + if B is None: + self.B = -Identity(self.A.output_shape, self.A.output_dtype) + else: + self.B = B + if c is None: + self.c = 0.0 + else: + self.c = c + self.rho: float = rho + self.mu: float = mu + self.nu: float = nu + self.itnum: int = 0 + self.maxiter: int = maxiter + self.fast_dual_residual: bool = fast_dual_residual + self.timer: Timer = Timer() + + if x0 is None: + x0 = snp.zeros(self.A.input_shape, dtype=self.A.input_dtype) + self.x = ensure_on_device(x0) + if z0 is None: + z0 = snp.zeros(self.B.input_shape, dtype=self.B.input_dtype) + self.z = ensure_on_device(z0) + self.z_old = self.z + if u0 is None: + u0 = snp.zeros(self.A.output_shape, dtype=self.A.output_dtype) + self.u = ensure_on_device(u0) + self.u_old = self.u + + self._itstat_init(itstat_options) + + def _itstat_init(self, itstat_options: Optional[dict] = None): + """Initialize iteration statistics mechanism. + + Args: + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and a :class:`PDHG` object, + responsible for constructing a tuple ready for insertion + into the :class:`.diagnostics.IterationStats` object. If + ``None``, default values are used for the dict entries, + otherwise the default dict is updated with the dict + specified by this parameter. + """ + # iteration number and time fields + itstat_fields = { + "Iter": "%d", + "Time": "%8.2e", + } + itstat_attrib = ["itnum", "timer.elapsed()"] + # objective function can be evaluated if 'g' function can be evaluated + if self.g.has_eval: + itstat_fields.update({"Objective": "%9.3e"}) + itstat_attrib.append("objective()") + # primal and dual residual fields + itstat_fields.update({"Prml Rsdl": "%9.3e", "Dual Rsdl": "%9.3e"}) + itstat_attrib.extend(["norm_primal_residual()", "norm_dual_residual()"]) + + self.itstat_insert_func, self.itstat_object = itstat_func_and_object( + itstat_fields, itstat_attrib, itstat_options + ) + + def objective( + self, + x: Optional[Union[JaxArray, BlockArray]] = None, + z: Optional[List[Union[JaxArray, BlockArray]]] = None, + ) -> float: + r"""Evaluate the objective function. + + Evaluate the objective function + + .. math:: + f(\mb{x}) + g(\mb{z}) \;. + + + Args: + x: Point at which to evaluate objective function. If + ``None``, the objective is evaluated at the current + iterate :code:`self.x`. + z: Point at which to evaluate objective function. If + ``None``, the objective is evaluated at the current + iterate :code:`self.z`. + + Returns: + scalar: Current value of the objective function. + """ + if (x is None) != (z is None): + raise ValueError("Both or neither of x and z must be supplied") + if x is None: + x = self.x + z = self.z + out = 0.0 + if self.f: + out += self.f(x) + out += self.g(z) + return out + + def norm_primal_residual( + self, + x: Optional[Union[JaxArray, BlockArray]] = None, + z: Optional[List[Union[JaxArray, BlockArray]]] = None, + ) -> float: + r"""Compute the :math:`\ell_2` norm of the primal residual. + + Compute the :math:`\ell_2` norm of the primal residual + + .. math:: + \norm{A \mb{x} + B \mb{z} - \mb{c}}_2 \;. + + Args: + x: Point at which to evaluate primal residual. If ``None``, + the primal residual is evaluated at the current iterate + :code:`self.x`. + z: Point at which to evaluate primal residual. If ``None``, + the primal residual is evaluated at the current iterate + :code:`self.z`. + + Returns: + Norm of primal residual. + """ + if (x is None) != (z is None): + raise ValueError("Both or neither of x and z must be supplied") + if x is None: + x = self.x + z = self.z + + return norm(self.A(x) + self.B(z) - self.c) + + def norm_dual_residual(self) -> float: + r"""Compute the :math:`\ell_2` norm of the dual residual. + + Compute the :math:`\ell_2` norm of the dual residual. If the flag + requesting a fast approximate calculation is set, it is computed + as + + .. math:: + \norm{\mb{z}^{(k+1)} - \mb{z}^{(k)}}_2 \;, + + otherwise it is computed as + + .. math:: + \norm{A^T B ( \mb{z}^{(k+1)} - \mb{z}^{(k)} ) }_2 \;. + + Returns: + Current norm of dual residual. + """ + if self.fast_dual_residual: + rsdl = self.z - self.z_old # fast but poor approximation + else: + rsdl = self.A.H(self.B(self.z - self.z_old)) + return norm(rsdl) + + def step(self): + r"""Perform a single algorithm iteration. + + Perform a single algorithm iteration. + """ + proxarg = self.x - (1.0 / self.mu) * self.A.H(2.0 * self.u - self.u_old) + self.x = self.f.prox(proxarg, (1.0 / (self.rho * self.mu)), v0=self.x) + proxarg = self.z - (1.0 / self.nu) * self.B.H( + self.A(self.x) + self.B(self.z) - self.c + self.u + ) + self.z_old = self.z + self.z = self.g.prox(proxarg, (1.0 / (self.rho * self.nu)), v0=self.z) + self.u_old = self.u + self.u = self.u + self.A(self.x) + self.B(self.z) - self.c + + def solve( + self, + callback: Optional[Callable[[ProximalADMM], None]] = None, + ) -> Union[JaxArray, BlockArray]: + r"""Initialize and run the optimization algorithm. + + Initialize and run the opimization algorithm for a total of + `self.maxiter` iterations. + + Args: + callback: An optional callback function, taking an a single + argument of type :class:`ProximalADMM`, that is called + at the end of every iteration. + + Returns: + Computed solution. + """ + self.timer.start() + for self.itnum in range(self.itnum, self.itnum + self.maxiter): + self.step() + self.itstat_object.insert(self.itstat_insert_func(self)) + if callback: + self.timer.stop() + callback(self) + self.timer.start() + self.timer.stop() + self.itnum += 1 + self.itstat_object.end() + return self.x + + @staticmethod + def estimate_parameters( + A: LinearOperator, + B: Optional[LinearOperator] = None, + factor: Optional[float] = 1.01, + maxiter: int = 100, + key: Optional[PRNGKey] = None, + ) -> Tuple[float, float]: + r"""Estimate `mu` and `nu` parameters of :class:`ProximalADMM`. + + Find values of the `mu` and `nu` parameters of :class:`ProximalADMM` + that respect the constraints + + .. math:: + \mu > \norm{ A }_2^2 \quad \text{and} \quad \nu > + \norm{ B }_2^2 \;. + + Args: + A: Linear operator :math:`A`. + B: Linear operator :math:`B` (if ``None``, :math:`B = -I` + where :math:`I` is the identity operator). + factor: Safety factor with which to multiply estimated + operator norms to ensure strict inequality compliance. If + ``None``, return the estimated squared operator norms. + maxiter: Maximum number of power iterations to use in operator + norm estimation (see :func:`.operator_norm`). Default: 100. + key: Jax PRNG key to use in operator norm estimation (see + :func:`.operator_norm`). Defaults to ``None``, in which + case a new key is created. + + Returns: + A tuple (`mu`, `nu`) representing the estimated parameter + values or corresponding squared operator norm values, + depending on the value of the `factor` parameter. + """ + if B is None: + B = -Identity(A.output_shape, A.output_dtype) # type: ignore + assert isinstance(B, LinearOperator) + mu = operator_norm(A, maxiter=maxiter, key=key) ** 2 + nu = operator_norm(B, maxiter=maxiter, key=key) ** 2 + if factor is None: + return (mu, nu) + else: + return (factor * mu, factor * nu) + + +class NonLinearPADMM(ProximalADMM): + r"""Non-linear proximal alternating direction method of multipliers. + + | + + Solve an optimization problem of the form + + .. math:: + \argmin_{\mb{x}} \; f(\mb{x}) + g(\mb{z}) \; + \text{such that}\; H(\mb{x}, \mb{z}) = 0 \;, + + where :math:`f` and :math:`g` are instances of :class:`.Functional`, + (in most cases :math:`f` will, more specifically be an an instance + of :class:`.Loss`), and :math:`H` is a function. + + The optimization problem is solved via a variant of the proximal ADMM + algorithm for problems with a non-linear operator constraint + :cite:`benning-2016-preconditioned`, consisting of the + iterations (see :meth:`step`) + + .. math:: + \begin{aligned} + A^{(k)} &= J_{\mb{x}} H(\mb{x}^{(k)}, \mb{z}^{(k)}) \\ + \mb{x}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \mu^{-1} f} \left( + \mb{x}^{(k)} - \mu^{-1} (A^{(k)})^T \left(2 \mb{u}^{(k)} - + \mb{u}^{(k-1)} \right) \right) \\ + B^{(k)} &= J_{\mb{z}} H(\mb{x}^{(k+1)}, \mb{z}^{(k)}) \\ + \mb{z}^{(k+1)} &= \mathrm{prox}_{\rho^{-1} \nu^{-1} g} \left( + \mb{z}^{(k)} - \nu^{-1} (B^{(k)})^T \left( + H(\mb{x}^{(k+1)}, \mb{z}^{(k)}) + \mb{u}^{(k)} \right) \right) \\ + \mb{u}^{(k+1)} &= \mb{u}^{(k)} + H(\mb{x}^{(k+1)}, + \mb{z}^{(k+1)}) \;. + \end{aligned} + + Parameters :math:`\mu` and :math:`\nu` are required to satisfy + + .. math:: + \mu > \norm{ A^{(k)} }_2^2 \quad \text{and} \quad \nu > \norm{ B^{(k)} }_2^2 + + for all :math:`A^{(k)}` and :math:`B^{(k)}`. + + + Attributes: + f (:class:`.Functional`): Functional :math:`f` (usually a + :class:`.Loss`). + g (:class:`.Functional`): Functional :math:`g`. + H (:class:`.Function`): :math:`H` function. + itnum (int): Iteration counter. + maxiter (int): Number of linearized ADMM outer-loop iterations. + timer (:class:`.Timer`): Iteration timer. + rho (scalar): Penalty parameter. + mu (scalar): First algorithm parameter. + nu (scalar): Second algorithm parameter. + x (array-like): Solution variable. + z (array-like): Auxiliary variables :math:`\mb{z}` at current + iteration. + z_old (array-like): Auxiliary variables :math:`\mb{z}` at + previous iteration. + u (array-like): Scaled Lagrange multipliers :math:`\mb{u}` at + current iteration. + u_old (array-like): Scaled Lagrange multipliers :math:`\mb{u}` at + previous iteration. + """ + + def __init__( + self, + f: Functional, + g: Functional, + H: Function, + rho: float, + mu: float, + nu: float, + x0: Optional[Union[JaxArray, BlockArray]] = None, + z0: Optional[Union[JaxArray, BlockArray]] = None, + u0: Optional[Union[JaxArray, BlockArray]] = None, + maxiter: int = 100, + fast_dual_residual: bool = True, + itstat_options: Optional[dict] = None, + ): + r"""Initialize a :class:`NonLinearPADMM` object. + + Args: + f: Functional :math:`f` (usually a loss function). + g: Functional :math:`g`. + H: Function :math:`H`. + rho: Penalty parameter. + mu: First algorithm parameter. + nu: Second algorithm parameter. + x0: Starting value for :math:`\mb{x}`. If ``None``, defaults + to an array of zeros. + z0: Starting value for :math:`\mb{z}`. If ``None``, defaults + to an array of zeros. + u0: Starting value for :math:`\mb{u}`. If ``None``, defaults + to an array of zeros. + maxiter: Number of main algorithm iterations. Default: 100. + fast_dual_residual: Flag indicating whether to use fast + approximation to the dual residual, or a slower but more + accurate calculation. + itstat_options: A dict of named parameters to be passed to + the :class:`.diagnostics.IterationStats` initializer. The + dict may also include an additional key "itstat_func" + with the corresponding value being a function with two + parameters, an integer and a :class:`NonLinearPADMM` + object, responsible for constructing a tuple ready for + insertion into the :class:`.diagnostics.IterationStats` + object. If ``None``, default values are used for the dict + entries, otherwise the default dict is updated with the + dict specified by this parameter. + """ + self.f: Functional = f + self.g: Functional = g + self.H: Function = H + self.rho: float = rho + self.mu: float = mu + self.nu: float = nu + self.itnum: int = 0 + self.maxiter: int = maxiter + self.fast_dual_residual: bool = fast_dual_residual + self.timer: Timer = Timer() + + if x0 is None: + x0 = snp.zeros(H.input_shapes[0], dtype=H.input_dtypes[0]) + self.x = ensure_on_device(x0) + if z0 is None: + z0 = snp.zeros(H.input_shapes[1], dtype=H.input_dtypes[1]) + self.z = ensure_on_device(z0) + self.z_old = self.z + if u0 is None: + u0 = snp.zeros(H.output_shape, dtype=H.output_dtype) + self.u = ensure_on_device(u0) + self.u_old = self.u + + self._itstat_init(itstat_options) + + def norm_primal_residual( + self, + x: Optional[Union[JaxArray, BlockArray]] = None, + z: Optional[List[Union[JaxArray, BlockArray]]] = None, + ) -> float: + r"""Compute the :math:`\ell_2` norm of the primal residual. + + Compute the :math:`\ell_2` norm of the primal residual + + .. math:: + \norm{H(\mb{x}, \mb{z})}_2 \;. + + Args: + x: Point at which to evaluate primal residual. If ``None``, + the primal residual is evaluated at the current iterate + :code:`self.x`. + z: Point at which to evaluate primal residual. If ``None``, + the primal residual is evaluated at the current iterate + :code:`self.z`. + + Returns: + Norm of primal residual. + """ + if (x is None) != (z is None): + raise ValueError("Both or neither of x and z must be supplied") + if x is None: + x = self.x + z = self.z + + return norm(self.H(x, z)) + + def norm_dual_residual(self) -> float: + r"""Compute the :math:`\ell_2` norm of the dual residual. + + Compute the :math:`\ell_2` norm of the dual residual. If the flag + requesting a fast approximate calculation is set, it is computed + as + + .. math:: + \norm{\mb{z}^{(k+1)} - \mb{z}^{(k)}}_2 \;, + + otherwise it is computed as + + .. math:: + \norm{A^T B ( \mb{z}^{(k+1)} - \mb{z}^{(k)} ) }_2 \;, + + where + + .. math:: + A &= J_{\mb{x}} H(\mb{x}^{(k+1)}, \mb{z}^{(k+1)}) \\ + B &= J_{\mb{z}} H(\mb{x}^{(k+1)}, \mb{z}^{(k+1)}) \;. + + Returns: + Current norm of dual residual. + """ + if self.fast_dual_residual: + rsdl = self.z - self.z_old # fast but poor approximation + else: + Hz = lambda z: self.H(self.x, z) + B = lambda u: jvp(Hz, (self.z,), (u,))[1] + Hx = lambda x: self.H(x, self.z) + AH = cvjp(Hx, self.x)[1] + rsdl = AH(B(self.z - self.z_old)) + return norm(rsdl) + + def step(self): + r"""Perform a single algorithm iteration. + + Perform a single algorithm iteration. + """ + AH = self.H.vjp(0, self.x, self.z, conjugate=True)[1] + proxarg = self.x - (1.0 / self.mu) * AH(2.0 * self.u - self.u_old) + self.x = self.f.prox(proxarg, (1.0 / (self.rho * self.mu)), v0=self.x) + BH = self.H.vjp(1, self.x, self.z, conjugate=True)[1] + proxarg = self.z - (1.0 / self.nu) * BH(self.H(self.x, self.z) + self.u) + self.z_old = self.z + self.z = self.g.prox(proxarg, (1.0 / (self.rho * self.nu)), v0=self.z) + self.u_old = self.u + self.u = self.u + self.H(self.x, self.z) + + @staticmethod + def estimate_parameters( + H: Function, + x: Optional[Union[JaxArray, BlockArray]] = None, + z: Optional[Union[JaxArray, BlockArray]] = None, + factor: Optional[float] = 1.01, + maxiter: int = 100, + key: Optional[PRNGKey] = None, + ) -> Tuple[float, float]: + r"""Estimate `mu` and `nu` parameters of :class:`NonLinearPADMM`. + + Find values of the `mu` and `nu` parameters of :class:`NonLinearPADMM` + that respect the constraints + + .. math:: + \mu > \norm{ J_x H(\mb{x}, \mb{z}) }_2^2 \quad \text{and} \quad + \nu > \norm{ J_z H(\mb{x}, \mb{z}) }_2^2 \;. + + Args: + H: Constraint function :math:`H`. + x: Value of :math:`\mb{x}` at which to evaluate the Jacobian. + If ``None``, defaults to an array of zeros. + z: Value of :math:`\mb{z}` at which to evaluate the Jacobian. + If ``None``, defaults to an array of zeros. + factor: Safety factor with which to multiply estimated + operator norms to ensure strict inequality compliance. If + ``None``, return the estimated squared operator norms. + maxiter: Maximum number of power iterations to use in operator + norm estimation (see :func:`.operator_norm`). Default: 100. + key: Jax PRNG key to use in operator norm estimation (see + :func:`.operator_norm`). Defaults to ``None``, in which + case a new key is created. + + Returns: + A tuple (`mu`, `nu`) representing the estimated parameter + values or corresponding squared operator norm values, + depending on the value of the `factor` parameter. + """ + if x is None: + x = snp.zeros(H.input_shapes[0], dtype=H.input_dtypes[0]) + if z is None: + z = snp.zeros(H.input_shapes[1], dtype=H.input_dtypes[1]) + Jx = H.jacobian(0, x, z) + Jz = H.jacobian(1, x, z) + mu = operator_norm(Jx, maxiter=maxiter, key=key) ** 2 + nu = operator_norm(Jz, maxiter=maxiter, key=key) ** 2 + if factor is None: + return (mu, nu) + else: + return (factor * mu, factor * nu) diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index 40362eed7..0ea731d45 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -15,12 +15,12 @@ import scico.numpy as snp from scico.functional import Functional -from scico.linop import LinearOperator +from scico.linop import LinearOperator, jacobian, operator_norm from scico.numpy import BlockArray from scico.numpy.linalg import norm from scico.numpy.util import ensure_on_device from scico.operator import Operator -from scico.typing import JaxArray +from scico.typing import JaxArray, PRNGKey from scico.util import Timer from ._common import itstat_func_and_object @@ -73,7 +73,7 @@ class PDHG: .. math:: \mb{x}^{(k+1)} = \mathrm{prox}_{\tau f} \left( \mb{x}^{(k)} - - \tau [\nabla C(\mb{x}^{(k)})]^T \mb{z}^{(k)} \right) \;. + \tau [J_x C(\mb{x}^{(k)})]^T \mb{z}^{(k)} \right) \;. Attributes: @@ -288,3 +288,58 @@ def solve( self.itnum += 1 self.itstat_object.end() return self.x + + @staticmethod + def estimate_parameters( + C: Operator, + x: Optional[Union[JaxArray, BlockArray]] = None, + ratio: float = 1.0, + factor: Optional[float] = 1.01, + maxiter: int = 100, + key: Optional[PRNGKey] = None, + ): + r"""Estimate `tau` and `sigma` parameters of :class:`PDHG`. + + Find values of the `tau` and `sigma` parameters of :class:`PDHG` + that respect the constraint + + .. math:: + \tau \sigma < \| C \|_2^{-2} \quad \text{or} \quad + \tau \sigma < \| J_x C(\mb{x}) \|_2^{-2} \;, + + depending on whether :math:`C` is a :class:`.LinearOperator` or + not. + + Args: + C: Operator :math:`C`. + x: Value of :math:`\mb{x}` at which to evaluate the Jacobian + of :math:`C` (when it is not a :class:`.LinearOperator`). + If ``None``, defaults to an array of zeros. + ratio: Desired ratio between return :math:`\tau` and + :math:`\sigma` values (:math:`\sigma = \mathrm{ratio} + \tau`). + factor: Safety factor with which to multiply :math:`\| C + \|_2^{-2}` to ensure strict inequality compliance. If + ``None``, the value is set to 1.0. + maxiter: Maximum number of power iterations to use in operator + norm estimation (see :func:`.operator_norm`). Default: 100. + key: Jax PRNG key to use in operator norm estimation (see + :func:`.operator_norm`). Defaults to ``None``, in which + case a new key is created. + + Returns: + A tuple (`tau`, `sigma`) representing the estimated parameter + values. + """ + if x is None: + x = snp.zeros(C.input_shape, dtype=C.input_dtype) + if factor is None: + factor = 1.0 + if isinstance(C, LinearOperator): + J = C + else: + J = jacobian(C, x) + Cnrm = operator_norm(J, maxiter=maxiter, key=key) + tau = snp.sqrt(factor / ratio) / Cnrm + sigma = ratio * tau + return (tau, sigma) diff --git a/scico/test/linop/test_linop_util.py b/scico/test/linop/test_linop_util.py index 4dc3239d7..4f618c90a 100644 --- a/scico/test/linop/test_linop_util.py +++ b/scico/test/linop/test_linop_util.py @@ -74,6 +74,9 @@ def test_operator_norm(): D = linop.Diagonal(d) Dnorm = linop.operator_norm(D) assert np.abs(Dnorm - snp.abs(d).max()) < 1e-5 + Zop = linop.MatrixOperator(snp.zeros((3, 3))) + Znorm = linop.operator_norm(Zop) + assert np.abs(Znorm) < 1e-6 @pytest.mark.parametrize("dtype", [snp.float32, snp.complex64]) diff --git a/scico/test/optimize/test_padmm.py b/scico/test/optimize/test_padmm.py new file mode 100644 index 000000000..9033a4830 --- /dev/null +++ b/scico/test/optimize/test_padmm.py @@ -0,0 +1,323 @@ +import numpy as np + +import jax + +import scico.numpy as snp +from scico import function, functional, linop, loss, random +from scico.numpy import BlockArray +from scico.optimize import NonLinearPADMM, ProximalADMM + + +class TestMisc: + def setup_method(self, method): + np.random.seed(12345) + self.y = jax.device_put(np.random.randn(32, 33).astype(np.float32)) + self.maxiter = 2 + self.ρ = 1e0 + self.μ = 1e0 + self.ν = 1e0 + self.A = linop.Identity(self.y.shape) + self.f = loss.SquaredL2Loss(y=self.y, A=self.A) + self.g = functional.DnCNN() + self.H = function.Function( + (self.A.input_shape, self.A.input_shape), + output_shape=self.A.input_shape, + eval_fn=lambda x, z: x - z, + input_dtypes=np.float32, + output_dtype=np.float32, + ) + self.x0 = snp.zeros(self.A.input_shape, dtype=snp.float32) + + def test_itstat_padmm(self): + itstat_fields = {"Iter": "%d", "Time": "%8.2e"} + + def itstat_func(obj): + return (obj.itnum, obj.timer.elapsed()) + + padmm_ = ProximalADMM( + f=self.f, + g=self.g, + A=self.A, + B=None, + rho=self.ρ, + mu=self.μ, + nu=self.ν, + x0=self.x0, + z0=self.x0, + u0=self.x0, + maxiter=self.maxiter, + ) + assert len(padmm_.itstat_object.fieldname) == 4 + assert snp.sum(padmm_.x) == 0.0 + + padmm_ = ProximalADMM( + f=self.f, + g=self.g, + A=self.A, + B=None, + rho=self.ρ, + mu=self.μ, + nu=self.ν, + maxiter=self.maxiter, + itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, + ) + assert len(padmm_.itstat_object.fieldname) == 2 + + def test_itstat_nlpadmm(self): + itstat_fields = {"Iter": "%d", "Time": "%8.2e"} + + def itstat_func(obj): + return (obj.itnum, obj.timer.elapsed()) + + nlpadmm_ = NonLinearPADMM( + f=self.f, + g=self.g, + H=self.H, + rho=self.ρ, + mu=self.μ, + nu=self.ν, + x0=self.x0, + z0=self.x0, + u0=self.x0, + maxiter=self.maxiter, + ) + assert len(nlpadmm_.itstat_object.fieldname) == 4 + assert snp.sum(nlpadmm_.x) == 0.0 + + nlpadmm_ = NonLinearPADMM( + f=self.f, + g=self.g, + H=self.H, + rho=self.ρ, + mu=self.μ, + nu=self.ν, + maxiter=self.maxiter, + itstat_options={"fields": itstat_fields, "itstat_func": itstat_func, "display": False}, + ) + assert len(nlpadmm_.itstat_object.fieldname) == 2 + + def test_callback(self): + padmm_ = ProximalADMM( + f=self.f, + g=self.g, + A=self.A, + B=None, + rho=self.ρ, + mu=self.μ, + nu=self.ν, + maxiter=self.maxiter, + ) + padmm_.test_flag = False + + def callback(obj): + obj.test_flag = True + + x = padmm_.solve(callback=callback) + assert padmm_.test_flag + + +class TestBlockArray: + def setup_method(self, method): + np.random.seed(12345) + self.y = snp.blockarray( + ( + np.random.randn(32, 33).astype(np.float32), + np.random.randn( + 17, + ).astype(np.float32), + ) + ) + self.λ = 1e0 + self.maxiter = 1 + self.ρ = 1e0 + self.μ = 1e0 + self.ν = 1e0 + self.A = linop.Identity(self.y.shape) + self.f = loss.SquaredL2Loss(y=self.y, A=self.A) + self.g = (self.λ / 2) * functional.L2Norm() + self.H = function.Function( + (self.A.input_shape, self.A.input_shape), + output_shape=self.A.input_shape, + eval_fn=lambda x, z: x - z, + input_dtypes=np.float32, + output_dtype=np.float32, + ) + self.x0 = snp.zeros(self.A.input_shape, dtype=snp.float32) + + def test_blockarray_padmm(self): + padmm_ = ProximalADMM( + f=self.f, + g=self.g, + A=self.A, + B=None, + rho=self.ρ, + mu=self.μ, + nu=self.ν, + maxiter=self.maxiter, + ) + x = padmm_.solve() + assert isinstance(x, BlockArray) + + def test_blockarray_nlpadmm(self): + nlpadmm_ = NonLinearPADMM( + f=self.f, + g=self.g, + H=self.H, + rho=self.ρ, + mu=self.μ, + nu=self.ν, + maxiter=self.maxiter, + ) + x = nlpadmm_.solve() + assert isinstance(x, BlockArray) + + +class TestReal: + def setup_method(self, method): + np.random.seed(12345) + N = 8 + MB = 10 + # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 + Amx = np.diag(np.random.randn(N).astype(np.float32)) + Bmx = np.random.randn(MB, N).astype(np.float32) + y = np.random.randn(N).astype(np.float32) + λ = 1e0 + self.Amx = Amx + self.Bmx = Bmx + self.y = jax.device_put(y) + self.λ = λ + # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y + self.grdA = lambda x: (Amx.T @ Amx + λ * Bmx.T @ Bmx) @ x + self.grdb = Amx.T @ y + + def test_padmm(self): + maxiter = 200 + ρ = 1e0 + μ = 5e1 + ν = 1e0 + A = linop.Diagonal(snp.diag(self.Amx)) + f = loss.SquaredL2Loss(y=self.y, A=A) + g = (self.λ / 2) * functional.SquaredL2Norm() + C = linop.MatrixOperator(self.Bmx) + padmm_ = ProximalADMM( + f=f, + g=g, + A=C, + B=None, + rho=ρ, + mu=μ, + nu=ν, + maxiter=maxiter, + ) + x = padmm_.solve() + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 + + def test_nlpadmm(self): + maxiter = 200 + ρ = 1e0 + μ = 5e1 + ν = 1e0 + A = linop.Diagonal(snp.diag(self.Amx)) + f = loss.SquaredL2Loss(y=self.y, A=A) + g = (self.λ / 2) * functional.SquaredL2Norm() + C = linop.MatrixOperator(self.Bmx) + H = function.Function( + (C.input_shape, C.output_shape), + output_shape=C.output_shape, + eval_fn=lambda x, z: C(x) - z, + input_dtypes=snp.float32, + output_dtype=snp.float32, + ) + nlpadmm_ = NonLinearPADMM( + f=f, + g=g, + H=H, + rho=ρ, + mu=μ, + nu=ν, + maxiter=maxiter, + ) + x = nlpadmm_.solve() + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 + + +class TestComplex: + def setup_method(self, method): + N = 8 + MB = 10 + # Set up arrays for problem argmin (1/2) ||A x - y||_2^2 + (λ/2) ||B x||_2^2 + Amx, key = random.randn((N,), dtype=np.complex64, key=None) + Amx = snp.diag(Amx) + Bmx, key = random.randn((MB, N), dtype=np.complex64, key=key) + y, key = random.randn((N,), dtype=np.complex64, key=key) + λ = 1e0 + self.Amx = Amx + self.Bmx = Bmx + self.y = jax.device_put(y) + self.λ = λ + # Solution of problem is given by linear system (A^T A + λ B^T B) x = A^T y + self.grdA = lambda x: (Amx.conj().T @ Amx + λ * Bmx.conj().T @ Bmx) @ x + self.grdb = Amx.conj().T @ y + + def test_nlpadmm(self): + maxiter = 300 + ρ = 1e0 + μ = 3e1 + ν = 1e0 + A = linop.Diagonal(snp.diag(self.Amx)) + f = loss.SquaredL2Loss(y=self.y, A=A) + g = (self.λ / 2) * functional.SquaredL2Norm() + C = linop.MatrixOperator(self.Bmx) + H = function.Function( + (C.input_shape, C.output_shape), + output_shape=C.output_shape, + eval_fn=lambda x, z: C(x) - z, + input_dtypes=snp.complex64, + output_dtype=snp.complex64, + ) + nlpadmm_ = NonLinearPADMM( + f=f, + g=g, + H=H, + rho=ρ, + mu=μ, + nu=ν, + maxiter=maxiter, + ) + x = nlpadmm_.solve() + assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 1e-4 + + +class TestEstimateParameters: + def setup_method(self): + shape = (32, 33) + self.A = linop.Identity(shape) + self.Hr = function.Function( + (shape, shape), + output_shape=shape, + eval_fn=lambda x, z: x - z, + input_dtypes=np.float32, + output_dtype=np.float32, + ) + self.Hc = function.Function( + (shape, shape), + output_shape=shape, + eval_fn=lambda x, z: x - z, + input_dtypes=np.complex64, + output_dtype=np.complex64, + ) + + def test_padmm(self): + mu, nu = ProximalADMM.estimate_parameters(self.A, factor=1.0) + assert snp.abs(mu - 1.0) < 1e-6 + assert snp.abs(nu - 1.0) < 1e-6 + + def test_real(self): + mu, nu = NonLinearPADMM.estimate_parameters(self.Hr, factor=1.0) + assert snp.abs(mu - 1.0) < 1e-6 + assert snp.abs(nu - 1.0) < 1e-6 + + def test_complex(self): + mu, nu = NonLinearPADMM.estimate_parameters(self.Hc, factor=1.0) + assert snp.abs(mu - 1.0) < 1e-6 + assert snp.abs(nu - 1.0) < 1e-6 diff --git a/scico/test/optimize/test_pdhg.py b/scico/test/optimize/test_pdhg.py index 15aad0b38..c5ed25ec5 100644 --- a/scico/test/optimize/test_pdhg.py +++ b/scico/test/optimize/test_pdhg.py @@ -197,3 +197,33 @@ def test_pdhg(self): ) x = pdhg_.solve() assert (snp.linalg.norm(self.grdA(x) - self.grdb) / snp.linalg.norm(self.grdb)) < 5e-4 + + +class TestEstimateParameters: + def setup_method(self): + shape = (32, 33) + A = linop.Identity(shape, input_dtype=np.float32) + B = linop.Identity(shape, input_dtype=np.complex64) + opcls = operator.operator_from_function(lambda x: snp.abs(x), "op") + C = opcls(input_shape=shape, input_dtype=np.float32) + D = opcls(input_shape=shape, input_dtype=np.complex64) + self.operators = [A, B, C, D] + + def test_operators_dlft(self): + for op in self.operators[0:2]: + tau, sigma = PDHG.estimate_parameters(op, factor=1.0) + assert snp.abs(tau - sigma) < 1e-6 + assert snp.abs(tau - 1.0) < 1e-6 + + def test_operators(self): + for op in self.operators: + x = snp.ones(op.input_shape, op.input_dtype) + tau, sigma = PDHG.estimate_parameters(op, x=x, factor=None) + assert snp.abs(tau - sigma) < 1e-6 + assert snp.abs(tau - 1.0) < 1e-6 + + def test_ratio(self): + op = self.operators[0] + tau, sigma = PDHG.estimate_parameters(op, factor=1.0, ratio=10.0) + assert snp.abs(tau * sigma - 1.0) < 1e-6 + assert snp.abs(sigma - 10.0 * tau) < 1e-6 diff --git a/scico/test/test_function.py b/scico/test/test_function.py new file mode 100644 index 000000000..af657bb83 --- /dev/null +++ b/scico/test/test_function.py @@ -0,0 +1,92 @@ +import numpy as np + +import pytest + +import scico.numpy as snp +from scico.function import Function +from scico.linop import jacobian +from scico.random import randn + + +class TestFunction: + def setup_method(self): + key = None + self.shape = (7, 8) + self.dtype = snp.float32 + self.x, key = randn(self.shape, key=key, dtype=self.dtype) + self.y, key = randn(self.shape, key=key, dtype=self.dtype) + self.func = lambda x, y: snp.abs(x) + snp.abs(y) + + def test_init(self): + F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func) + assert F.output_shape == self.shape + assert len(F.input_dtypes) == 2 + assert F.output_dtype == self.dtype + + def test_eval(self): + F = Function( + (self.shape, self.shape), + output_shape=self.shape, + eval_fn=self.func, + input_dtypes=(self.dtype, self.dtype), + output_dtype=self.dtype, + ) + np.testing.assert_allclose(self.func(self.x, self.y), F(self.x, self.y)) + + def test_eval_jit(self): + F = Function( + (self.shape, self.shape), + output_shape=self.shape, + eval_fn=self.func, + input_dtypes=(self.dtype, self.dtype), + output_dtype=self.dtype, + jit=True, + ) + np.testing.assert_allclose(self.func(self.x, self.y), F(self.x, self.y)) + + def test_slice(self): + F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func) + Op = F.slice(0, self.y) + np.testing.assert_allclose(Op(self.x), F(self.x, self.y)) + + def test_join(self): + F = Function((self.shape, self.shape), input_dtypes=self.dtype, eval_fn=self.func) + Op = F.join() + np.testing.assert_allclose(Op(snp.blockarray((self.x, self.y))), F(self.x, self.y)) + + def test_join_raise(self): + F = Function( + (self.shape, self.shape), input_dtypes=(snp.float32, snp.complex64), eval_fn=self.func + ) + with pytest.raises(ValueError): + Op = F.join() + + +@pytest.mark.parametrize("dtype", [snp.float32, snp.complex64]) +def test_jacobian(dtype): + N = 7 + M = 8 + key = None + fmx, key = randn((M, N), key=key, dtype=dtype) + gmx, key = randn((M, N), key=key, dtype=dtype) + F = Function(((N, 1), (N, 1)), input_dtypes=dtype, eval_fn=lambda x, y: fmx @ x + gmx @ y) + u0, key = randn((N, 1), key=key, dtype=dtype) + u1, key = randn((N, 1), key=key, dtype=dtype) + v, key = randn((N, 1), key=key, dtype=dtype) + w, key = randn((M, 1), key=key, dtype=dtype) + + op = F.slice(0, u1) + J0op = jacobian(op, u0) + np.testing.assert_allclose(J0op(v), F.jvp(0, v, u0, u1)[1]) + np.testing.assert_allclose(J0op.H(w), F.vjp(0, u0, u1)[1](w)) + J0fn = F.jacobian(0, u0, u1) + np.testing.assert_allclose(J0op(v), J0fn(v)) + np.testing.assert_allclose(J0op.H(w), J0fn.H(w)) + + op = F.slice(1, u0) + J1op = jacobian(op, u1) + np.testing.assert_allclose(J1op(v), F.jvp(1, v, u0, u1)[1]) + np.testing.assert_allclose(J1op.H(w), F.vjp(1, u0, u1)[1](w)) + J1fn = F.jacobian(1, u0, u1) + np.testing.assert_allclose(J1op(v), J1fn(v)) + np.testing.assert_allclose(J1op.H(w), J1fn.H(w))