From 5b121c959a180abb1f0e1caf0e604bb98f1a231c Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Fri, 16 Jun 2023 14:31:04 -0400 Subject: [PATCH 01/12] First attempted implementation of the API for toggling JAX. Needs lots of work still, including how to delay JIT Rebased from master --- pymbar/mbar.py | 12 +++ pymbar/mbar_solvers.py | 199 ++++++++++++++++++++++++++--------------- 2 files changed, 138 insertions(+), 73 deletions(-) diff --git a/pymbar/mbar.py b/pymbar/mbar.py index 893a46a9..cb38c024 100644 --- a/pymbar/mbar.py +++ b/pymbar/mbar.py @@ -96,6 +96,7 @@ def __init__( n_bootstraps=0, bootstrap_solver_protocol=None, rseed=None, + accelerator="numpy" ): """Initialize multistate Bennett acceptance ratio (MBAR) on a set of simulation data. @@ -186,6 +187,13 @@ def __init__( We usually just do steps of adaptive sampling without. "robust" would be the backup. Default: dict(method="adaptive", options=dict(min_sc_iter=0)), + accelerator: str, optional, default="jax" + Set the accelerator method to try. Attempts to use the named accelerator for the solvers, and then + stores the output accelerator after trying to set. Not case-sensitive. "numpy" is no-accelerators, + and will work fine. + (Valid options: jax, numpy) + + Notes ----- The reduced potential energy ``u_kn[k,n] = u_k(x_{ln})``, where the reduced potential energy ``u_l(x)`` is @@ -225,6 +233,10 @@ def __init__( """ + # Set the accelerator methods for the solvers + mbar_solvers.set_accelerator(accelerator) + self.accelerator = mbar_solvers.accelerator + # Store local copies of necessary data. # N_k[k] is the number of samples from state k, some of which might be zero. self.N_k = np.array(N_k, dtype=np.int64) diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index 6c36e654..1fa37066 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -11,15 +11,91 @@ logger = logging.getLogger(__name__) use_jit = False -force_no_jax = False # Temporary until we can make a proper setting to enable/disable by choice -try: - #### JAX related imports - if force_no_jax: - # Capture user-disabled JAX instead "JAX not found" - raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py") +accelerator = "numpy" + + +# Import the methods functionally +# This is admittedly non-standard, but solves the following use case: +# * Has JAX +# * Wants to use PyMBAR +# * Does NOT want JAX to be set to 64-bit mode +# Also solves the future use case of different accelerator, +# but want to selectively use them +def init_numpy(): + """Set the imports for the basic numpy methods""" + # Fallback/default solver methods + # NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY + global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp + global jit, precondition_jit + global accelerator, use_jit + from numpy import exp, sum, newaxis, diag, dot, s_ + from numpy import pad as npad + from numpy.linalg import lstsq + import scipy.optimize as scipy_optimize # pylint: disable=reimported + from scipy.special import logsumexp + + # No jit, so make a passthrough decorator + def jit(fn): + return fn + + # Precondition if you need to do something different + def precondition_jit(fn): + return jit(fn) + + use_jit = False + accelerator = "numpy" + logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy") + + +def init_jax(): + """Set the imports for the JAX accelerated methods""" + # NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY + global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp + global jit, precondition_jit + global accelerator, use_jit + global config try: from jax.config import config + from jax.numpy import exp, sum, newaxis, diag, dot, s_ + from jax.numpy import pad as npad + from jax.numpy.linalg import lstsq + import jax.scipy.optimize as scipy_optimize + from jax.scipy.special import logsumexp + + from jax import jit + def precondition_jit(jitable_fn): + """ + Attempt to set JAX precision if present. This does nothing if JAX is not present + + Parameters + ---------- + jitable_fn: function + A function which can be jit'd + """ + + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection + def staggered_jit(*args, **kwargs): + # This will only trigger if JAX is set + if use_jit and not config.x64_enabled: + # Warn that JAX 64-bit will being turned on + logger.warning( + "\n" + "******* JAX 64-bit mode is now on! *******\n" + "* JAX is now set to 64-bit mode! *\n" + "* This MAY cause problems with other *\n" + "* uses of JAX in the same code. *\n" + "******************************************\n" + ) + config.update("jax_enable_x64", True) + jited_fn = jit(jitable_fn) + return jited_fn(*args, **kwargs) + + return staggered_jit + + # Throw warning only if the whole of JAX is found if not config.x64_enabled: # Warn that we're going to be setting 64 bit jax logger.warning( @@ -36,15 +112,9 @@ "******************************************\n" ) - from jax.numpy import exp, sum, newaxis, diag, dot, s_ - from jax.numpy import pad as npad - from jax.numpy.linalg import lstsq - import jax.scipy.optimize as optimize_maybe_jax - from jax.scipy.special import logsumexp - - from jax import jit as jit_or_passthrough - use_jit = True + accelerator = "jax" + logger.info("JAX detected. Using JAX acceleration.") except ImportError: # Catch no JAX and throw a warning logger.warning( @@ -58,31 +128,46 @@ " conda install pymbar \n" "*********************************" ) - raise # Continue with the raised Import Error + # Fall back to NumPy import + init_numpy() -except ImportError: - # No JAX found, overlap imports - # These imports MUST align exactly - from numpy import exp, sum, newaxis, diag, dot, s_ - from numpy import pad as npad - from numpy.linalg import lstsq - import scipy.optimize as optimize_maybe_jax # pylint: disable=reimported - from scipy.special import logsumexp +# Accelerator map for the set method below +ACCELERATOR_MAP = { + "numpy": init_numpy, + "jax": init_jax +} + +# Try to set the initial/default accelerator +init_jax() - # No jit, so make a passthrough decorator - def jit_or_passthrough(fn): - return fn +# Helper function for toggling the solver method +def set_accelerator(accelerator_name: str): + """ + Set the accelerator in the namespace for this module + """ + global accelerator # We want to modify the current accelerator + # Saving it to new tag does not change since we're saving the immutable string object + accel = accelerator_name.lower() + if accel not in ACCELERATOR_MAP: + raise ValueError(f"No accelerator implementation for {accel}, please use one of the following:\n" + + "".join((f"* {a}\n" for a in ACCELERATOR_MAP.keys())) + + f"(case-insentive)" + ) + logger.info(f"Attempting to change accelerator to {accel}...") + old_accelerator = accelerator + ACCELERATOR_MAP[accelerator_name.lower()]() + new_accelerator = accelerator + if new_accelerator == old_accelerator: + logger.warning(f"Attempted to change accelerator from {old_accelerator} to {accel}," + f" but something went wrong. Please check the log outputs above.") + return + logger.info(f"Successfully changed to accelerator {accel}!") # Note on "pylint: disable=invalid-unary-operand-type" # Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax. # Can be mostly ignored -if use_jit is False: - logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy") -else: - logger.info("JAX detected. Using JAX acceleration.") - # Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving # the MBAR equations. # Note: we use tuples instead of lists to avoid accidental mutability. @@ -126,38 +211,6 @@ def jit_or_passthrough(fn): scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included -def jit_or_pass_after_bitsize(jitable_fn): - """ - Attempt to set JAX precision if present. This does nothing if JAX is not present - - Parameters - ---------- - jitable_fn: function - A function which can be jit'd - """ - - @wraps( - jitable_fn - ) # Helper to ensure the decorated function still registers for docs and inspection - def staggered_jit(*args, **kwargs): - # This will only trigger if JAX is set - if use_jit and not config.x64_enabled: - # Warn that JAX 64-bit will being turned on - logger.warning( - "\n" - "******* JAX 64-bit mode is now on! *******\n" - "* JAX is now set to 64-bit mode! *\n" - "* This MAY cause problems with other *\n" - "* uses of JAX in the same code. *\n" - "******************************************\n" - ) - config.update("jax_enable_x64", True) - jited_fn = jit_or_passthrough(jitable_fn) - return jited_fn(*args, **kwargs) - - return staggered_jit - - def validate_inputs(u_kn, N_k, f_k): """Check types and return inputs for MBAR calculations. @@ -215,7 +268,7 @@ def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): return jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples) -@jit_or_pass_after_bitsize +@precondition_jit def _jit_self_consistent_update(u_kn, N_k, f_k): """JAX version of self_consistent update. For parameters, see self_consistent_update. N_k must be float (should be cast at a higher level) @@ -268,7 +321,7 @@ def mbar_gradient(u_kn, N_k, f_k): return jax_mbar_gradient(u_kn, N_k, f_k) -@jit_or_pass_after_bitsize +@precondition_jit def jax_mbar_gradient(u_kn, N_k, f_k): """JAX version of MBAR gradient function. See documentation of mbar_gradient. N_k must be float (should be cast at a higher level) @@ -311,7 +364,7 @@ def mbar_objective(u_kn, N_k, f_k): return jax_mbar_objective(u_kn, N_k, f_k) -@jit_or_pass_after_bitsize +@precondition_jit def jax_mbar_objective(u_kn, N_k, f_k): """JAX version of mbar_objective. For parameters, mbar_objective_and_Gradient @@ -325,7 +378,7 @@ def jax_mbar_objective(u_kn, N_k, f_k): return obj -@jit_or_pass_after_bitsize +@precondition_jit def jax_mbar_objective_and_gradient(u_kn, N_k, f_k): """JAX version of mbar_objective_and_gradient. For parameters, mbar_objective_and_Gradient @@ -379,7 +432,7 @@ def mbar_objective_and_gradient(u_kn, N_k, f_k): return jax_mbar_objective_and_gradient(u_kn, N_k, f_k) -@jit_or_pass_after_bitsize +@precondition_jit def jax_mbar_hessian(u_kn, N_k, f_k): """JAX version of mbar_hessian. For parameters, see mbar_hessian @@ -423,7 +476,7 @@ def mbar_hessian(u_kn, N_k, f_k): return jax_mbar_hessian(u_kn, N_k, f_k) -@jit_or_pass_after_bitsize +@precondition_jit def jax_mbar_log_W_nk(u_kn, N_k, f_k): """JAX version of mbar_log_W_nk. For parameters, see mbar_log_W_nk @@ -460,7 +513,7 @@ def mbar_log_W_nk(u_kn, N_k, f_k): return jax_mbar_log_W_nk(u_kn, N_k, f_k) -@jit_or_pass_after_bitsize +@precondition_jit def jax_mbar_W_nk(u_kn, N_k, f_k): """JAX version of mbar_W_nk. For parameters, see mbar_W_nk @@ -654,7 +707,7 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None): return results -@jit_or_pass_after_bitsize +@precondition_jit def jax_core_adaptive(u_kn, N_k, f_k, gamma): """JAX version of adaptive inner loop. N_k must be float (should be cast at a higher level) @@ -681,7 +734,7 @@ def jax_core_adaptive(u_kn, N_k, f_k, gamma): return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr -@jit_or_pass_after_bitsize +@precondition_jit def jax_precondition_u_kn(u_kn, N_k, f_k): """JAX version of precondition_u_kn for parameters, see precondition_u_kn @@ -808,7 +861,7 @@ def solve_mbar_once( fpad = lambda x: npad(x, (1, 0)) obj = lambda x: mbar_objective(u_kn_nonzero, N_k_nonzero, fpad(x)) # objective function to be minimized (for derivative free methods, mostly jit) - jax_results = optimize_maybe_jax.minimize( + jax_results = scipy_optimize.minimize( obj, f_k_nonzero[1:], method=method, From 511d86764ba11fed2cd2afc1ab0c10a3c0ee8b5a Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Fri, 16 Jun 2023 16:52:38 -0400 Subject: [PATCH 02/12] Make using JAX a toggle This PR overhuals how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself. This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future. * Importing is all handled through an `init_accelerator` method with matching name (i.e. `init_numpy` or `init_jax`). * All items which need to be set in the `mbar_solvers` namespace are set through the `global` word of Python in the `init_X` method and therefore are cast up to the full `mbar_solvers` namespace. * The `mbar_solvers` module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries. * Default is JAX * I'm calling numpy an "accelerator" even though its the fall back. --- pymbar/mbar.py | 2 +- pymbar/mbar_solvers.py | 48 +++++++++----- pymbar/tests/test_accelerators.py | 107 ++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 18 deletions(-) create mode 100644 pymbar/tests/test_accelerators.py diff --git a/pymbar/mbar.py b/pymbar/mbar.py index cb38c024..4cbd014d 100644 --- a/pymbar/mbar.py +++ b/pymbar/mbar.py @@ -96,7 +96,7 @@ def __init__( n_bootstraps=0, bootstrap_solver_protocol=None, rseed=None, - accelerator="numpy" + accelerator="jax", ): """Initialize multistate Bennett acceptance ratio (MBAR) on a set of simulation data. diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index 1fa37066..5894bd23 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -23,6 +23,12 @@ # but want to selectively use them def init_numpy(): """Set the imports for the basic numpy methods""" + # Disable the pylint problems for this block + # pylint: disable=global-variable-not-assigned + # pylint: disable=global-statement + # pylint: disable=unused-import + # pylint: disable=global-variable-undefined + # Fallback/default solver methods # NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp @@ -41,7 +47,7 @@ def jit(fn): # Precondition if you need to do something different def precondition_jit(fn): return jit(fn) - + use_jit = False accelerator = "numpy" logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy") @@ -49,6 +55,11 @@ def precondition_jit(fn): def init_jax(): """Set the imports for the JAX accelerated methods""" + # Disable the pylint problems for this block + # pylint: disable=global-variable-not-assigned + # pylint: disable=global-statement + # pylint: disable=unused-import + # pylint: disable=global-variable-undefined # NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp global jit, precondition_jit @@ -64,6 +75,7 @@ def init_jax(): from jax.scipy.special import logsumexp from jax import jit + def precondition_jit(jitable_fn): """ Attempt to set JAX precision if present. This does nothing if JAX is not present @@ -94,7 +106,7 @@ def staggered_jit(*args, **kwargs): return jited_fn(*args, **kwargs) return staggered_jit - + # Throw warning only if the whole of JAX is found if not config.x64_enabled: # Warn that we're going to be setting 64 bit jax @@ -131,12 +143,10 @@ def staggered_jit(*args, **kwargs): # Fall back to NumPy import init_numpy() + # Accelerator map for the set method below -ACCELERATOR_MAP = { - "numpy": init_numpy, - "jax": init_jax -} - +ACCELERATOR_MAP = {"numpy": init_numpy, "jax": init_jax} + # Try to set the initial/default accelerator init_jax() @@ -146,24 +156,28 @@ def set_accelerator(accelerator_name: str): """ Set the accelerator in the namespace for this module """ - global accelerator # We want to modify the current accelerator - # Saving it to new tag does not change since we're saving the immutable string object + # Saving accelerator to new tag does not change since we're saving the immutable string object accel = accelerator_name.lower() if accel not in ACCELERATOR_MAP: - raise ValueError(f"No accelerator implementation for {accel}, please use one of the following:\n" + - "".join((f"* {a}\n" for a in ACCELERATOR_MAP.keys())) + - f"(case-insentive)" - ) + raise ValueError( + f"No accelerator implementation for {accel}, please use one of the following:\n" + + "".join((f"* {a}\n" for a in ACCELERATOR_MAP.keys())) + + f"(case-insentive)" + ) logger.info(f"Attempting to change accelerator to {accel}...") old_accelerator = accelerator - ACCELERATOR_MAP[accelerator_name.lower()]() + # Check the accelerator map, call the accelerator init which will handle the accelerator at the top level + ACCELERATOR_MAP[accel]() new_accelerator = accelerator - if new_accelerator == old_accelerator: - logger.warning(f"Attempted to change accelerator from {old_accelerator} to {accel}," - f" but something went wrong. Please check the log outputs above.") + if new_accelerator == old_accelerator and accel != old_accelerator: + logger.warning( + f"Attempted to change accelerator from {old_accelerator} to {accel}," + f" but something went wrong. Please check the log outputs above." + ) return logger.info(f"Successfully changed to accelerator {accel}!") + # Note on "pylint: disable=invalid-unary-operand-type" # Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax. # Can be mostly ignored diff --git a/pymbar/tests/test_accelerators.py b/pymbar/tests/test_accelerators.py new file mode 100644 index 00000000..508622a8 --- /dev/null +++ b/pymbar/tests/test_accelerators.py @@ -0,0 +1,107 @@ +"""Test MBAR accelerators by ensuring they yield comperable results to the default (numpy) and can cycle between them +""" + +import numpy as np +import pytest + +from pymbar import MBAR +from pymbar.utils_for_testing import assert_equal, assert_allclose + +# Pylint doesn't like the interplay between pytest and importing fixtures. disabled the one problem. +from pymbar.tests.test_mbar import ( # pylint: disable=unused-import + system_generators, + N_k, + free_energies_almost_equal, + fixed_harmonic_sample, +) + +# Setup skip if conditions +has_jax = False +try: + # pylint: disable=unused-import + from jax import jit + from jax.numpy import ndarray as jax_ndarray + + has_jax = True +except ImportError: + pass + +# Establish marks +needs_jax = pytest.mark.skipif(not has_jax, reason="Needs Jax Accelerator") + + +# Required test function for testing that the accelerator worked correctly. +def check_numpy(mbar: MBAR): + assert isinstance(mbar.f_k, np.ndarray) + + +def check_jax(mbar: MBAR): + assert isinstance(mbar.f_k, jax_ndarray) + + +# Setup accelerator list. Each parameter is (string_of_accelerator, accelerator_check) +numpy_accel = pytest.param(("numpy", check_numpy), id="numpy") +jax_accel = pytest.param(("jax", check_jax), marks=needs_jax, id="jax") +accelerators = [numpy_accel, jax_accel] + + +@pytest.fixture +def fallback_accelerator(): + return "numpy", check_numpy + + +@pytest.fixture(scope="module", params=system_generators) +def only_test_data(request): + _, test = request.param() + x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode="u_kn") + assert_equal(N_k, N_k_output) + yield_bundle = {"test": test, "x_n": x_n, "u_kn": u_kn} + yield yield_bundle + + +@pytest.fixture() +def static_ukn_nk(fixed_harmonic_sample): + _, u_kn, N_k_output, _ = fixed_harmonic_sample.sample(N_k, mode="u_kn") + assert_equal(N_k, N_k_output) + return u_kn, N_k_output + + +@pytest.mark.parametrize("accelerator", accelerators) +def test_mbar_accelerators_are_accurate(only_test_data, accelerator): + """Test that each accelerator is scientifically accurate""" + accelerator_name, accelerator_check = accelerator + test, x_n, u_kn = only_test_data["test"], only_test_data["x_n"], only_test_data["u_kn"] + x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode="u_kn") + mbar = MBAR(u_kn, N_k, verbose=True, n_bootstraps=200, accelerator=accelerator_name) + results = mbar.compute_free_energy_differences() + fe = results["Delta_f"] + fe_sigma = results["dDelta_f"] + free_energies_almost_equal(fe, fe_sigma, test.analytical_free_energies()) + accelerator_check(mbar) + + +def build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check): + """Helper function to build an MBAR object""" + mbar = MBAR(u_kn, N_k, verbose=True, accelerator=accelerator_name) + assert mbar.accelerator == accelerator_name + accelerator_check(mbar) + return mbar + + +@pytest.mark.parametrize("accelerator", accelerators) +def test_mbar_accelerators_can_toggle(static_ukn_nk, accelerator, fallback_accelerator): + """ + Test that accelerator can toggle and the act of doing so doesn't corrupt each other's output. + """ + u_kn, N_k_output = static_ukn_nk + # Setup and check the accelerator + accelerator_name, accelerator_check = accelerator + mbar = build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check) + # Setup and check the fallback + fall_back_name, fall_back_check = fallback_accelerator + mbar_fallback = build_out_an_mbar(u_kn, N_k, fall_back_name, fall_back_check) + # Ensure fallback and accelerator match + assert_allclose(mbar.f_k, mbar_fallback.f_k) + # Rebuild the accelerated version again. + mbar_rebuild = build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check) + assert_allclose(mbar.f_k, mbar_rebuild.f_k) From a7945fe0a87f94ab4765733c499dfe2b9f30dd28 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Tue, 20 Jun 2023 14:40:59 -0400 Subject: [PATCH 03/12] First take on implementing a much more pythonic jax-toggle system. Relies on creating classes with an exposed API. The problem is that JAX doesn't like acting on class methods so I am having to build around it. --- pymbar/mbar.py | 11 +- pymbar/mbar_solvers.py | 1071 --------------------------- pymbar/mbar_solvers/__init__.py | 106 +++ pymbar/mbar_solvers/jax_solver.py | 146 ++++ pymbar/mbar_solvers/mbar_solver.py | 788 ++++++++++++++++++++ pymbar/mbar_solvers/numpy_solver.py | 95 +++ pymbar/mbar_solvers/solver_api.py | 157 ++++ 7 files changed, 1297 insertions(+), 1077 deletions(-) delete mode 100644 pymbar/mbar_solvers.py create mode 100644 pymbar/mbar_solvers/__init__.py create mode 100644 pymbar/mbar_solvers/jax_solver.py create mode 100644 pymbar/mbar_solvers/mbar_solver.py create mode 100644 pymbar/mbar_solvers/numpy_solver.py create mode 100644 pymbar/mbar_solvers/solver_api.py diff --git a/pymbar/mbar.py b/pymbar/mbar.py index 4cbd014d..0ca48485 100644 --- a/pymbar/mbar.py +++ b/pymbar/mbar.py @@ -234,8 +234,7 @@ def __init__( """ # Set the accelerator methods for the solvers - mbar_solvers.set_accelerator(accelerator) - self.accelerator = mbar_solvers.accelerator + self.solver = mbar_solvers.get_accelerator(accelerator)() # Store local copies of necessary data. # N_k[k] is the number of samples from state k, some of which might be zero. @@ -419,7 +418,7 @@ def __init__( else: np.random.seed(rseed) - self.f_k = mbar_solvers.solve_mbar_for_all_states( + self.f_k = self.solver.solve_mbar_for_all_states( self.u_kn, self.N_k, self.f_k, self.states_with_samples, solver_protocol ) @@ -443,7 +442,7 @@ def __init__( # If we initialized with BAR, then BAR, starting from the provided initial_f_k as well. if initialize == "BAR": f_k_init = self._initialize_with_bar(self.u_kn[:, rints], f_k_init=self.f_k) - self.f_k_boots[b, :] = mbar_solvers.solve_mbar_for_all_states( + self.f_k_boots[b, :] = self.solver.solve_mbar_for_all_states( self.u_kn[:, rints], self.N_k, f_k_init, @@ -461,7 +460,7 @@ def __init__( # bootstrapped weight matrices not generated here, but when expectations are needed # otherwise, it's too much memory to keep - self.Log_W_nk = mbar_solvers.mbar_log_W_nk(self.u_kn, self.N_k, self.f_k) + self.Log_W_nk = self.solver.mbar_log_W_nk(self.u_kn, self.N_k, self.f_k) # Print final dimensionless free energies. if self.verbose: @@ -916,7 +915,7 @@ def compute_expectations_inner( f_k[0:K] = self.f_k_boots[n - 1, :] ri = self.bootstrap_rints[n - 1] u_kn = self.u_kn[:, ri] - Log_W_nk[:, 0:K] = mbar_solvers.mbar_log_W_nk(u_kn, self.N_k, f_k[0:K]) + Log_W_nk[:, 0:K] = self.solver.mbar_log_W_nk(u_kn, self.N_k, f_k[0:K]) # Pre-calculate the log denominator: Eqns 13, 14 in MBAR paper states_with_samples = self.N_k > 0 diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py deleted file mode 100644 index 5894bd23..00000000 --- a/pymbar/mbar_solvers.py +++ /dev/null @@ -1,1071 +0,0 @@ -import logging -import warnings -from functools import wraps - -import numpy as np - -# Optimize imported here and below as the jax-optimized one is jax or passthrough, but this is required regardless -import scipy.optimize -from pymbar.utils import ensure_type, check_w_normalized, ParameterError - -logger = logging.getLogger(__name__) - -use_jit = False -accelerator = "numpy" - - -# Import the methods functionally -# This is admittedly non-standard, but solves the following use case: -# * Has JAX -# * Wants to use PyMBAR -# * Does NOT want JAX to be set to 64-bit mode -# Also solves the future use case of different accelerator, -# but want to selectively use them -def init_numpy(): - """Set the imports for the basic numpy methods""" - # Disable the pylint problems for this block - # pylint: disable=global-variable-not-assigned - # pylint: disable=global-statement - # pylint: disable=unused-import - # pylint: disable=global-variable-undefined - - # Fallback/default solver methods - # NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY - global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp - global jit, precondition_jit - global accelerator, use_jit - from numpy import exp, sum, newaxis, diag, dot, s_ - from numpy import pad as npad - from numpy.linalg import lstsq - import scipy.optimize as scipy_optimize # pylint: disable=reimported - from scipy.special import logsumexp - - # No jit, so make a passthrough decorator - def jit(fn): - return fn - - # Precondition if you need to do something different - def precondition_jit(fn): - return jit(fn) - - use_jit = False - accelerator = "numpy" - logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy") - - -def init_jax(): - """Set the imports for the JAX accelerated methods""" - # Disable the pylint problems for this block - # pylint: disable=global-variable-not-assigned - # pylint: disable=global-statement - # pylint: disable=unused-import - # pylint: disable=global-variable-undefined - # NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY - global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp - global jit, precondition_jit - global accelerator, use_jit - global config - try: - from jax.config import config - - from jax.numpy import exp, sum, newaxis, diag, dot, s_ - from jax.numpy import pad as npad - from jax.numpy.linalg import lstsq - import jax.scipy.optimize as scipy_optimize - from jax.scipy.special import logsumexp - - from jax import jit - - def precondition_jit(jitable_fn): - """ - Attempt to set JAX precision if present. This does nothing if JAX is not present - - Parameters - ---------- - jitable_fn: function - A function which can be jit'd - """ - - @wraps( - jitable_fn - ) # Helper to ensure the decorated function still registers for docs and inspection - def staggered_jit(*args, **kwargs): - # This will only trigger if JAX is set - if use_jit and not config.x64_enabled: - # Warn that JAX 64-bit will being turned on - logger.warning( - "\n" - "******* JAX 64-bit mode is now on! *******\n" - "* JAX is now set to 64-bit mode! *\n" - "* This MAY cause problems with other *\n" - "* uses of JAX in the same code. *\n" - "******************************************\n" - ) - config.update("jax_enable_x64", True) - jited_fn = jit(jitable_fn) - return jited_fn(*args, **kwargs) - - return staggered_jit - - # Throw warning only if the whole of JAX is found - if not config.x64_enabled: - # Warn that we're going to be setting 64 bit jax - logger.warning( - "\n" - "****** PyMBAR will use 64-bit JAX! *******\n" - "* JAX is currently set to 32-bit bitsize *\n" - "* which is its default. *\n" - "* *\n" - "* PyMBAR requires 64-bit mode and WILL *\n" - "* enable JAX's 64-bit mode when called. *\n" - "* *\n" - "* This MAY cause problems with other *\n" - "* Uses of JAX in the same code. *\n" - "******************************************\n" - ) - - use_jit = True - accelerator = "jax" - logger.info("JAX detected. Using JAX acceleration.") - except ImportError: - # Catch no JAX and throw a warning - logger.warning( - "\n" - "********* JAX NOT FOUND *********\n" - " PyMBAR can run faster with JAX \n" - " But will work fine without it \n" - "Either install with pip or conda:\n" - " pip install pybar[jax] \n" - " OR \n" - " conda install pymbar \n" - "*********************************" - ) - # Fall back to NumPy import - init_numpy() - - -# Accelerator map for the set method below -ACCELERATOR_MAP = {"numpy": init_numpy, "jax": init_jax} - -# Try to set the initial/default accelerator -init_jax() - - -# Helper function for toggling the solver method -def set_accelerator(accelerator_name: str): - """ - Set the accelerator in the namespace for this module - """ - # Saving accelerator to new tag does not change since we're saving the immutable string object - accel = accelerator_name.lower() - if accel not in ACCELERATOR_MAP: - raise ValueError( - f"No accelerator implementation for {accel}, please use one of the following:\n" - + "".join((f"* {a}\n" for a in ACCELERATOR_MAP.keys())) - + f"(case-insentive)" - ) - logger.info(f"Attempting to change accelerator to {accel}...") - old_accelerator = accelerator - # Check the accelerator map, call the accelerator init which will handle the accelerator at the top level - ACCELERATOR_MAP[accel]() - new_accelerator = accelerator - if new_accelerator == old_accelerator and accel != old_accelerator: - logger.warning( - f"Attempted to change accelerator from {old_accelerator} to {accel}," - f" but something went wrong. Please check the log outputs above." - ) - return - logger.info(f"Successfully changed to accelerator {accel}!") - - -# Note on "pylint: disable=invalid-unary-operand-type" -# Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax. -# Can be mostly ignored - -# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving -# the MBAR equations. -# Note: we use tuples instead of lists to avoid accidental mutability. -JAX_SOLVER_PROTOCOL = ( - dict(method="BFGS", continuation=True), - dict(method="adaptive", options=dict(min_sc_iter=0)), -) - -DEFAULT_SOLVER_PROTOCOL = ( - dict(method="hybr", continuation=True), - dict(method="adaptive", options=dict(min_sc_iter=0)), -) - -ROBUST_SOLVER_PROTOCOL = ( - dict(method="adaptive", options=dict(maxiter=1000)), - dict(method="L-BFGS-B", options=dict(maxiter=1000)), -) - -BOOTSTRAP_SOLVER_PROTOCOL = (dict(method="adaptive", options=dict(min_sc_iter=0)),) - -# Allows all of the gradient based methods, but not the non-gradient methods ["Nelder-Mead", "Powell", "COBYLA"]", -scipy_minimize_options = [ - "L-BFGS-B", - "dogleg", - "CG", - "BFGS", - "Newton-CG", - "TNC", - "trust-ncg", - "trust-krylov", - "trust-exact", - "SLSQP", -] -scipy_nohess_options = [ - "L-BFGS-B", - "BFGS", - "CG", - "TNC", - "SLSQP", -] # don't pass a hessian to these to avoid warnings to these. -scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included - - -def validate_inputs(u_kn, N_k, f_k): - """Check types and return inputs for MBAR calculations. - - Parameters - ---------- - u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies or unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies or unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='float' - The number of samples in each state. Converted to float because this cast is required when log is calculated. - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - """ - n_states, n_samples = u_kn.shape - - u_kn = ensure_type(u_kn, "float", 2, "u_kn or Q_kn", shape=(n_states, n_samples)) - N_k = ensure_type( - N_k, "float", 1, "N_k", shape=(n_states,), warn_on_cast=False - ) # Autocast to float because will be eventually used in float calculations. - f_k = ensure_type(f_k, "float", 1, "f_k", shape=(n_states,)) - - return u_kn, N_k, f_k - - -def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): - """Return an improved guess for the dimensionless free energies - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - f_k : np.ndarray, shape=(n_states), dtype='float' - Updated estimate of f_k - - Notes - ----- - Equation C3 in MBAR JCP paper. - """ - - return jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples) - - -@precondition_jit -def _jit_self_consistent_update(u_kn, N_k, f_k): - """JAX version of self_consistent update. For parameters, see self_consistent_update. - N_k must be float (should be cast at a higher level) - - """ - # Asteroid - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - # All states can contribute to the numerator term. Check transpose - return -1.0 * logsumexp( - -log_denominator_n - u_kn, axis=1 - ) # pylint: disable=invalid-unary-operand-type - - -def jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): - """JAX version of self_consistent update. For parameters, see self_consistent_update. - N_k must be float (should be cast at a higher level) - - """ - # Only the states with samples can contribute to the denominator term. - # Precondition before feeding the op to the JIT'd function - # In theory, this can be computed with jax.lax.cond, but trying to reuse code for non-jax paths - states_with_samples = s_[:] if states_with_samples is None else states_with_samples - # Feed to the JIT'd function. Can't pass slice types, so slice here - return _jit_self_consistent_update( - u_kn[states_with_samples], N_k[states_with_samples], f_k[states_with_samples] - ) - - -def mbar_gradient(u_kn, N_k, f_k): - """Gradient of MBAR objective function. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - grad : np.ndarray, dtype=float, shape=(n_states) - Gradient of mbar_objective - - Notes - ----- - This is equation C6 in the JCP MBAR paper. - """ - return jax_mbar_gradient(u_kn, N_k, f_k) - - -@precondition_jit -def jax_mbar_gradient(u_kn, N_k, f_k): - """JAX version of MBAR gradient function. See documentation of mbar_gradient. - N_k must be float (should be cast at a higher level) - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1) - return -1 * N_k * (1.0 - exp(f_k + log_numerator_k)) - - -def mbar_objective(u_kn, N_k, f_k): - """Calculates objective function for MBAR. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - - Returns - ------- - obj : float - Objective function - - Notes - ----- - This objective function is essentially a doubly-summed partition function and is - quite sensitive to precision loss from both overflow and underflow. For optimal - results, u_kn can be preconditioned by subtracting out a `n` dependent - vector. - - More optimal precision, the objective function uses math.fsum for the - outermost sum and logsumexp for the inner sum. - """ - - return jax_mbar_objective(u_kn, N_k, f_k) - - -@precondition_jit -def jax_mbar_objective(u_kn, N_k, f_k): - """JAX version of mbar_objective. - For parameters, mbar_objective_and_Gradient - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - obj = sum(log_denominator_n) - dot(N_k, f_k) - - return obj - - -@precondition_jit -def jax_mbar_objective_and_gradient(u_kn, N_k, f_k): - """JAX version of mbar_objective_and_gradient. - For parameters, mbar_objective_and_Gradient - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1) - grad = -1 * N_k * (1.0 - exp(f_k + log_numerator_k)) - - obj = sum(log_denominator_n) - dot(N_k, f_k) - - return obj, grad - - -def mbar_objective_and_gradient(u_kn, N_k, f_k): - """Calculates both objective function and gradient for MBAR. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - - Returns - ------- - obj : float - Objective function - grad : np.ndarray, dtype=float, shape=(n_states) - Gradient of objective function - - Notes - ----- - This objective function is essentially a doubly-summed partition function and is - quite sensitive to precision loss from both overflow and underflow. For optimal - results, u_kn can be preconditioned by subtracting out a `n` dependent - vector. - - More optimal precision, the objective function uses math.fsum for the - outermost sum and logsumexp for the inner sum. - - The gradient is equation C6 in the JCP MBAR paper; the objective - function is its integral. - """ - - return jax_mbar_objective_and_gradient(u_kn, N_k, f_k) - - -@precondition_jit -def jax_mbar_hessian(u_kn, N_k, f_k): - """JAX version of mbar_hessian. - For parameters, see mbar_hessian - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, newaxis] - W = exp(logW) - - H = dot(W.T, W) - H *= N_k - H *= N_k[:, newaxis] - H -= diag(W.sum(0) * N_k) - return -1.0 * H - - -def mbar_hessian(u_kn, N_k, f_k): - """Hessian of MBAR objective function. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - H : np.ndarray, dtype=float, shape=(n_states, n_states) - Hessian of mbar objective function. - - Notes - ----- - Equation (C9) in JCP MBAR paper. - """ - - return jax_mbar_hessian(u_kn, N_k, f_k) - - -@precondition_jit -def jax_mbar_log_W_nk(u_kn, N_k, f_k): - """JAX version of mbar_log_W_nk. - For parameters, see mbar_log_W_nk - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, newaxis] - return logW - - -def mbar_log_W_nk(u_kn, N_k, f_k): - """Calculate the log weight matrix. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - logW_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) - The normalized log weights. - - Notes - ----- - Equation (9) in JCP MBAR paper. - """ - return jax_mbar_log_W_nk(u_kn, N_k, f_k) - - -@precondition_jit -def jax_mbar_W_nk(u_kn, N_k, f_k): - """JAX version of mbar_W_nk. - For parameters, see mbar_W_nk - N_k must be float (should be cast at a higher level) - - """ - return exp(jax_mbar_log_W_nk(u_kn, N_k, f_k)) - - -def mbar_W_nk(u_kn, N_k, f_k): - """Calculate the weight matrix. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - W_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) - The normalized weights. - - Notes - ----- - Equation (9) in JCP MBAR paper. - """ - return jax_mbar_W_nk(u_kn, N_k, f_k) - - -def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None): - """ - Determine dimensionless free energies by a combination of Newton-Raphson iteration and self-consistent iteration. - Picks whichever method gives the lowest gradient. - Is slower than NR since it calculates the log norms twice each iteration. - - OPTIONAL ARGUMENTS - tol (float between 0 and 1) - relative tolerance for convergence (default 1.0e-12) - - options : dictionary of options - gamma (float between 0 and 1) - incrementor for NR iterations (default 1.0). Usually not changed now, since adaptively switch. - maxiter (int) - maximum number of Newton-Raphson iterations (default 10000: either NR converges or doesn't, pretty quickly) - verbose (boolean) - verbosity level for debug output - - NOTES - - This method determines the dimensionless free energies by - minimizing a convex function whose solution is the desired - estimator. The original idea came from the construction of a - likelihood function that independently reproduced the work of - Geyer (see [1] and Section 6 of [2]). This can alternatively be - formulated as a root-finding algorithm for the Z-estimator. More - details of this procedure will follow in a subsequent paper. Only - those states with nonzero counts are include in the estimation - procedure. - - REFERENCES - See Appendix C.2 of [1]. - - """ - # put the defaults here in case we get passed an 'options' dictionary that is only partial - options.setdefault("verbose", False) - options.setdefault("maxiter", 10000) - options.setdefault("print_warning", False) - options.setdefault("gamma", 1.0) - options.setdefault("min_sc_iter", 2) # set a minimum number of self-consistent iterations - - gamma = options["gamma"] - - doneIterating = False - if options["verbose"] == True: - logger.info( - "Determining dimensionless free energies by Newton-Raphson / self-consistent iteration." - ) - - if tol < 4.0 * np.finfo(float).eps: - logger.info("Tolerance may be too close to machine precision to converge.") - - success = False # fail unless solution is found. - # keep track of Newton-Raphson and self-consistent iterations - nr_iter = 0 - sci_iter = 0 - - f_sci = np.zeros(len(f_k), dtype=np.float64) - f_nr = np.zeros(len(f_k), dtype=np.float64) - - # Perform Newton-Raphson iterations (with sci computed on the way) - - # usually calculated at the end of the loop and saved, but we need - # to calculate the first time. - g = mbar_gradient(u_kn, N_k, f_k) # Objective function gradient. - - maxiter = options["maxiter"] - min_sc_iter = options["min_sc_iter"] - warn = "Did not converge." - for iteration in range(0, maxiter): - if use_jit: - (f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr) = jax_core_adaptive( - u_kn, N_k, f_k, options["gamma"] - ) - else: - H = mbar_hessian(u_kn, N_k, f_k) # Objective function hessian - Hinvg = np.linalg.lstsq(H, g, rcond=-1)[0] - Hinvg -= Hinvg[0] - f_nr = f_k - gamma * Hinvg - - # self-consistent iteration gradient norm and saved log sums. - f_sci = self_consistent_update(u_kn, N_k, f_k) - f_sci = f_sci - f_sci[0] # zero out the minimum - g_sci = mbar_gradient(u_kn, N_k, f_sci) - gnorm_sci = dot(g_sci, g_sci) - - # newton raphson gradient norm and saved log sums. - g_nr = mbar_gradient(u_kn, N_k, f_nr) - gnorm_nr = dot(g_nr, g_nr) - - # we could save the gradient, for the next round, but it's not too expensive to - # compute since we are doing the Hessian anyway. - - if options["verbose"]: - logger.info( - "self consistent iteration gradient norm is %10.5g, Newton-Raphson gradient norm is %10.5g" - % (np.sqrt(gnorm_sci), np.sqrt(gnorm_nr)) - ) - # decide which directon to go depending on size of gradient norm - f_old = f_k - - if gnorm_sci < gnorm_nr or sci_iter < min_sc_iter: - f_k = f_sci - g = g_sci - sci_iter += 1 - if options["verbose"]: - if sci_iter < min_sc_iter: - logger.info( - f"Choosing self-consistent iteration on iteration {iteration:d} because min_sci_iter={min_sc_iter:d}" - ) - else: - logger.info( - f"Choosing self-consistent iteration for lower gradient on iteration {iteration:d}" - ) - else: - f_k = f_nr - g = g_nr - nr_iter += 1 - if options["verbose"]: - logger.info(f"Newton-Raphson used on iteration {iteration:}") - - div = np.abs(f_k[1:]) # what we will divide by to get relative difference - zeroed = np.abs(f_k[1:]) < np.min( - [10**-8, tol] - ) # check which values are near enough to zero, hard coded max for now. - div[zeroed] = 1.0 # for these values, use absolute values. - max_delta = np.max(np.abs(f_k[1:] - f_old[1:]) / div) - max_diff = np.max(np.abs(f_sci[1:] - f_nr[1:]) / div) - # add this just to make sure they are not too different. - # if we start with bad states, the f_k - f_k_old might be far off. - if np.isnan(max_delta) or ((max_delta < tol) and max_diff < np.sqrt(tol)): - doneIterating = True - success = True - warn = "Convergence achieved by change in f with respect to previous guess." - break - - if doneIterating: - if options["verbose"]: - logger.info(f"Converged to tolerance of {max_delta:e} in {iteration+1:d} iterations.") - logger.info( - f"Of {iteration+1:d} iterations, {nr_iter:d} were Newton-Raphson iterations and {sci_iter:d} were self-consistent iterations" - ) - if np.all(f_k == 0.0): - logger.info("WARNING: All f_k appear to be zero.") - else: - logger.warning("WARNING: Did not converge to within specified tolerance.") - - if maxiter <= 0: - logger.warning( - f"No iterations ran be cause maximum_iterations was <= 0 ({maxiter:s})!" - ) - else: - logger.warning( - f"max_delta = {max_delta:e}, tol = {tol:e}, maximum_iterations = {maxiter:d}, iterations completed = {iteration:d}" - ) - - results = dict() - results["success"] = success - results["message"] = warn - results["x"] = f_k - - return results - - -@precondition_jit -def jax_core_adaptive(u_kn, N_k, f_k, gamma): - """JAX version of adaptive inner loop. - N_k must be float (should be cast at a higher level) - - """ - - # Perform Newton-Raphson iterations (with sci computed on the way) - g = mbar_gradient(u_kn, N_k, f_k) # Objective function gradient - H = mbar_hessian(u_kn, N_k, f_k) # Objective function hessian - Hinvg = lstsq(H, g, rcond=-1)[0] - Hinvg -= Hinvg[0] - f_nr = f_k - gamma * Hinvg - - # self-consistent iteration gradient norm and saved log sums. - f_sci = self_consistent_update(u_kn, N_k, f_k) - f_sci = f_sci - f_sci[0] # zero out the minimum - g_sci = mbar_gradient(u_kn, N_k, f_sci) - gnorm_sci = dot(g_sci, g_sci) - - # newton raphson gradient norm and saved log sums. - g_nr = mbar_gradient(u_kn, N_k, f_nr) - gnorm_nr = dot(g_nr, g_nr) - - return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr - - -@precondition_jit -def jax_precondition_u_kn(u_kn, N_k, f_k): - """JAX version of precondition_u_kn - for parameters, see precondition_u_kn - N_k must be float (should be cast at a higher level) - - """ - - u_kn = u_kn - u_kn.min(0) - u_kn += (logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - dot(N_k, f_k) / N_k.sum() - return u_kn - - -def precondition_u_kn(u_kn, N_k, f_k): - """Subtract a sample-dependent constant from u_kn to improve precision - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - - Notes - ----- - Returns u_kn - x_n, where x_n is based on the current estimate of f_k. - Upon subtraction of x_n, the MBAR objective function changes by an - additive constant, but its derivatives remain unchanged. We choose - x_n such that the current objective function value is zero, which - should give maximum precision in the objective function. - """ - return jax_precondition_u_kn(u_kn, N_k, f_k) - - -def solve_mbar_once( - u_kn_nonzero, - N_k_nonzero, - f_k_nonzero, - method="adaptive", - tol=1e-12, - continuation=None, - options=None, -): - """Solve MBAR self-consistent equations using some form of equation solver. - - Parameters - ---------- - u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - for the nonempty states - N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state for the nonempty states - f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies for the nonempty states - method : str, optional, default="hybr" - The optimization routine to use. This can be any of the methods - available via scipy.optimize.minimize() or scipy.optimize.root(). - tol : float, optional, default=1E-14 - The convergance tolerance for minimize() or root() - verbose: bool - Whether to print information about the solution method. - options: dict, optional, default=None - Optional dictionary of algorithm-specific parameters. See - scipy.optimize.root or scipy.optimize.minimize for details. - - Returns - ------- - f_k : np.ndarray - The converged reduced free energies. - results : dict - Dictionary containing entire results of optimization routine, may - be useful when debugging convergence. - - Notes - ----- - This function requires that N_k_nonzero > 0--that is, you should have - already dropped all the states for which you have no samples. - Internally, this function works in a reduced coordinate system defined - by subtracting off the first component of f_k and fixing that component - to be zero. - - For fast but precise convergence, we recommend calling this function - multiple times to polish the result. `solve_mbar()` facilitates this. - """ - - # we only validate at the outside of the call - u_kn_nonzero, N_k_nonzeo, f_k_nonzero = validate_inputs(u_kn_nonzero, N_k_nonzero, f_k_nonzero) - f_k_nonzero = f_k_nonzero - f_k_nonzero[0] # Work with reduced dimensions with f_k[0] := 0 - N_k_nonzero = 1.0 * N_k_nonzero # convert to float for acceleration. - u_kn_nonzero = precondition_u_kn(u_kn_nonzero, N_k_nonzero, f_k_nonzero) - - pad = lambda x: np.pad( - x, (1, 0), mode="constant" - ) # Helper function inserts zero before first element - unpad_second_arg = lambda obj, grad: ( - obj, - grad[1:], - ) # Helper function drops first element of gradient - - # Create objective functions / nonlinear equations to send to scipy.optimize, fixing f_0 = 0 - grad = lambda x: mbar_gradient(u_kn_nonzero, N_k_nonzero, pad(x))[ - 1: - ] # Objective function gradient - - grad_and_obj = lambda x: unpad_second_arg( - *mbar_objective_and_gradient(u_kn_nonzero, N_k_nonzero, pad(x)) - ) # Objective function gradient and objective function - - de_jax_grad_and_obj = lambda x: ( - *map(np.array, grad_and_obj(x)), # (...,) Casts to tuple instead of object - ) # Force any jax-based array output to normal numpy for scipy.optimize.minimize. np.asarray does not work. - - hess = lambda x: mbar_hessian(u_kn_nonzero, N_k_nonzero, pad(x))[1:][ - :, 1: - ] # Hessian of objective function - with warnings.catch_warnings(record=True) as w: - if use_jit and method == "BFGS": - fpad = lambda x: npad(x, (1, 0)) - obj = lambda x: mbar_objective(u_kn_nonzero, N_k_nonzero, fpad(x)) - # objective function to be minimized (for derivative free methods, mostly jit) - jax_results = scipy_optimize.minimize( - obj, - f_k_nonzero[1:], - method=method, - tol=tol, - options=dict(maxiter=options["maxiter"]), - ) - results = dict() # there should be a way to copy this. - results["x"] = jax_results[0] - f_k_nonzero = pad(results["x"]) - results["success"] = jax_results[1] - elif method in scipy_minimize_options: - if method in scipy_nohess_options: - hess = None # To suppress warning from passing a hessian function. - results = scipy.optimize.minimize( - de_jax_grad_and_obj, - f_k_nonzero[1:], - jac=True, - hess=hess, - method=method, - tol=tol, - options=options, - ) - f_k_nonzero = pad(results["x"]) - elif method == "adaptive": - results = adaptive(u_kn_nonzero, N_k_nonzero, f_k_nonzero, tol=tol, options=options) - f_k_nonzero = results["x"] - elif method in scipy_root_options: - # find the root in the gradient. - results = scipy.optimize.root( - grad, f_k_nonzero[1:], jac=hess, method=method, tol=tol, options=options - ) - f_k_nonzero = pad(results["x"]) - else: - raise ParameterError(f"Method {method} for solution of free energies not recognized") - - # If there were runtime warnings, show the messages - if len(w) > 0: - can_ignore = True - for warn_msg in w: - if "Unknown solver options" in str(warn_msg.message): - continue - warnings.showwarning( - warn_msg.message, - warn_msg.category, - warn_msg.filename, - warn_msg.lineno, - warn_msg.file, - "", - ) - can_ignore = False # If any warning is not just unknown options, can not skip check - if not can_ignore: - # Ensure MBAR solved correctly - w_nk_check = mbar_W_nk(u_kn_nonzero, N_k_nonzero, f_k_nonzero) - check_w_normalized(w_nk_check, N_k_nonzero) - logger.warning( - "MBAR weights converged within tolerance, despite the SciPy Warnings. Please validate your results." - ) - - return f_k_nonzero, results - - -def solve_mbar(u_kn_nonzero, N_k_nonzero, f_k_nonzero, solver_protocol=None): - """Solve MBAR self-consistent equations using some sequence of equation solvers. - - Parameters - ---------- - u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - for the nonempty states - N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state for the nonempty states - f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies for the nonempty states - solver_protocol : tuple(dict()), optional, default=None - Optional list of dictionaries of steps in solver protocol. - If None, a default protocol will be used. - - Returns - ------- - f_k : np.ndarray - The converged reduced free energies. - all_results : list(dict()) - List of results from each step of solver_protocol. Each element in - list contains the results dictionary from solve_mbar_once() - for the corresponding step. - - Notes - ----- - This function requires that N_k_nonzero > 0--that is, you should have - already dropped all the states for which you have no samples. - Internally, this function works in a reduced coordinate system defined - by subtracting off the first component of f_k and fixing that component - to be zero. - - This function calls `solve_mbar_once()` multiple times to achieve - converged results. Generally, a single call to solve_mbar_once() - will not give fully converged answers because of limited numerical precision. - Each call to `solve_mbar_once()` re-conditions the nonlinear - equations using the current guess. - """ - - if solver_protocol is None: - solver_protocol = DEFAULT_SOLVER_PROTOCOL - - all_fks = [] - all_gnorms = [] - all_results = [] - - for solver in solver_protocol: - f_k_nonzero_result, results = solve_mbar_once( - u_kn_nonzero, N_k_nonzero, f_k_nonzero, **solver - ) - all_fks.append(f_k_nonzero_result) - all_gnorms.append( - np.linalg.norm(mbar_gradient(u_kn_nonzero, N_k_nonzero, f_k_nonzero_result)) - ) - all_results.append(results) - - if results["success"]: - success = True - best_gnorm = all_gnorms[-1] - logger.info(f"Reached a solution to within tolerance with {solver['method']}") - break - else: - logger.warning( - f"Failed to reach a solution to within tolerance with {solver['method']}: trying next method" - ) - logger.info(f"Ending gnorm of method {solver['method']} = {all_gnorms[-1]:e}") - if solver["continuation"]: - f_k_nonzero = f_k_nonzero_result - logger.info("Will continue with results from previous method") - - if results["success"]: - logger.info("Solution found within tolerance!") - else: - i_best_gnorm = np.argmin(all_gnorms) - logger.warning("No solution found to within tolerance.") - best_method = solver_protocol[i_best_gnorm]["method"] - best_gnorm = all_gnorms[i_best_gnorm] - logger.warning( - f"The solution with the smallest gradient {best_gnorm:e} norm is {best_method}" - ) - f_k_nonzero_result = all_fks[i_best_gnorm] - logger.warning( - "Please exercise caution with this solution and consider alternative methods or a different tolerance." - ) - - logger.info(f"Final gradient norm: {best_gnorm:.3g}") - - return f_k_nonzero_result, all_results - - -def solve_mbar_for_all_states(u_kn, N_k, f_k, states_with_samples, solver_protocol): - """Solve for free energies of states with samples, then calculate for - empty states. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - solver_protocol : tuple(dict()), optional, default=None - Sequence of dictionaries of steps in solver protocol for final - stage of refinement. - - Returns - ------- - f_k : np.ndarray, shape=(n_states), dtype='float' - The free energies of states - """ - - if len(states_with_samples) == 1: - f_k_nonzero = np.array([0.0]) - else: - f_k_nonzero, all_results = solve_mbar( - u_kn[states_with_samples], - N_k[states_with_samples], - f_k[states_with_samples], - solver_protocol=solver_protocol, - ) - - f_k[states_with_samples] = np.array(f_k_nonzero) - - # Update all free energies because those from states with zero samples are not correctly computed by solvers. - f_k = self_consistent_update(u_kn, N_k, f_k) - # This is necessary because state 0 might have had zero samples, - # but we still want that state to be the reference with free energy 0. - f_k -= f_k[0] - - return f_k diff --git a/pymbar/mbar_solvers/__init__.py b/pymbar/mbar_solvers/__init__.py new file mode 100644 index 00000000..c2e8055a --- /dev/null +++ b/pymbar/mbar_solvers/__init__.py @@ -0,0 +1,106 @@ +############################################################################## +# pymbar: A Python Library for MBAR +# +# Copyright 2017-2022 University of Colorado Boulder +# Copyright 2010-2017 Memorial Sloan-Kettering Cancer Center +# Portions of this software are Copyright (c) 2010-2016 University of Virginia +# Portions of this software are Copyright (c) 2006-2007 The Regents of the University of California. All Rights Reserved. +# Portions of this software are Copyright (c) 2007-2008 Stanford University and Columbia University. +# +# Authors: Michael Shirts, John Chodera +# Contributors: Kyle Beauchamp, Levi Naden +# +# pymbar is free software: you can redistribute it and/or modify +# it under the terms of the MIT License as +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. +# +# You should have received a copy of the MIT License along with pymbar. +############################################################################## + +""" +########### +pymbar.mbar_solvers +########### + +A module implementing the solvers array operations for the MBAR solvers with various code bases for acceleration. + +All methods have the same calls and returns, independent of their underlying codes for solution. + +Please reference the following if you use this code in your research: + +[1] Shirts MR and Chodera JD. Statistically optimal analysis of samples from multiple equilibrium states. +J. Chem. Phys. 129:124105, 2008. http://dx.doi.org/10.1063/1.2978177 + +""" + +import logging +from .mbar_solver import ( + validate_inputs, + JAX_SOLVER_PROTOCOL, + DEFAULT_SOLVER_PROTOCOL, + ROBUST_SOLVER_PROTOCOL, + BOOTSTRAP_SOLVER_PROTOCOL +) +from .numpy_solver import MBARSolverNumpy + +logger = logging.getLogger(__name__) + +default_solver = MBARSolverNumpy # Set fallback solver +ACCELERATOR_MAP = {"numpy": MBARSolverNumpy} +try: + from .jax_solver import MBARSolverJAX + default_solver = MBARSolverJAX + ACCELERATOR_MAP["jax"] = MBARSolverJAX + logger.info("JAX detected. Using JAX acceleration by default.") +except ImportError: + logger.warning( + "\n" + "********* JAX NOT FOUND *********\n" + " PyMBAR can run faster with JAX \n" + " But will work fine without it \n" + "Either install with pip or conda:\n" + " pip install pybar[jax] \n" + " OR \n" + " conda install pymbar \n" + "*********************************" + ) + + +# Helper function for toggling the solver method +def get_accelerator(accelerator_name: str): + """ + get the accelerator in the namespace for this module + """ + # Saving accelerator to new tag does not change since we're saving the immutable string object + accel = accelerator_name.lower() + if accel not in ACCELERATOR_MAP: + raise ValueError( + f"Accelerator {accel} is not implemented or did not load correctly. Please use one of the following:\n" + + "".join((f"* {a}\n" for a in ACCELERATOR_MAP.keys())) + + f"(case-insentive)\n" + + f"If you expected {accel} to load, please check the logs above for details." + ) + logger.info(f"Getting accelerator {accel}...") + return ACCELERATOR_MAP[accel] + + +# Imports done, handle initialization +module_solver = default_solver() + +# Establish API methods for 4.x consistency +self_consistent_update = module_solver.self_consistent_update +mbar_gradient = module_solver.mbar_gradient +mbar_objective = module_solver.mbar_objective +mbar_objective_and_gradient = module_solver.mbar_objective_and_gradient +mbar_hessian = module_solver.mbar_hessian +mbar_log_W_nk = module_solver.mbar_log_W_nk +mbar_W_nk = module_solver.mbar_W_nk +adaptive = module_solver.adaptive +precondition_u_kn = module_solver.precondition_u_kn +solve_mbar_once = module_solver.solve_mbar_once +solve_mbar = module_solver.solve_mbar +solve_mbar_for_all_states = module_solver.solve_mbar_for_all_states diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py new file mode 100644 index 00000000..59cc0488 --- /dev/null +++ b/pymbar/mbar_solvers/jax_solver.py @@ -0,0 +1,146 @@ + +"""Set the imports for the JAX accelerated methods""" + +import logging +from functools import partial, wraps + +from jax.config import config + +import jax.numpy as jnp +from jax.numpy.linalg import lstsq +import jax.scipy.optimize +from jax.scipy.special import logsumexp + +from jax import jit + +from pymbar.mbar_solvers.mbar_solver import MBARSolver + +logger = logging.getLogger(__name__) + + +class MBARSolverJAX(MBARSolver): + """ + Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. + + Default solver is the numpy solution + """ + + def __init__(self): + # Throw warning only if the whole of JAX is found + if not config.x64_enabled: + # Warn that we're going to be setting 64 bit jax + logger.warning( + "\n" + "****** PyMBAR will use 64-bit JAX! *******\n" + "* JAX is currently set to 32-bit bitsize *\n" + "* which is its default. *\n" + "* *\n" + "* PyMBAR requires 64-bit mode and WILL *\n" + "* enable JAX's 64-bit mode when called. *\n" + "* *\n" + "* This MAY cause problems with other *\n" + "* Uses of JAX in the same code. *\n" + "******************************************\n" + ) + super().__init__() + + @property + def exp(self): + return jnp.exp + + @property + def sum(self): + return jnp.sum + + @property + def diag(self): + return jnp.diag + + @property + def newaxis(self): + return jnp.newaxis + + @property + def dot(self): + return jnp.dot + + @property + def s_(self): + return jnp.s_ + + @property + def pad(self): + return jnp.pad + + @property + def lstsq(self): + return lstsq + + @property + def optimize(self): + return jax.scipy.optimize + + @property + def logsumexp(self): + return logsumexp + + @property + def jit(self): + return jit + + + class JitDecorators: + """ + Internal helper class to do any preconditioning of JIT operations + Uses this buried class to allow a decorator with access to the "self" within its precondition + """ + @classmethod + def precondition_jit(cls, jitable_fn): + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection + def staggered_jit(self, *args, **kwargs): + # This will only trigger if JAX is set + if not config.x64_enabled: + # Warn that JAX 64-bit will being turned on + logger.warning( + "\n" + "******* JAX 64-bit mode is now on! *******\n" + "* JAX is now set to 64-bit mode! *\n" + "* This MAY cause problems with other *\n" + "* uses of JAX in the same code. *\n" + "******************************************\n" + ) + config.update("jax_enable_x64", True) + + jited_fn = partial(self.jit, static_argnums=(0,))(jitable_fn) + return jited_fn(*args, **kwargs) + return staggered_jit + + precondition_jit = JitDecorators.precondition_jit + + @precondition_jit + def _adaptive_core(self, u_kn, N_k, f_k, gamma): + """JAX version of adaptive inner loop. + N_k must be float (should be cast at a higher level) + + """ + + # Perform Newton-Raphson iterations (with sci computed on the way) + g = self.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient + H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian + Hinvg = lstsq(H, g, rcond=-1)[0] + Hinvg -= Hinvg[0] + f_nr = f_k - gamma * Hinvg + + # self-consistent iteration gradient norm and saved log sums. + f_sci = self.self_consistent_update(u_kn, N_k, f_k) + f_sci = f_sci - f_sci[0] # zero out the minimum + g_sci = self.mbar_gradient(u_kn, N_k, f_sci) + gnorm_sci = self.dot(g_sci, g_sci) + + # newton raphson gradient norm and saved log sums. + g_nr = self.mbar_gradient(u_kn, N_k, f_nr) + gnorm_nr = self.dot(g_nr, g_nr) + + return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr diff --git a/pymbar/mbar_solvers/mbar_solver.py b/pymbar/mbar_solvers/mbar_solver.py new file mode 100644 index 00000000..eb884e64 --- /dev/null +++ b/pymbar/mbar_solvers/mbar_solver.py @@ -0,0 +1,788 @@ +import logging +import warnings +from functools import wraps +from abc import ABC, abstractmethod + +import numpy as np + +# Optimize imported here and below as the jax-optimized one is jax or passthrough, but this is required regardless +import scipy.optimize +from pymbar.utils import ensure_type, check_w_normalized, ParameterError +from pymbar.mbar_solvers.solver_api import MBARSolverAPI, MBARSolverAcceleratorMethods + +logger = logging.getLogger(__name__) + +# Note on "pylint: disable=invalid-unary-operand-type" +# Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax. +# Can be mostly ignored + +# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving +# the MBAR equations. +# Note: we use tuples instead of lists to avoid accidental mutability. +JAX_SOLVER_PROTOCOL = ( + dict(method="BFGS", continuation=True), + dict(method="adaptive", options=dict(min_sc_iter=0)), +) + +DEFAULT_SOLVER_PROTOCOL = ( + dict(method="hybr", continuation=True), + dict(method="adaptive", options=dict(min_sc_iter=0)), +) + +ROBUST_SOLVER_PROTOCOL = ( + dict(method="adaptive", options=dict(maxiter=1000)), + dict(method="L-BFGS-B", options=dict(maxiter=1000)), +) + +BOOTSTRAP_SOLVER_PROTOCOL = (dict(method="adaptive", options=dict(min_sc_iter=0)),) + +# Allows all of the gradient based methods, but not the non-gradient methods ["Nelder-Mead", "Powell", "COBYLA"]", +scipy_minimize_options = [ + "L-BFGS-B", + "dogleg", + "CG", + "BFGS", + "Newton-CG", + "TNC", + "trust-ncg", + "trust-krylov", + "trust-exact", + "SLSQP", +] +scipy_nohess_options = [ + "L-BFGS-B", + "BFGS", + "CG", + "TNC", + "SLSQP", +] # don't pass a hessian to these to avoid warnings to these. +scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included + + +class MBARSolver(MBARSolverAPI, MBARSolverAcceleratorMethods): + """ + Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. + + Default solver is the numpy solution + """ + + JITABLE_IMPLEMENTATION_METHODS = ( + "_jit_self_consistent_update", + ) + + def __init__(self): + """Apply the precondition to each of the JITED_METHODS""" + for method in ( + self.JITABLE_ACCELERATOR_METHODS + + self.JITABLE_API_METHODS + + self.JITABLE_IMPLEMENTATION_METHODS + ): + setattr(self, method, self._precondition_jit(getattr(self, method))) + + def self_consistent_update(self, u_kn, N_k, f_k, states_with_samples=None): + """Return an improved guess for the dimensionless free energies + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + f_k : np.ndarray, shape=(n_states), dtype='float' + Updated estimate of f_k + + Notes + ----- + Equation C3 in MBAR JCP paper. + """ + + # Only the states with samples can contribute to the denominator term. + # Precondition before feeding the op to the JIT'd function + # In theory, this can be computed with jax.lax.cond, but trying to reuse code for non-jax paths + states_with_samples = self.s_[:] if states_with_samples is None else states_with_samples + return self._jit_self_consistent_update( + u_kn[states_with_samples], N_k[states_with_samples], f_k[states_with_samples] + ) + + def _jit_self_consistent_update(self, u_kn, N_k, f_k): + """JAX version of self_consistent update. For parameters, see self_consistent_update. + N_k must be float (should be cast at a higher level) + + """ + # Asteroid + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + # All states can contribute to the numerator term. Check transpose + return -1.0 * self.logsumexp( + -log_denominator_n - u_kn, axis=1 + ) # pylint: disable=invalid-unary-operand-type + + def mbar_gradient(self, u_kn, N_k, f_k): + """Gradient of MBAR objective function. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + grad : np.ndarray, dtype=float, shape=(n_states) + Gradient of mbar_objective + + Notes + ----- + This is equation C6 in the JCP MBAR paper. + """ + + # N_k must be float (should be cast at a higher level) + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) + return -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) + + def mbar_objective(self, u_kn, N_k, f_k): + """Calculates objective function for MBAR. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + + Returns + ------- + obj : float + Objective function + + Notes + ----- + This objective function is essentially a doubly-summed partition function and is + quite sensitive to precision loss from both overflow and underflow. For optimal + results, u_kn can be preconditioned by subtracting out a `n` dependent + vector. + + More optimal precision, the objective function uses math.fsum for the + outermost sum and logsumexp for the inner sum. + """ + + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + obj = sum(log_denominator_n) - self.dot(N_k, f_k) + + return obj + + def mbar_objective_and_gradient(self, u_kn, N_k, f_k): + """Calculates both objective function and gradient for MBAR. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + + Returns + ------- + obj : float + Objective function + grad : np.ndarray, dtype=float, shape=(n_states) + Gradient of objective function + + Notes + ----- + This objective function is essentially a doubly-summed partition function and is + quite sensitive to precision loss from both overflow and underflow. For optimal + results, u_kn can be preconditioned by subtracting out a `n` dependent + vector. + + More optimal precision, the objective function uses math.fsum for the + outermost sum and logsumexp for the inner sum. + + The gradient is equation C6 in the JCP MBAR paper; the objective + function is its integral. + """ + + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) + grad = -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) + + obj = sum(log_denominator_n) - self.dot(N_k, f_k) + + return obj, grad + + def mbar_hessian(self, u_kn, N_k, f_k): + """Hessian of MBAR objective function. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + H : np.ndarray, dtype=float, shape=(n_states, n_states) + Hessian of mbar objective function. + + Notes + ----- + Equation (C9) in JCP MBAR paper. + """ + + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] + W = self.exp(logW) + + H = self.dot(W.T, W) + H *= N_k + H *= N_k[:, self.newaxis] + H -= self.diag(W.sum(0) * N_k) + return -1.0 * H + + def mbar_log_W_nk(self, u_kn, N_k, f_k): + """Calculate the log weight matrix. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + logW_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) + The normalized log weights. + + Notes + ----- + Equation (9) in JCP MBAR paper. + """ + + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] + return logW + + def mbar_W_nk(self, u_kn, N_k, f_k): + """Calculate the weight matrix. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + W_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) + The normalized weights. + + Notes + ----- + Equation (9) in JCP MBAR paper. + """ + + return self.exp(self.mbar_log_W_nk(u_kn, N_k, f_k)) + + def precondition_u_kn(self, u_kn, N_k, f_k): + """Subtract a sample-dependent constant from u_kn to improve precision + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + + Notes + ----- + Returns u_kn - x_n, where x_n is based on the current estimate of f_k. + Upon subtraction of x_n, the MBAR objective function changes by an + additive constant, but its derivatives remain unchanged. We choose + x_n such that the current objective function value is zero, which + should give maximum precision in the objective function. + """ + u_kn = u_kn - u_kn.min(0) + u_kn += (self.logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - self.dot(N_k, f_k) / N_k.sum() + return u_kn + + def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): + """ + Determine dimensionless free energies by a combination of Newton-Raphson iteration and self-consistent iteration. + Picks whichever method gives the lowest gradient. + Is slower than NR since it calculates the log norms twice each iteration. + + OPTIONAL ARGUMENTS + tol (float between 0 and 1) - relative tolerance for convergence (default 1.0e-12) + + options : dictionary of options + gamma (float between 0 and 1) - incrementor for NR iterations (default 1.0). Usually not changed now, since adaptively switch. + maxiter (int) - maximum number of Newton-Raphson iterations (default 10000: either NR converges or doesn't, pretty quickly) + verbose (boolean) - verbosity level for debug output + + NOTES + + This method determines the dimensionless free energies by + minimizing a convex function whose solution is the desired + estimator. The original idea came from the construction of a + likelihood function that independently reproduced the work of + Geyer (see [1] and Section 6 of [2]). This can alternatively be + formulated as a root-finding algorithm for the Z-estimator. More + details of this procedure will follow in a subsequent paper. Only + those states with nonzero counts are include in the estimation + procedure. + + REFERENCES + See Appendix C.2 of [1]. + + """ + # put the defaults here in case we get passed an 'options' dictionary that is only partial + options.setdefault("verbose", False) + options.setdefault("maxiter", 10000) + options.setdefault("print_warning", False) + options.setdefault("gamma", 1.0) + options.setdefault("min_sc_iter", 2) # set a minimum number of self-consistent iterations + + doneIterating = False + if options["verbose"] == True: + logger.info( + "Determining dimensionless free energies by Newton-Raphson / self-consistent iteration." + ) + + if tol < 4.0 * np.finfo(float).eps: + logger.info("Tolerance may be too close to machine precision to converge.") + + success = False # fail unless solution is found. + # keep track of Newton-Raphson and self-consistent iterations + nr_iter = 0 + sci_iter = 0 + + f_sci = np.zeros(len(f_k), dtype=np.float64) + f_nr = np.zeros(len(f_k), dtype=np.float64) + + # Perform Newton-Raphson iterations (with sci computed on the way) + + # usually calculated at the end of the loop and saved, but we need + # to calculate the first time. + g = self.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient. + + maxiter = options["maxiter"] + min_sc_iter = options["min_sc_iter"] + warn = "Did not converge." + for iteration in range(0, maxiter): + f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr = self._adaptive_core( + u_kn, N_k, f_k, g, options + ) + # we could save the gradient, for the next round, but it's not too expensive to + # compute since we are doing the Hessian anyway. + if options["verbose"]: + logger.info( + "self consistent iteration gradient norm is %10.5g, Newton-Raphson gradient norm is %10.5g" + % (np.sqrt(gnorm_sci), np.sqrt(gnorm_nr)) + ) + # decide which direction to go depending on size of gradient norm + f_old = f_k + + if gnorm_sci < gnorm_nr or sci_iter < min_sc_iter: + f_k = f_sci + g = g_sci + sci_iter += 1 + if options["verbose"]: + if sci_iter < min_sc_iter: + logger.info( + f"Choosing self-consistent iteration on iteration {iteration:d} because min_sci_iter={min_sc_iter:d}" + ) + else: + logger.info( + f"Choosing self-consistent iteration for lower gradient on iteration {iteration:d}" + ) + else: + f_k = f_nr + g = g_nr + nr_iter += 1 + if options["verbose"]: + logger.info(f"Newton-Raphson used on iteration {iteration:}") + + div = np.abs(f_k[1:]) # what we will divide by to get relative difference + zeroed = np.abs(f_k[1:]) < np.min( + [10**-8, tol] + ) # check which values are near enough to zero, hard coded max for now. + div[zeroed] = 1.0 # for these values, use absolute values. + max_delta = np.max(np.abs(f_k[1:] - f_old[1:]) / div) + max_diff = np.max(np.abs(f_sci[1:] - f_nr[1:]) / div) + # add this just to make sure they are not too different. + # if we start with bad states, the f_k - f_k_old might be far off. + if np.isnan(max_delta) or ((max_delta < tol) and max_diff < np.sqrt(tol)): + doneIterating = True + success = True + warn = "Convergence achieved by change in f with respect to previous guess." + break + + if doneIterating: + if options["verbose"]: + logger.info(f"Converged to tolerance of {max_delta:e} in {iteration+1:d} iterations.") + logger.info( + f"Of {iteration+1:d} iterations, {nr_iter:d} were Newton-Raphson iterations and {sci_iter:d} were self-consistent iterations" + ) + if np.all(f_k == 0.0): + logger.info("WARNING: All f_k appear to be zero.") + else: + logger.warning("WARNING: Did not converge to within specified tolerance.") + + if maxiter <= 0: + logger.warning( + f"No iterations ran be cause maximum_iterations was <= 0 ({maxiter:s})!" + ) + else: + logger.warning( + f"max_delta = {max_delta:e}, tol = {tol:e}, maximum_iterations = {maxiter:d}, iterations completed = {iteration:d}" + ) + + results = dict() + results["success"] = success + results["message"] = warn + results["x"] = f_k + + return results + + def solve_mbar_once( + self, + u_kn_nonzero, + N_k_nonzero, + f_k_nonzero, + method="adaptive", + tol=1e-12, + continuation=None, + options=None, + ): + """Solve MBAR self-consistent equations using some form of equation solver. + + Parameters + ---------- + u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + for the nonempty states + N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state for the nonempty states + f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies for the nonempty states + method : str, optional, default="hybr" + The optimization routine to use. This can be any of the methods + available via scipy.optimize.minimize() or scipy.optimize.root(). + tol : float, optional, default=1E-14 + The convergance tolerance for minimize() or root() + verbose: bool + Whether to print information about the solution method. + options: dict, optional, default=None + Optional dictionary of algorithm-specific parameters. See + scipy.optimize.root or scipy.optimize.minimize for details. + + Returns + ------- + f_k : np.ndarray + The converged reduced free energies. + results : dict + Dictionary containing entire results of optimization routine, may + be useful when debugging convergence. + + Notes + ----- + This function requires that N_k_nonzero > 0--that is, you should have + already dropped all the states for which you have no samples. + Internally, this function works in a reduced coordinate system defined + by subtracting off the first component of f_k and fixing that component + to be zero. + + For fast but precise convergence, we recommend calling this function + multiple times to polish the result. `solve_mbar()` facilitates this. + """ + + # we only validate at the outside of the call + u_kn_nonzero, N_k_nonzeo, f_k_nonzero = validate_inputs(u_kn_nonzero, N_k_nonzero, f_k_nonzero) + f_k_nonzero = f_k_nonzero - f_k_nonzero[0] # Work with reduced dimensions with f_k[0] := 0 + N_k_nonzero = 1.0 * N_k_nonzero # convert to float for acceleration. + u_kn_nonzero = self.precondition_u_kn(u_kn_nonzero, N_k_nonzero, f_k_nonzero) + + pad = lambda x: np.pad( + x, (1, 0), mode="constant" + ) # Helper function inserts zero before first element + unpad_second_arg = lambda obj, grad: ( + obj, + grad[1:], + ) # Helper function drops first element of gradient + + # Create objective functions / nonlinear equations to send to scipy.optimize, fixing f_0 = 0 + grad = lambda x: self.mbar_gradient(u_kn_nonzero, N_k_nonzero, pad(x))[ + 1: + ] # Objective function gradient + + grad_and_obj = lambda x: unpad_second_arg( + *self.mbar_objective_and_gradient(u_kn_nonzero, N_k_nonzero, pad(x)) + ) # Objective function gradient and objective function + + de_jax_grad_and_obj = lambda x: ( + *map(np.array, grad_and_obj(x)), # (...,) Casts to tuple instead of object + ) # Force any jax-based array output to normal numpy for scipy.optimize.minimize. np.asarray does not work. + + hess = lambda x: self.mbar_hessian(u_kn_nonzero, N_k_nonzero, pad(x))[1:][ + :, 1: + ] # Hessian of objective function + with warnings.catch_warnings(record=True) as w: + if method == "BFGS": # Might be a way to fold this in now that accelerators are class-ified + fpad = lambda x: self.pad(x, (1, 0)) + obj = lambda x: self.mbar_objective(u_kn_nonzero, N_k_nonzero, fpad(x)) + # objective function to be minimized (for derivative free methods, mostly jit) + minimize_results = self.optimize.minimize( + obj, + f_k_nonzero[1:], + method=method, + tol=tol, + options=dict(maxiter=options["maxiter"]), + ) + results = dict() # there should be a way to copy this. + results["x"] = minimize_results[0] + f_k_nonzero = pad(results["x"]) + results["success"] = minimize_results[1] + elif method in scipy_minimize_options: + if method in scipy_nohess_options: + hess = None # To suppress warning from passing a hessian function. + # This needs to be stock scipy.optimize (at least it won't work for JAX) + results = scipy.optimize.minimize( + de_jax_grad_and_obj, + f_k_nonzero[1:], + jac=True, + hess=hess, + method=method, + tol=tol, + options=options, + ) + f_k_nonzero = pad(results["x"]) + elif method == "adaptive": + results = self.adaptive(u_kn_nonzero, N_k_nonzero, f_k_nonzero, tol=tol, options=options) + f_k_nonzero = results["x"] + elif method in scipy_root_options: + # find the root in the gradient. + results = scipy.optimize.root( + grad, f_k_nonzero[1:], jac=hess, method=method, tol=tol, options=options + ) + f_k_nonzero = pad(results["x"]) + else: + raise ParameterError(f"Method {method} for solution of free energies not recognized") + + # If there were runtime warnings, show the messages + if len(w) > 0: + can_ignore = True + for warn_msg in w: + if "Unknown solver options" in str(warn_msg.message): + continue + warnings.showwarning( + warn_msg.message, + warn_msg.category, + warn_msg.filename, + warn_msg.lineno, + warn_msg.file, + "", + ) + can_ignore = False # If any warning is not just unknown options, can not skip check + if not can_ignore: + # Ensure MBAR solved correctly + w_nk_check = self.mbar_W_nk(u_kn_nonzero, N_k_nonzero, f_k_nonzero) + check_w_normalized(w_nk_check, N_k_nonzero) + logger.warning( + "MBAR weights converged within tolerance, despite the SciPy Warnings. Please validate your results." + ) + + return f_k_nonzero, results + + def solve_mbar(self, u_kn_nonzero, N_k_nonzero, f_k_nonzero, solver_protocol=None): + """Solve MBAR self-consistent equations using some sequence of equation solvers. + + Parameters + ---------- + u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + for the nonempty states + N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state for the nonempty states + f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies for the nonempty states + solver_protocol : tuple(dict()), optional, default=None + Optional list of dictionaries of steps in solver protocol. + If None, a default protocol will be used. + + Returns + ------- + f_k : np.ndarray + The converged reduced free energies. + all_results : list(dict()) + List of results from each step of solver_protocol. Each element in + list contains the results dictionary from solve_mbar_once() + for the corresponding step. + + Notes + ----- + This function requires that N_k_nonzero > 0--that is, you should have + already dropped all the states for which you have no samples. + Internally, this function works in a reduced coordinate system defined + by subtracting off the first component of f_k and fixing that component + to be zero. + + This function calls `solve_mbar_once()` multiple times to achieve + converged results. Generally, a single call to solve_mbar_once() + will not give fully converged answers because of limited numerical precision. + Each call to `solve_mbar_once()` re-conditions the nonlinear + equations using the current guess. + """ + + if solver_protocol is None: + solver_protocol = DEFAULT_SOLVER_PROTOCOL + + all_fks = [] + all_gnorms = [] + all_results = [] + + for solver in solver_protocol: + f_k_nonzero_result, results = self.solve_mbar_once( + u_kn_nonzero, N_k_nonzero, f_k_nonzero, **solver + ) + all_fks.append(f_k_nonzero_result) + all_gnorms.append( + np.linalg.norm(self.mbar_gradient(u_kn_nonzero, N_k_nonzero, f_k_nonzero_result)) + ) + all_results.append(results) + + if results["success"]: + success = True + best_gnorm = all_gnorms[-1] + logger.info(f"Reached a solution to within tolerance with {solver['method']}") + break + else: + logger.warning( + f"Failed to reach a solution to within tolerance with {solver['method']}: trying next method" + ) + logger.info(f"Ending gnorm of method {solver['method']} = {all_gnorms[-1]:e}") + if solver["continuation"]: + f_k_nonzero = f_k_nonzero_result + logger.info("Will continue with results from previous method") + + if results["success"]: + logger.info("Solution found within tolerance!") + else: + i_best_gnorm = np.argmin(all_gnorms) + logger.warning("No solution found to within tolerance.") + best_method = solver_protocol[i_best_gnorm]["method"] + best_gnorm = all_gnorms[i_best_gnorm] + logger.warning( + f"The solution with the smallest gradient {best_gnorm:e} norm is {best_method}" + ) + f_k_nonzero_result = all_fks[i_best_gnorm] + logger.warning( + "Please exercise caution with this solution and consider alternative methods or a different tolerance." + ) + + logger.info(f"Final gradient norm: {best_gnorm:.3g}") + + return f_k_nonzero_result, all_results + + def solve_mbar_for_all_states(self, u_kn, N_k, f_k, states_with_samples, solver_protocol): + """Solve for free energies of states with samples, then calculate for + empty states. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + solver_protocol : tuple(dict()), optional, default=None + Sequence of dictionaries of steps in solver protocol for final + stage of refinement. + + Returns + ------- + f_k : np.ndarray, shape=(n_states), dtype='float' + The free energies of states + """ + + if len(states_with_samples) == 1: + f_k_nonzero = np.array([0.0]) + else: + f_k_nonzero, all_results = self.solve_mbar( + u_kn[states_with_samples], + N_k[states_with_samples], + f_k[states_with_samples], + solver_protocol=solver_protocol, + ) + + f_k[states_with_samples] = np.array(f_k_nonzero) + + # Update all free energies because those from states with zero samples are not correctly computed by solvers. + f_k = self.self_consistent_update(u_kn, N_k, f_k) + # This is necessary because state 0 might have had zero samples, + # but we still want that state to be the reference with free energy 0. + f_k -= f_k[0] + + return f_k + + +def validate_inputs(u_kn, N_k, f_k): + """Check types and return inputs for MBAR calculations. + + Parameters + ---------- + u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies or unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies or unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='float' + The number of samples in each state. Converted to float because this cast is required when log is calculated. + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + """ + n_states, n_samples = u_kn.shape + + u_kn = ensure_type(u_kn, "float", 2, "u_kn or Q_kn", shape=(n_states, n_samples)) + N_k = ensure_type( + N_k, "float", 1, "N_k", shape=(n_states,), warn_on_cast=False + ) # Autocast to float because will be eventually used in float calculations. + f_k = ensure_type(f_k, "float", 1, "f_k", shape=(n_states,)) + + return u_kn, N_k, f_k diff --git a/pymbar/mbar_solvers/numpy_solver.py b/pymbar/mbar_solvers/numpy_solver.py new file mode 100644 index 00000000..9f12de9d --- /dev/null +++ b/pymbar/mbar_solvers/numpy_solver.py @@ -0,0 +1,95 @@ +# Import the methods functionally +# This is admittedly non-standard, but solves the following use case: +# * Has JAX +# * Wants to use PyMBAR +# * Does NOT want JAX to be set to 64-bit mode +# Also solves the future use case of different accelerator, +# but want to selectively use them + +# Fallback/default solver methods +# NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY +import numpy as np +from numpy.linalg import lstsq +import scipy.optimize +from scipy.special import logsumexp + +from pymbar.mbar_solvers.mbar_solver import MBARSolver + + +class MBARSolverNumpy(MBARSolver): + """ + Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. + + Default solver is the numpy solution + """ + + @property + def exp(self): + return np.exp + + @property + def sum(self): + return np.sum + + @property + def diag(self): + return np.diag + + @property + def newaxis(self): + return np.newaxis + + @property + def dot(self): + return np.dot + + @property + def s_(self): + return np.s_ + + @property + def pad(self): + return np.pad + + @property + def lstsq(self): + return lstsq + + @property + def optimize(self): + return scipy.optimize + + @property + def logsumexp(self): + return logsumexp + + @staticmethod + def _passthrough_jit(fn): + return fn + + @property + def jit(self): + """Passthrough JIT""" + return self._passthrough_jit + + def _adaptive_core(self, u_kn, N_k, f_k, g, options): + """ + Core function to execute per iteration of a method. + """ + gamma = options["gamma"] # Handle options + H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian + Hinvg = np.linalg.lstsq(H, g, rcond=-1)[0] + Hinvg -= Hinvg[0] + f_nr = f_k - gamma * Hinvg + + # self-consistent iteration gradient norm and saved log sums. + f_sci = self.self_consistent_update(u_kn, N_k, f_k) + f_sci = f_sci - f_sci[0] # zero out the minimum + g_sci = self.mbar_gradient(u_kn, N_k, f_sci) + gnorm_sci = self.dot(g_sci, g_sci) + + # newton raphson gradient norm and saved log sums. + g_nr = self.mbar_gradient(u_kn, N_k, f_nr) + gnorm_nr = self.dot(g_nr, g_nr) + + return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr diff --git a/pymbar/mbar_solvers/solver_api.py b/pymbar/mbar_solvers/solver_api.py new file mode 100644 index 00000000..9996d40e --- /dev/null +++ b/pymbar/mbar_solvers/solver_api.py @@ -0,0 +1,157 @@ +""" +API Definitions of the solver module to be consistent with PyMBAR 4.0 +and for subclassing any solvers for implementation. +""" + +from functools import wraps +from abc import ABC, abstractmethod + + +class MBARSolverAPI(ABC): + """ + API for MBAR solvers + """ + + JITABLE_API_METHODS = ( + "mbar_gradient", + "mbar_objective", + "mbar_objective_and_gradient", + "mbar_hessian", + "mbar_log_W_nk", + "mbar_W_nk", + "precondition_u_kn" + ) + + @abstractmethod + def self_consistent_update(self, u_kn, N_k, f_k, states_with_samples=None): + pass + + @abstractmethod + def mbar_gradient(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_objective(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_objective_and_gradient(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_hessian(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_log_W_nk(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_W_nk(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): + pass + + @abstractmethod + def precondition_u_kn(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def solve_mbar_once( + self, + u_kn_nonzero, + N_k_nonzero, + f_k_nonzero, + method="adaptive", + tol=1e-12, + continuation=None, + options=None, + ): + pass + + @abstractmethod + def solve_mbar(self, u_kn_nonzero, N_k_nonzero, f_k_nonzero, solver_protocol=None): + pass + + @abstractmethod + def solve_mbar_for_all_states(self, u_kn, N_k, f_k, states_with_samples, solver_protocol): + pass + + +class MBARSolverAcceleratorMethods(ABC): + """ + Methods which have to be implemented by MBAR solver accelerators + """ + + JITABLE_ACCELERATOR_METHODS = ( + "_adaptive_core", + ) + + @property + @abstractmethod + def exp(self): + pass + + @property + @abstractmethod + def sum(self): + pass + + @property + @abstractmethod + def diag(self): + pass + + @property + @abstractmethod + def newaxis(self): + pass + + @property + @abstractmethod + def dot(self): + pass + + @property + @abstractmethod + def s_(self): + pass + + @property + @abstractmethod + def pad(self): + pass + + @property + @abstractmethod + def lstsq(self): + pass + + @property + @abstractmethod + def optimize(self): + pass + + @property + @abstractmethod + def logsumexp(self): + pass + + @property + @abstractmethod + def jit(self): + pass + + def _precondition_jit(self, jitable_fn): + @wraps(jitable_fn) # Helper to ensure the decorated function still registers for docs and inspection + def wrapped_precog_jit(self, *args, **kwargs): + # Uses "self" here as intercepted first arg for instance of MBARSolver + jited_fn = self.jit(jitable_fn) + return jited_fn(*args, **kwargs) + return wrapped_precog_jit + + @abstractmethod + def _adaptive_core(self, u_kn, N_k, f_k, g, options): + pass From 9216bad7b53c8f301d331451dc114f3a4b4cff42 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Tue, 20 Jun 2023 15:08:03 -0400 Subject: [PATCH 04/12] Next iterations, trying to JIT without causing all kinds of pain of recompile. Seems to run much slower in tests right now. --- pymbar/mbar_solvers/jax_solver.py | 65 +++++++++++++++---------------- pymbar/mbar_solvers/solver_api.py | 29 ++++++++++++++ 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py index 59cc0488..95acf276 100644 --- a/pymbar/mbar_solvers/jax_solver.py +++ b/pymbar/mbar_solvers/jax_solver.py @@ -88,44 +88,41 @@ def logsumexp(self): def jit(self): return jit - - class JitDecorators: - """ - Internal helper class to do any preconditioning of JIT operations - Uses this buried class to allow a decorator with access to the "self" within its precondition - """ - @classmethod - def precondition_jit(cls, jitable_fn): - @wraps( - jitable_fn - ) # Helper to ensure the decorated function still registers for docs and inspection - def staggered_jit(self, *args, **kwargs): - # This will only trigger if JAX is set - if not config.x64_enabled: - # Warn that JAX 64-bit will being turned on - logger.warning( - "\n" - "******* JAX 64-bit mode is now on! *******\n" - "* JAX is now set to 64-bit mode! *\n" - "* This MAY cause problems with other *\n" - "* uses of JAX in the same code. *\n" - "******************************************\n" - ) - config.update("jax_enable_x64", True) - - jited_fn = partial(self.jit, static_argnums=(0,))(jitable_fn) - return jited_fn(*args, **kwargs) - return staggered_jit - - precondition_jit = JitDecorators.precondition_jit - - @precondition_jit - def _adaptive_core(self, u_kn, N_k, f_k, gamma): + # def _precondition_jit(self, jitable_fn): + # @wraps(jitable_fn) # Helper to ensure the decorated function still registers for docs and inspection + # def wrapped_precog_jit(self, *args, **kwargs): + # # Uses "self" here as intercepted first arg for instance of MBARSolver + # jited_fn = self.jit(jitable_fn) + # return jited_fn(*args, **kwargs) + # return wrapped_precog_jit + + def _precondition_jit(self, jitable_fn): + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection + def staggered_jit(*args, **kwargs): + # This will only trigger if JAX is set + if not config.x64_enabled: + # Warn that JAX 64-bit will being turned on + logger.warning( + "\n" + "******* JAX 64-bit mode is now on! *******\n" + "* JAX is now set to 64-bit mode! *\n" + "* This MAY cause problems with other *\n" + "* uses of JAX in the same code. *\n" + "******************************************\n" + ) + config.update("jax_enable_x64", True) + jited_fn = self.jit(jitable_fn) + return jited_fn(*args, **kwargs) + return staggered_jit + + def _adaptive_core(self, u_kn, N_k, f_k, g, options): """JAX version of adaptive inner loop. N_k must be float (should be cast at a higher level) """ - + gamma = options["gamma"] # Perform Newton-Raphson iterations (with sci computed on the way) g = self.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian diff --git a/pymbar/mbar_solvers/solver_api.py b/pymbar/mbar_solvers/solver_api.py index 9996d40e..52aca04f 100644 --- a/pymbar/mbar_solvers/solver_api.py +++ b/pymbar/mbar_solvers/solver_api.py @@ -155,3 +155,32 @@ def wrapped_precog_jit(self, *args, **kwargs): @abstractmethod def _adaptive_core(self, u_kn, N_k, f_k, g, options): pass + + # def __hash__(self): + # return hash((self.exp, + # self.sum, + # self.diag, + # self.newaxis, + # self.dot, + # self.s_, + # self.pad, + # self.lstsq, + # self.optimize, + # self.logsumexp, + # self.jit + # )) + # + # def __eq__(self): + # return hash((self.exp, + # self.sum, + # self.diag, + # self.newaxis, + # self.dot, + # self.s_, + # self.pad, + # self.lstsq, + # self.optimize, + # self.logsumexp, + # self.jit + # )) + From e4b942fa66d1d7af3047e179fca5a77d9930a9eb Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Thu, 22 Jun 2023 09:43:12 -0400 Subject: [PATCH 05/12] Debug and try to use pytrees --- pymbar/mbar_solvers/__init__.py | 1 + pymbar/mbar_solvers/jax_solver.py | 27 ++++++++++++++++++ pymbar/mbar_solvers/solver_api.py | 46 ++++++++++++------------------- 3 files changed, 45 insertions(+), 29 deletions(-) diff --git a/pymbar/mbar_solvers/__init__.py b/pymbar/mbar_solvers/__init__.py index c2e8055a..b02d7ed4 100644 --- a/pymbar/mbar_solvers/__init__.py +++ b/pymbar/mbar_solvers/__init__.py @@ -54,6 +54,7 @@ try: from .jax_solver import MBARSolverJAX default_solver = MBARSolverJAX + # default_solver = MBARSolverNumpy # Set fallback solver ACCELERATOR_MAP["jax"] = MBARSolverJAX logger.info("JAX detected. Using JAX acceleration by default.") except ImportError: diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py index 95acf276..be716283 100644 --- a/pymbar/mbar_solvers/jax_solver.py +++ b/pymbar/mbar_solvers/jax_solver.py @@ -17,7 +17,10 @@ logger = logging.getLogger(__name__) +# hell: https://github.com/google/jax/discussions/16020 + +@jax.tree_util.register_pytree_node_class class MBARSolverJAX(MBARSolver): """ Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. @@ -44,6 +47,27 @@ def __init__(self): ) super().__init__() + def tree_flatten(self): + children = () # arrays / dynamic values + aux_data = { + "exp": self.exp, + "sum": self.sum, + "diag": self.diag, + "newaxis": self.newaxis, + "dot": self.dot, + "s_": self._s, + "pad": self.pad, + "lstsq": self.lstsq, + "optimize": self.optimize, + "logsumexp": self.logsumexp + } # static values + aux_data = {} + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls() + @property def exp(self): return jnp.exp @@ -114,6 +138,9 @@ def staggered_jit(*args, **kwargs): ) config.update("jax_enable_x64", True) jited_fn = self.jit(jitable_fn) + # jited_fn = partial(jit, static_argnums=(0,))(jitable_fn) + # breakpoint() + # print(jited_fn._cache_size()) return jited_fn(*args, **kwargs) return staggered_jit diff --git a/pymbar/mbar_solvers/solver_api.py b/pymbar/mbar_solvers/solver_api.py index 52aca04f..37ec95f6 100644 --- a/pymbar/mbar_solvers/solver_api.py +++ b/pymbar/mbar_solvers/solver_api.py @@ -146,7 +146,7 @@ def jit(self): def _precondition_jit(self, jitable_fn): @wraps(jitable_fn) # Helper to ensure the decorated function still registers for docs and inspection - def wrapped_precog_jit(self, *args, **kwargs): + def wrapped_precog_jit(*args, **kwargs): # Uses "self" here as intercepted first arg for instance of MBARSolver jited_fn = self.jit(jitable_fn) return jited_fn(*args, **kwargs) @@ -156,31 +156,19 @@ def wrapped_precog_jit(self, *args, **kwargs): def _adaptive_core(self, u_kn, N_k, f_k, g, options): pass - # def __hash__(self): - # return hash((self.exp, - # self.sum, - # self.diag, - # self.newaxis, - # self.dot, - # self.s_, - # self.pad, - # self.lstsq, - # self.optimize, - # self.logsumexp, - # self.jit - # )) - # - # def __eq__(self): - # return hash((self.exp, - # self.sum, - # self.diag, - # self.newaxis, - # self.dot, - # self.s_, - # self.pad, - # self.lstsq, - # self.optimize, - # self.logsumexp, - # self.jit - # )) - + def __hash__(self): + return hash((self.exp, + self.sum, + self.diag, + self.newaxis, + self.dot, + self.s_, + self.pad, + self.lstsq, + self.optimize, + self.logsumexp, + self.jit + )) + + def __eq__(self, other): + return isinstance(other, MBARSolverAcceleratorMethods) and self.__hash__ == other.__hash__ From adc30bdcaa44958a272ed257d5c63aa917fb879f Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Thu, 22 Jun 2023 10:57:46 -0400 Subject: [PATCH 06/12] was it just sum? --- pymbar/mbar_solvers/mbar_solver.py | 38 ++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/pymbar/mbar_solvers/mbar_solver.py b/pymbar/mbar_solvers/mbar_solver.py index eb884e64..9557e3cb 100644 --- a/pymbar/mbar_solvers/mbar_solver.py +++ b/pymbar/mbar_solvers/mbar_solver.py @@ -72,6 +72,7 @@ class MBARSolver(MBARSolverAPI, MBARSolverAcceleratorMethods): def __init__(self): """Apply the precondition to each of the JITED_METHODS""" + self._construct_static_methods() for method in ( self.JITABLE_ACCELERATOR_METHODS + self.JITABLE_API_METHODS + @@ -79,6 +80,12 @@ def __init__(self): ): setattr(self, method, self._precondition_jit(getattr(self, method))) + def _construct_static_methods(self): + """ + Hoo boy. + This function creates static methods for each of the + """ + def self_consistent_update(self, u_kn, N_k, f_k, states_with_samples=None): """Return an improved guess for the dimensionless free energies @@ -143,11 +150,6 @@ def mbar_gradient(self, u_kn, N_k, f_k): This is equation C6 in the JCP MBAR paper. """ - # N_k must be float (should be cast at a higher level) - log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) - return -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) - def mbar_objective(self, u_kn, N_k, f_k): """Calculates objective function for MBAR. @@ -177,11 +179,6 @@ def mbar_objective(self, u_kn, N_k, f_k): outermost sum and logsumexp for the inner sum. """ - log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - obj = sum(log_denominator_n) - self.dot(N_k, f_k) - - return obj - def mbar_objective_and_gradient(self, u_kn, N_k, f_k): """Calculates both objective function and gradient for MBAR. @@ -786,3 +783,24 @@ def validate_inputs(u_kn, N_k, f_k): f_k = ensure_type(f_k, "float", 1, "f_k", shape=(n_states,)) return u_kn, N_k, f_k + +def generate_static_mbar_gradient(solver: MBARSolver): + def mbar_gradient(u_kn, N_k, f_k): + # N_k must be float (should be cast at a higher level) + log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = solver.logsumexp(-log_denominator_n - u_kn, axis=1) + return -1 * N_k * (1.0 - solver.exp(f_k + log_numerator_k)) + return mbar_gradient + +def generate_static_mbar_objective(solver: MBARSolver) + def mbar_objective(u_kn, N_k, f_k): + log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + obj = solver.sum(log_denominator_n) - solver.dot(N_k, f_k) + + return obj +def generate_static_mbar_objective_and_gradient(solver: MBARSolver) +def generate_static_mbar_hessian(solver: MBARSolver) +def generate_static_mbar_log_W_nk(solver: MBARSolver) +def generate_static_mbar_mbar_W_nk(solver: MBARSolver) +def generate_static_precondition_u_kn(solver: MBARSolver) + From 758a02ae9bdd8b06d91d454a457281a73fce0749 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Thu, 22 Jun 2023 14:14:26 -0400 Subject: [PATCH 07/12] Solved the accelerator problem. Casts methods to static methods every time to ensure that JAX is not serializing the class itself as constants, dramatically slowing down the code execution. Makes for slightly more complicated method call jumping, but otherwise uses the same code paths. --- pymbar/mbar.py | 10 +- pymbar/mbar_solvers/__init__.py | 26 ++-- pymbar/mbar_solvers/jax_solver.py | 69 +++------ pymbar/mbar_solvers/mbar_solver.py | 216 +++++++++++++++++++--------- pymbar/mbar_solvers/numpy_solver.py | 3 +- pymbar/mbar_solvers/solver_api.py | 48 +++---- pymbar/tests/test_accelerators.py | 19 ++- 7 files changed, 227 insertions(+), 164 deletions(-) diff --git a/pymbar/mbar.py b/pymbar/mbar.py index 0ca48485..6c705d75 100644 --- a/pymbar/mbar.py +++ b/pymbar/mbar.py @@ -96,7 +96,7 @@ def __init__( n_bootstraps=0, bootstrap_solver_protocol=None, rseed=None, - accelerator="jax", + accelerator=None, ): """Initialize multistate Bennett acceptance ratio (MBAR) on a set of simulation data. @@ -187,10 +187,10 @@ def __init__( We usually just do steps of adaptive sampling without. "robust" would be the backup. Default: dict(method="adaptive", options=dict(min_sc_iter=0)), - accelerator: str, optional, default="jax" - Set the accelerator method to try. Attempts to use the named accelerator for the solvers, and then + accelerator: str, optional, default=None + Set the accelerator library. Attempts to use the named accelerator for the solvers, and then stores the output accelerator after trying to set. Not case-sensitive. "numpy" is no-accelerators, - and will work fine. + and will work fine. Default accelerator is JAX if nothing specified and JAX installed, else NumPy (Valid options: jax, numpy) @@ -234,7 +234,7 @@ def __init__( """ # Set the accelerator methods for the solvers - self.solver = mbar_solvers.get_accelerator(accelerator)() + self.solver = mbar_solvers.get_accelerator(accelerator) # Store local copies of necessary data. # N_k[k] is the number of samples from state k, some of which might be zero. diff --git a/pymbar/mbar_solvers/__init__.py b/pymbar/mbar_solvers/__init__.py index b02d7ed4..861ee14f 100644 --- a/pymbar/mbar_solvers/__init__.py +++ b/pymbar/mbar_solvers/__init__.py @@ -38,24 +38,29 @@ """ import logging +from typing import Union + from .mbar_solver import ( validate_inputs, JAX_SOLVER_PROTOCOL, DEFAULT_SOLVER_PROTOCOL, ROBUST_SOLVER_PROTOCOL, - BOOTSTRAP_SOLVER_PROTOCOL + BOOTSTRAP_SOLVER_PROTOCOL, ) +from .mbar_solver import MBARSolver from .numpy_solver import MBARSolverNumpy logger = logging.getLogger(__name__) -default_solver = MBARSolverNumpy # Set fallback solver +INSTANCED_ACCELERATORS = {} # Cache the accelerators to avoid re-jit on instancing ACCELERATOR_MAP = {"numpy": MBARSolverNumpy} +default_solver = "numpy" # Set fallback solver + try: from .jax_solver import MBARSolverJAX - default_solver = MBARSolverJAX - # default_solver = MBARSolverNumpy # Set fallback solver + ACCELERATOR_MAP["jax"] = MBARSolverJAX + default_solver = "jax" logger.info("JAX detected. Using JAX acceleration by default.") except ImportError: logger.warning( @@ -72,12 +77,16 @@ # Helper function for toggling the solver method -def get_accelerator(accelerator_name: str): +def get_accelerator(accelerator_name: Union[str, None]) -> MBARSolver: """ get the accelerator in the namespace for this module """ + if accelerator_name is None: + accelerator_name = default_solver # Saving accelerator to new tag does not change since we're saving the immutable string object accel = accelerator_name.lower() + if accel in INSTANCED_ACCELERATORS: + return INSTANCED_ACCELERATORS[accel] if accel not in ACCELERATOR_MAP: raise ValueError( f"Accelerator {accel} is not implemented or did not load correctly. Please use one of the following:\n" @@ -85,12 +94,13 @@ def get_accelerator(accelerator_name: str): + f"(case-insentive)\n" + f"If you expected {accel} to load, please check the logs above for details." ) - logger.info(f"Getting accelerator {accel}...") - return ACCELERATOR_MAP[accel] + logger.info(f"Instancing accelerator {accel}...") + INSTANCED_ACCELERATORS[accel] = ACCELERATOR_MAP[accel]() + return INSTANCED_ACCELERATORS[accel] # Imports done, handle initialization -module_solver = default_solver() +module_solver = get_accelerator(default_solver) # Establish API methods for 4.x consistency self_consistent_update = module_solver.self_consistent_update diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py index be716283..36c8808c 100644 --- a/pymbar/mbar_solvers/jax_solver.py +++ b/pymbar/mbar_solvers/jax_solver.py @@ -1,8 +1,7 @@ - """Set the imports for the JAX accelerated methods""" import logging -from functools import partial, wraps +from functools import wraps from jax.config import config @@ -17,10 +16,7 @@ logger = logging.getLogger(__name__) -# hell: https://github.com/google/jax/discussions/16020 - -@jax.tree_util.register_pytree_node_class class MBARSolverJAX(MBARSolver): """ Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. @@ -43,31 +39,17 @@ def __init__(self): "* *\n" "* This MAY cause problems with other *\n" "* Uses of JAX in the same code. *\n" + "* *\n" + "* If you want 32-bit JAX and PyMBAR *\n" + "* please set: *\n" + "* accelerator=numpy *\n" + "* when you instance the MBAR object *\n" "******************************************\n" ) + # Double __ in middle name intentional here + self._static__adaptive_core = generate_static_adaptive_core(self) super().__init__() - def tree_flatten(self): - children = () # arrays / dynamic values - aux_data = { - "exp": self.exp, - "sum": self.sum, - "diag": self.diag, - "newaxis": self.newaxis, - "dot": self.dot, - "s_": self._s, - "pad": self.pad, - "lstsq": self.lstsq, - "optimize": self.optimize, - "logsumexp": self.logsumexp - } # static values - aux_data = {} - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls() - @property def exp(self): return jnp.exp @@ -112,14 +94,6 @@ def logsumexp(self): def jit(self): return jit - # def _precondition_jit(self, jitable_fn): - # @wraps(jitable_fn) # Helper to ensure the decorated function still registers for docs and inspection - # def wrapped_precog_jit(self, *args, **kwargs): - # # Uses "self" here as intercepted first arg for instance of MBARSolver - # jited_fn = self.jit(jitable_fn) - # return jited_fn(*args, **kwargs) - # return wrapped_precog_jit - def _precondition_jit(self, jitable_fn): @wraps( jitable_fn @@ -138,33 +112,36 @@ def staggered_jit(*args, **kwargs): ) config.update("jax_enable_x64", True) jited_fn = self.jit(jitable_fn) - # jited_fn = partial(jit, static_argnums=(0,))(jitable_fn) - # breakpoint() - # print(jited_fn._cache_size()) return jited_fn(*args, **kwargs) + return staggered_jit - def _adaptive_core(self, u_kn, N_k, f_k, g, options): + def _adaptive_core(self, u_kn, N_k, f_k, g, gamma): """JAX version of adaptive inner loop. N_k must be float (should be cast at a higher level) """ - gamma = options["gamma"] + + +def generate_static_adaptive_core(solver: MBARSolver): + def _adaptive_core(u_kn, N_k, f_k, g, gamma): # Perform Newton-Raphson iterations (with sci computed on the way) - g = self.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient - H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian + g = solver.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient + H = solver.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian Hinvg = lstsq(H, g, rcond=-1)[0] Hinvg -= Hinvg[0] f_nr = f_k - gamma * Hinvg # self-consistent iteration gradient norm and saved log sums. - f_sci = self.self_consistent_update(u_kn, N_k, f_k) + f_sci = solver.self_consistent_update(u_kn, N_k, f_k) f_sci = f_sci - f_sci[0] # zero out the minimum - g_sci = self.mbar_gradient(u_kn, N_k, f_sci) - gnorm_sci = self.dot(g_sci, g_sci) + g_sci = solver.mbar_gradient(u_kn, N_k, f_sci) + gnorm_sci = solver.dot(g_sci, g_sci) # newton raphson gradient norm and saved log sums. - g_nr = self.mbar_gradient(u_kn, N_k, f_nr) - gnorm_nr = self.dot(g_nr, g_nr) + g_nr = solver.mbar_gradient(u_kn, N_k, f_nr) + gnorm_nr = solver.dot(g_nr, g_nr) return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr + + return _adaptive_core diff --git a/pymbar/mbar_solvers/mbar_solver.py b/pymbar/mbar_solvers/mbar_solver.py index 9557e3cb..2acc297e 100644 --- a/pymbar/mbar_solvers/mbar_solver.py +++ b/pymbar/mbar_solvers/mbar_solver.py @@ -1,7 +1,5 @@ import logging import warnings -from functools import wraps -from abc import ABC, abstractmethod import numpy as np @@ -66,26 +64,59 @@ class MBARSolver(MBARSolverAPI, MBARSolverAcceleratorMethods): Default solver is the numpy solution """ - JITABLE_IMPLEMENTATION_METHODS = ( - "_jit_self_consistent_update", - ) + JITABLE_IMPLEMENTATION_METHODS = ("jit_self_consistent_update",) def __init__(self): - """Apply the precondition to each of the JITED_METHODS""" - self._construct_static_methods() + """ + Generate all the static methods to make JIT compile clean + + All the methods overwritten in this code are cast to static methods because JIT (at least in JAX) + suffers a massive performance loss if you try to JIT a bound method of a class (i.e. anything with + a reference to 'self'). See: https://github.com/google/jax/discussions/16020#discussioncomment-5915882 + + Marking self as static with a partial doesn't work because we're wrapping the function already once, and we + still need the functions/properties found in the class to make this an extensible class for other accelerators + in the future. + See https://jax.readthedocs.io/en/latest/faq.html#strategy-2-marking-self-as-static + + The PyTree approach didn't seem to work due to the same problem as marking self static because it still needs + properties so no gains were made. Its possible I (LNN) misinterpreted something here and this can be used to + simplify the code in the future to avoid writing all the static-method generators. + https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree + + If the default methods are used (which are written to call the static generator anyway), then the JIT cache + will re-compile them every time, which defeats the whole point. The calls are left in to leave a developer + breadcrumb as to what is supposed to go there, and to make linter's happy. + """ + # Dont use just _{method} because any result with leading __ mangles name + # E.g. __adaptive_core -> _{ClassName}_adaptive_core + self._static_mbar_gradient = generate_static_mbar_gradient(self) + self._static_mbar_objective = generate_static_mbar_objective(self) + self._static_mbar_objective_and_gradient = generate_static_mbar_objective_and_gradient( + self + ) + self._static_mbar_hessian = generate_static_mbar_hessian(self) + self._static_mbar_log_W_nk = generate_static_mbar_log_W_nk(self) + self._static_mbar_W_nk = generate_static_mbar_W_nk(self, self._static_mbar_log_W_nk) + self._static_jit_self_consistent_update = generate_jit_self_consistent_update(self) + self._static_precondition_u_kn = generate_static_precondition_u_kn(self) + # Apply the precondition to each of the JITABLE_METHODS for method in ( - self.JITABLE_ACCELERATOR_METHODS + - self.JITABLE_API_METHODS + - self.JITABLE_IMPLEMENTATION_METHODS + self.JITABLE_ACCELERATOR_METHODS + + self.JITABLE_API_METHODS + + self.JITABLE_IMPLEMENTATION_METHODS ): + # Attempt to staticfy if method "_{method}" exists + if hasattr(self, "_static_" + method): + doc = getattr(self, method).__doc__ + static = getattr(self, "_static_" + method) + # Replace with static name + setattr(self, method, static) + # Reset docstring + getattr(self, method).__doc__ = doc + # Jit setattr(self, method, self._precondition_jit(getattr(self, method))) - def _construct_static_methods(self): - """ - Hoo boy. - This function creates static methods for each of the - """ - def self_consistent_update(self, u_kn, N_k, f_k, states_with_samples=None): """Return an improved guess for the dimensionless free energies @@ -112,21 +143,15 @@ def self_consistent_update(self, u_kn, N_k, f_k, states_with_samples=None): # Precondition before feeding the op to the JIT'd function # In theory, this can be computed with jax.lax.cond, but trying to reuse code for non-jax paths states_with_samples = self.s_[:] if states_with_samples is None else states_with_samples - return self._jit_self_consistent_update( + return self.jit_self_consistent_update( u_kn[states_with_samples], N_k[states_with_samples], f_k[states_with_samples] ) - def _jit_self_consistent_update(self, u_kn, N_k, f_k): + def jit_self_consistent_update(self, u_kn, N_k, f_k): """JAX version of self_consistent update. For parameters, see self_consistent_update. N_k must be float (should be cast at a higher level) """ - # Asteroid - log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - # All states can contribute to the numerator term. Check transpose - return -1.0 * self.logsumexp( - -log_denominator_n - u_kn, axis=1 - ) # pylint: disable=invalid-unary-operand-type def mbar_gradient(self, u_kn, N_k, f_k): """Gradient of MBAR objective function. @@ -149,6 +174,7 @@ def mbar_gradient(self, u_kn, N_k, f_k): ----- This is equation C6 in the JCP MBAR paper. """ + return generate_static_mbar_gradient(self)(u_kn, N_k, f_k) def mbar_objective(self, u_kn, N_k, f_k): """Calculates objective function for MBAR. @@ -178,6 +204,7 @@ def mbar_objective(self, u_kn, N_k, f_k): More optimal precision, the objective function uses math.fsum for the outermost sum and logsumexp for the inner sum. """ + return generate_static_mbar_objective(self)(u_kn, N_k, f_k) def mbar_objective_and_gradient(self, u_kn, N_k, f_k): """Calculates both objective function and gradient for MBAR. @@ -212,17 +239,10 @@ def mbar_objective_and_gradient(self, u_kn, N_k, f_k): The gradient is equation C6 in the JCP MBAR paper; the objective function is its integral. """ + return generate_static_mbar_objective_and_gradient(self)(u_kn, N_k, f_k) - log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) - grad = -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) - - obj = sum(log_denominator_n) - self.dot(N_k, f_k) - - return obj, grad - - def mbar_hessian(self, u_kn, N_k, f_k): - """Hessian of MBAR objective function. + def mbar_hessian(self, u_kn, N_k, f_k) -> np.ndarray: + """Hessian of Mmbar_hessianBAR objective function. Parameters ---------- @@ -242,16 +262,7 @@ def mbar_hessian(self, u_kn, N_k, f_k): ----- Equation (C9) in JCP MBAR paper. """ - - log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] - W = self.exp(logW) - - H = self.dot(W.T, W) - H *= N_k - H *= N_k[:, self.newaxis] - H -= self.diag(W.sum(0) * N_k) - return -1.0 * H + return generate_static_mbar_hessian(self)(u_kn, N_k, f_k) def mbar_log_W_nk(self, u_kn, N_k, f_k): """Calculate the log weight matrix. @@ -274,10 +285,7 @@ def mbar_log_W_nk(self, u_kn, N_k, f_k): ----- Equation (9) in JCP MBAR paper. """ - - log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] - return logW + return generate_static_mbar_log_W_nk(self)(u_kn, N_k, f_k) def mbar_W_nk(self, u_kn, N_k, f_k): """Calculate the weight matrix. @@ -300,8 +308,7 @@ def mbar_W_nk(self, u_kn, N_k, f_k): ----- Equation (9) in JCP MBAR paper. """ - - return self.exp(self.mbar_log_W_nk(u_kn, N_k, f_k)) + return generate_static_mbar_W_nk(self, self.mbar_log_W_nk)(u_kn, N_k, f_k) def precondition_u_kn(self, u_kn, N_k, f_k): """Subtract a sample-dependent constant from u_kn to improve precision @@ -328,9 +335,7 @@ def precondition_u_kn(self, u_kn, N_k, f_k): x_n such that the current objective function value is zero, which should give maximum precision in the objective function. """ - u_kn = u_kn - u_kn.min(0) - u_kn += (self.logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - self.dot(N_k, f_k) / N_k.sum() - return u_kn + return generate_static_precondition_u_kn(self)(u_kn, N_k, f_k) def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): """ @@ -397,7 +402,7 @@ def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): warn = "Did not converge." for iteration in range(0, maxiter): f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr = self._adaptive_core( - u_kn, N_k, f_k, g, options + u_kn, N_k, f_k, g, options["gamma"] ) # we could save the gradient, for the next round, but it's not too expensive to # compute since we are doing the Hessian anyway. @@ -446,7 +451,9 @@ def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): if doneIterating: if options["verbose"]: - logger.info(f"Converged to tolerance of {max_delta:e} in {iteration+1:d} iterations.") + logger.info( + f"Converged to tolerance of {max_delta:e} in {iteration+1:d} iterations." + ) logger.info( f"Of {iteration+1:d} iterations, {nr_iter:d} were Newton-Raphson iterations and {sci_iter:d} were self-consistent iterations" ) @@ -524,7 +531,9 @@ def solve_mbar_once( """ # we only validate at the outside of the call - u_kn_nonzero, N_k_nonzeo, f_k_nonzero = validate_inputs(u_kn_nonzero, N_k_nonzero, f_k_nonzero) + u_kn_nonzero, N_k_nonzeo, f_k_nonzero = validate_inputs( + u_kn_nonzero, N_k_nonzero, f_k_nonzero + ) f_k_nonzero = f_k_nonzero - f_k_nonzero[0] # Work with reduced dimensions with f_k[0] := 0 N_k_nonzero = 1.0 * N_k_nonzero # convert to float for acceleration. u_kn_nonzero = self.precondition_u_kn(u_kn_nonzero, N_k_nonzero, f_k_nonzero) @@ -554,8 +563,11 @@ def solve_mbar_once( :, 1: ] # Hessian of objective function with warnings.catch_warnings(record=True) as w: - if method == "BFGS": # Might be a way to fold this in now that accelerators are class-ified + if ( + method == "BFGS" + ): # Might be a way to fold this in now that accelerators are class-ified fpad = lambda x: self.pad(x, (1, 0)) + # Make sure to use the static method here obj = lambda x: self.mbar_objective(u_kn_nonzero, N_k_nonzero, fpad(x)) # objective function to be minimized (for derivative free methods, mostly jit) minimize_results = self.optimize.minimize( @@ -584,7 +596,9 @@ def solve_mbar_once( ) f_k_nonzero = pad(results["x"]) elif method == "adaptive": - results = self.adaptive(u_kn_nonzero, N_k_nonzero, f_k_nonzero, tol=tol, options=options) + results = self.adaptive( + u_kn_nonzero, N_k_nonzero, f_k_nonzero, tol=tol, options=options + ) f_k_nonzero = results["x"] elif method in scipy_root_options: # find the root in the gradient. @@ -593,7 +607,9 @@ def solve_mbar_once( ) f_k_nonzero = pad(results["x"]) else: - raise ParameterError(f"Method {method} for solution of free energies not recognized") + raise ParameterError( + f"Method {method} for solution of free energies not recognized" + ) # If there were runtime warnings, show the messages if len(w) > 0: @@ -609,7 +625,9 @@ def solve_mbar_once( warn_msg.file, "", ) - can_ignore = False # If any warning is not just unknown options, can not skip check + can_ignore = ( + False # If any warning is not just unknown options, can not skip check + ) if not can_ignore: # Ensure MBAR solved correctly w_nk_check = self.mbar_W_nk(u_kn_nonzero, N_k_nonzero, f_k_nonzero) @@ -784,23 +802,87 @@ def validate_inputs(u_kn, N_k, f_k): return u_kn, N_k, f_k + def generate_static_mbar_gradient(solver: MBARSolver): def mbar_gradient(u_kn, N_k, f_k): # N_k must be float (should be cast at a higher level) log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) log_numerator_k = solver.logsumexp(-log_denominator_n - u_kn, axis=1) return -1 * N_k * (1.0 - solver.exp(f_k + log_numerator_k)) + return mbar_gradient - -def generate_static_mbar_objective(solver: MBARSolver) + + +def generate_static_mbar_objective(solver: MBARSolver): def mbar_objective(u_kn, N_k, f_k): log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) obj = solver.sum(log_denominator_n) - solver.dot(N_k, f_k) return obj -def generate_static_mbar_objective_and_gradient(solver: MBARSolver) -def generate_static_mbar_hessian(solver: MBARSolver) -def generate_static_mbar_log_W_nk(solver: MBARSolver) -def generate_static_mbar_mbar_W_nk(solver: MBARSolver) -def generate_static_precondition_u_kn(solver: MBARSolver) + return mbar_objective + + +def generate_static_mbar_objective_and_gradient(solver: MBARSolver): + def mbar_objective_and_gradient(u_kn, N_k, f_k): + log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = solver.logsumexp(-log_denominator_n - u_kn, axis=1) + grad = -1 * N_k * (1.0 - solver.exp(f_k + log_numerator_k)) + + obj = solver.sum(log_denominator_n) - solver.dot(N_k, f_k) + + return obj, grad + + return mbar_objective_and_gradient + + +def generate_static_mbar_hessian(solver: MBARSolver): + def mbar_hessian(u_kn, N_k, f_k): + log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, solver.newaxis] + W = solver.exp(logW) + + H = solver.dot(W.T, W) + H *= N_k + H *= N_k[:, solver.newaxis] + H -= solver.diag(W.sum(0) * N_k) + return -1.0 * H + + return mbar_hessian + + +def generate_static_mbar_log_W_nk(solver: MBARSolver): + def mbar_log_W_nk(u_kn, N_k, f_k): + log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, solver.newaxis] + return logW + + return mbar_log_W_nk + + +def generate_static_mbar_W_nk(solver: MBARSolver, static_mbar_log_W_nk: callable): + def mbar_W_nk(u_kn, N_k, f_k): + return solver.exp(static_mbar_log_W_nk(u_kn, N_k, f_k)) + + return mbar_W_nk + + +def generate_static_precondition_u_kn(solver: MBARSolver): + def precondition_u_kn(u_kn, N_k, f_k): + u_kn = u_kn - u_kn.min(0) + u_kn += (solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - solver.dot(N_k, f_k) / N_k.sum() + return u_kn + + return precondition_u_kn + + +def generate_jit_self_consistent_update(solver: MBARSolver): + def jit_self_consistent_update(u_kn, N_k, f_k): + # Asteroid + log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + # All states can contribute to the numerator term. Check transpose + return -1.0 * solver.logsumexp( + -log_denominator_n - u_kn, axis=1 + ) # pylint: disable=invalid-unary-operand-type + + return jit_self_consistent_update diff --git a/pymbar/mbar_solvers/numpy_solver.py b/pymbar/mbar_solvers/numpy_solver.py index 9f12de9d..bb000240 100644 --- a/pymbar/mbar_solvers/numpy_solver.py +++ b/pymbar/mbar_solvers/numpy_solver.py @@ -72,11 +72,10 @@ def jit(self): """Passthrough JIT""" return self._passthrough_jit - def _adaptive_core(self, u_kn, N_k, f_k, g, options): + def _adaptive_core(self, u_kn, N_k, f_k, g, gamma): """ Core function to execute per iteration of a method. """ - gamma = options["gamma"] # Handle options H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian Hinvg = np.linalg.lstsq(H, g, rcond=-1)[0] Hinvg -= Hinvg[0] diff --git a/pymbar/mbar_solvers/solver_api.py b/pymbar/mbar_solvers/solver_api.py index 37ec95f6..d8725563 100644 --- a/pymbar/mbar_solvers/solver_api.py +++ b/pymbar/mbar_solvers/solver_api.py @@ -19,7 +19,7 @@ class MBARSolverAPI(ABC): "mbar_hessian", "mbar_log_W_nk", "mbar_W_nk", - "precondition_u_kn" + "precondition_u_kn", ) @abstractmethod @@ -60,15 +60,15 @@ def precondition_u_kn(self, u_kn, N_k, f_k): @abstractmethod def solve_mbar_once( - self, - u_kn_nonzero, - N_k_nonzero, - f_k_nonzero, - method="adaptive", - tol=1e-12, - continuation=None, - options=None, - ): + self, + u_kn_nonzero, + N_k_nonzero, + f_k_nonzero, + method="adaptive", + tol=1e-12, + continuation=None, + options=None, + ): pass @abstractmethod @@ -85,9 +85,7 @@ class MBARSolverAcceleratorMethods(ABC): Methods which have to be implemented by MBAR solver accelerators """ - JITABLE_ACCELERATOR_METHODS = ( - "_adaptive_core", - ) + JITABLE_ACCELERATOR_METHODS = ("_adaptive_core",) @property @abstractmethod @@ -145,30 +143,16 @@ def jit(self): pass def _precondition_jit(self, jitable_fn): - @wraps(jitable_fn) # Helper to ensure the decorated function still registers for docs and inspection + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection def wrapped_precog_jit(*args, **kwargs): - # Uses "self" here as intercepted first arg for instance of MBARSolver + # Uses "self" here as intercepted first arg for instance of the decorated class jited_fn = self.jit(jitable_fn) return jited_fn(*args, **kwargs) + return wrapped_precog_jit @abstractmethod def _adaptive_core(self, u_kn, N_k, f_k, g, options): pass - - def __hash__(self): - return hash((self.exp, - self.sum, - self.diag, - self.newaxis, - self.dot, - self.s_, - self.pad, - self.lstsq, - self.optimize, - self.logsumexp, - self.jit - )) - - def __eq__(self, other): - return isinstance(other, MBARSolverAcceleratorMethods) and self.__hash__ == other.__hash__ diff --git a/pymbar/tests/test_accelerators.py b/pymbar/tests/test_accelerators.py index 508622a8..4be8a002 100644 --- a/pymbar/tests/test_accelerators.py +++ b/pymbar/tests/test_accelerators.py @@ -5,6 +5,7 @@ import pytest from pymbar import MBAR +from pymbar.mbar_solvers import get_accelerator, default_solver from pymbar.utils_for_testing import assert_equal, assert_allclose # Pylint doesn't like the interplay between pytest and importing fixtures. disabled the one problem. @@ -72,7 +73,7 @@ def test_mbar_accelerators_are_accurate(only_test_data, accelerator): accelerator_name, accelerator_check = accelerator test, x_n, u_kn = only_test_data["test"], only_test_data["x_n"], only_test_data["u_kn"] x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode="u_kn") - mbar = MBAR(u_kn, N_k, verbose=True, n_bootstraps=200, accelerator=accelerator_name) + mbar = build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check, boostraps=200) results = mbar.compute_free_energy_differences() fe = results["Delta_f"] fe_sigma = results["dDelta_f"] @@ -80,10 +81,10 @@ def test_mbar_accelerators_are_accurate(only_test_data, accelerator): accelerator_check(mbar) -def build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check): +def build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check, boostraps=0): """Helper function to build an MBAR object""" - mbar = MBAR(u_kn, N_k, verbose=True, accelerator=accelerator_name) - assert mbar.accelerator == accelerator_name + mbar = MBAR(u_kn, N_k, verbose=True, accelerator=accelerator_name, n_bootstraps=boostraps) + assert mbar.solver == get_accelerator(accelerator_name) accelerator_check(mbar) return mbar @@ -105,3 +106,13 @@ def test_mbar_accelerators_can_toggle(static_ukn_nk, accelerator, fallback_accel # Rebuild the accelerated version again. mbar_rebuild = build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check) assert_allclose(mbar.f_k, mbar_rebuild.f_k) + + +def test_default_acclerator_is_correct(static_ukn_nk): + u_kn, N_k_output = static_ukn_nk + + def blank_check(*args): + return True + + mbar = build_out_an_mbar(u_kn, N_k, default_solver, blank_check) + assert mbar.solver == get_accelerator(default_solver) From f392dca58622d3cd73ab49adaf40047380893b46 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Thu, 22 Jun 2023 16:03:00 -0400 Subject: [PATCH 08/12] Fix pytest doctest trying to import everything and let it skip (had to go into source code to find this flag) Fix lint complaining about jax import on no-jax systems by wrapping JAX and raising appropriately --- .github/workflows/CI.yaml | 2 +- pymbar/mbar_solvers/jax_solver.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 0e377a21..62410f9f 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -62,7 +62,7 @@ jobs: - name: Run tests (pytest) shell: bash -l {0} run: | - pytest -v --cov=$PACKAGE --cov-report=xml --color=yes --doctest-modules $PACKAGE/ + pytest -v --cov=$PACKAGE --cov-report=xml --color=yes --doctest-modules --doctest-ignore-import-errors $PACKAGE/ - name: Run examples shell: bash -l {0} diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py index 36c8808c..048d89a8 100644 --- a/pymbar/mbar_solvers/jax_solver.py +++ b/pymbar/mbar_solvers/jax_solver.py @@ -3,14 +3,17 @@ import logging from functools import wraps -from jax.config import config +try: + from jax.config import config -import jax.numpy as jnp -from jax.numpy.linalg import lstsq -import jax.scipy.optimize -from jax.scipy.special import logsumexp + import jax.numpy as jnp + from jax.numpy.linalg import lstsq + import jax.scipy.optimize + from jax.scipy.special import logsumexp -from jax import jit + from jax import jit +except ImportError: + raise ImportError("JAX not found!") from pymbar.mbar_solvers.mbar_solver import MBARSolver From d65e882491c9d57575e675972889305debffb1ed Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Thu, 22 Jun 2023 16:38:45 -0400 Subject: [PATCH 09/12] It turns out that BFGS is really slow without actual JIT, and the calls are weird, so added a "real_jit" property that can be set on implementation. --- pymbar/mbar_solvers/jax_solver.py | 4 ++++ pymbar/mbar_solvers/mbar_solver.py | 4 ++-- pymbar/mbar_solvers/solver_api.py | 4 ++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py index 048d89a8..f68a2598 100644 --- a/pymbar/mbar_solvers/jax_solver.py +++ b/pymbar/mbar_solvers/jax_solver.py @@ -97,6 +97,10 @@ def logsumexp(self): def jit(self): return jit + @property + def real_jit(self): + return True + def _precondition_jit(self, jitable_fn): @wraps( jitable_fn diff --git a/pymbar/mbar_solvers/mbar_solver.py b/pymbar/mbar_solvers/mbar_solver.py index 2acc297e..e1a249b2 100644 --- a/pymbar/mbar_solvers/mbar_solver.py +++ b/pymbar/mbar_solvers/mbar_solver.py @@ -564,7 +564,7 @@ def solve_mbar_once( ] # Hessian of objective function with warnings.catch_warnings(record=True) as w: if ( - method == "BFGS" + self.real_jit and method == "BFGS" ): # Might be a way to fold this in now that accelerators are class-ified fpad = lambda x: self.pad(x, (1, 0)) # Make sure to use the static method here @@ -578,7 +578,7 @@ def solve_mbar_once( options=dict(maxiter=options["maxiter"]), ) results = dict() # there should be a way to copy this. - results["x"] = minimize_results[0] + results["x"] = minimize_results.x f_k_nonzero = pad(results["x"]) results["success"] = minimize_results[1] elif method in scipy_minimize_options: diff --git a/pymbar/mbar_solvers/solver_api.py b/pymbar/mbar_solvers/solver_api.py index d8725563..06a5da54 100644 --- a/pymbar/mbar_solvers/solver_api.py +++ b/pymbar/mbar_solvers/solver_api.py @@ -142,6 +142,10 @@ def logsumexp(self): def jit(self): pass + @property + def real_jit(self): + return False + def _precondition_jit(self, jitable_fn): @wraps( jitable_fn From 16cc394ed4a4d756d5ef38bb9cdd1110eb728333 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Mon, 26 Jun 2023 11:10:35 -0400 Subject: [PATCH 10/12] After some testing, I found the speed gain from having pure static generated methods is negligible and my earlier testing was the fact I was re-instancing the solver, and thus re-JIT'ing everything each time a new MBAR was called (fixed in earlier commit). After testing, here are the results: Testing the timing of test_protocols test using static-generated methods as a relative baseline: The test is 99% as fast on average with PyTree registration. The test is 95% as fast on average without the PyTree registration. So I've opted to use the JAX PyTree registration method and simplify the code substantially by moving all methods back into self methods. Also updated the readme to reflect the new option. --- README.md | 7 ++ pymbar/mbar_solvers/jax_solver.py | 34 +++--- pymbar/mbar_solvers/mbar_solver.py | 175 +++++++++-------------------- 3 files changed, 78 insertions(+), 138 deletions(-) diff --git a/README.md b/README.md index 2005c383..fa5b4217 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,13 @@ PyMBAR needs 64-bit floats to provide reliable answers. JAX by default uses PyMBAR will turn on JAX's 64-bit mode, which may cause issues with some separate uses of JAX in the same code as PyMBAR, such as existing Neural Network (NN) Models for machine learning. +If you would like JAX in 32-bit mode, and PyMBAR in the same script, instance your MBAR with the `accelerator=numpy` +option, e.g. +```python +mbar = MBAR(..., accelerator="numpy") +``` +replacing `...` with your other options. + Authors ------- * Kyle A. Beauchamp diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py index f68a2598..b9853b02 100644 --- a/pymbar/mbar_solvers/jax_solver.py +++ b/pymbar/mbar_solvers/jax_solver.py @@ -11,6 +11,8 @@ import jax.scipy.optimize from jax.scipy.special import logsumexp + from jax.tree_util import register_pytree_node_class + from jax import jit except ImportError: raise ImportError("JAX not found!") @@ -20,6 +22,7 @@ logger = logging.getLogger(__name__) +@register_pytree_node_class class MBARSolverJAX(MBARSolver): """ Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. @@ -49,8 +52,6 @@ def __init__(self): "* when you instance the MBAR object *\n" "******************************************\n" ) - # Double __ in middle name intentional here - self._static__adaptive_core = generate_static_adaptive_core(self) super().__init__() @property @@ -128,27 +129,32 @@ def _adaptive_core(self, u_kn, N_k, f_k, g, gamma): N_k must be float (should be cast at a higher level) """ - - -def generate_static_adaptive_core(solver: MBARSolver): - def _adaptive_core(u_kn, N_k, f_k, g, gamma): # Perform Newton-Raphson iterations (with sci computed on the way) - g = solver.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient - H = solver.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian + g = self.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient + H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian Hinvg = lstsq(H, g, rcond=-1)[0] Hinvg -= Hinvg[0] f_nr = f_k - gamma * Hinvg # self-consistent iteration gradient norm and saved log sums. - f_sci = solver.self_consistent_update(u_kn, N_k, f_k) + f_sci = self.self_consistent_update(u_kn, N_k, f_k) f_sci = f_sci - f_sci[0] # zero out the minimum - g_sci = solver.mbar_gradient(u_kn, N_k, f_sci) - gnorm_sci = solver.dot(g_sci, g_sci) + g_sci = self.mbar_gradient(u_kn, N_k, f_sci) + gnorm_sci = self.dot(g_sci, g_sci) # newton raphson gradient norm and saved log sums. - g_nr = solver.mbar_gradient(u_kn, N_k, f_nr) - gnorm_nr = solver.dot(g_nr, g_nr) + g_nr = self.mbar_gradient(u_kn, N_k, f_nr) + gnorm_nr = self.dot(g_nr, g_nr) return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr - return _adaptive_core + def tree_flatten(self): + """Required method for PyTree registration with JAX""" + children = () + aux_data = {} + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Required method for PyTree registration with JAX""" + return cls() diff --git a/pymbar/mbar_solvers/mbar_solver.py b/pymbar/mbar_solvers/mbar_solver.py index e1a249b2..5a760fb1 100644 --- a/pymbar/mbar_solvers/mbar_solver.py +++ b/pymbar/mbar_solvers/mbar_solver.py @@ -68,52 +68,34 @@ class MBARSolver(MBARSolverAPI, MBARSolverAcceleratorMethods): def __init__(self): """ - Generate all the static methods to make JIT compile clean + JIT the methods on instancing to avoid doing this at runtime. - All the methods overwritten in this code are cast to static methods because JIT (at least in JAX) - suffers a massive performance loss if you try to JIT a bound method of a class (i.e. anything with - a reference to 'self'). See: https://github.com/google/jax/discussions/16020#discussioncomment-5915882 + In theory, you want all JIT methods to be static (at least in JAX) because otherwise you can suffer a massive + performance loss if you try to JIT a bound method of a class (i.e. anything with a reference to 'self'). + See: https://github.com/google/jax/discussions/16020#discussioncomment-5915882 + + For this use case however, we do not appear to suffer a performance loss of note due to the simplicity + of this class, and the use of the exact methods we need in all the @property decorators and the PyTree + recommendation of JAX itself. + See: https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree + + Testing the timing of test_protocols test using static-generated methods as a relative baseline: + The test is 99% as fast on average with PyTree registration. + The test is 95% as fast on average without the PyTree registration. + See commit hash d65e882 to view the static-generated methods for this code. Marking self as static with a partial doesn't work because we're wrapping the function already once, and we still need the functions/properties found in the class to make this an extensible class for other accelerators in the future. See https://jax.readthedocs.io/en/latest/faq.html#strategy-2-marking-self-as-static - The PyTree approach didn't seem to work due to the same problem as marking self static because it still needs - properties so no gains were made. Its possible I (LNN) misinterpreted something here and this can be used to - simplify the code in the future to avoid writing all the static-method generators. - https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree - - If the default methods are used (which are written to call the static generator anyway), then the JIT cache - will re-compile them every time, which defeats the whole point. The calls are left in to leave a developer - breadcrumb as to what is supposed to go there, and to make linter's happy. """ - # Dont use just _{method} because any result with leading __ mangles name - # E.g. __adaptive_core -> _{ClassName}_adaptive_core - self._static_mbar_gradient = generate_static_mbar_gradient(self) - self._static_mbar_objective = generate_static_mbar_objective(self) - self._static_mbar_objective_and_gradient = generate_static_mbar_objective_and_gradient( - self - ) - self._static_mbar_hessian = generate_static_mbar_hessian(self) - self._static_mbar_log_W_nk = generate_static_mbar_log_W_nk(self) - self._static_mbar_W_nk = generate_static_mbar_W_nk(self, self._static_mbar_log_W_nk) - self._static_jit_self_consistent_update = generate_jit_self_consistent_update(self) - self._static_precondition_u_kn = generate_static_precondition_u_kn(self) # Apply the precondition to each of the JITABLE_METHODS for method in ( self.JITABLE_ACCELERATOR_METHODS + self.JITABLE_API_METHODS + self.JITABLE_IMPLEMENTATION_METHODS ): - # Attempt to staticfy if method "_{method}" exists - if hasattr(self, "_static_" + method): - doc = getattr(self, method).__doc__ - static = getattr(self, "_static_" + method) - # Replace with static name - setattr(self, method, static) - # Reset docstring - getattr(self, method).__doc__ = doc # Jit setattr(self, method, self._precondition_jit(getattr(self, method))) @@ -152,6 +134,12 @@ def jit_self_consistent_update(self, u_kn, N_k, f_k): N_k must be float (should be cast at a higher level) """ + # Asteroid + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + # All states can contribute to the numerator term. Check transpose + return -1.0 * self.logsumexp( + -log_denominator_n - u_kn, axis=1 + ) # pylint: disable=invalid-unary-operand-type def mbar_gradient(self, u_kn, N_k, f_k): """Gradient of MBAR objective function. @@ -174,7 +162,10 @@ def mbar_gradient(self, u_kn, N_k, f_k): ----- This is equation C6 in the JCP MBAR paper. """ - return generate_static_mbar_gradient(self)(u_kn, N_k, f_k) + # N_k must be float (should be cast at a higher level) + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) + return -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) def mbar_objective(self, u_kn, N_k, f_k): """Calculates objective function for MBAR. @@ -204,7 +195,10 @@ def mbar_objective(self, u_kn, N_k, f_k): More optimal precision, the objective function uses math.fsum for the outermost sum and logsumexp for the inner sum. """ - return generate_static_mbar_objective(self)(u_kn, N_k, f_k) + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + obj = self.sum(log_denominator_n) - self.dot(N_k, f_k) + + return obj def mbar_objective_and_gradient(self, u_kn, N_k, f_k): """Calculates both objective function and gradient for MBAR. @@ -239,7 +233,13 @@ def mbar_objective_and_gradient(self, u_kn, N_k, f_k): The gradient is equation C6 in the JCP MBAR paper; the objective function is its integral. """ - return generate_static_mbar_objective_and_gradient(self)(u_kn, N_k, f_k) + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) + grad = -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) + + obj = self.sum(log_denominator_n) - self.dot(N_k, f_k) + + return obj, grad def mbar_hessian(self, u_kn, N_k, f_k) -> np.ndarray: """Hessian of Mmbar_hessianBAR objective function. @@ -262,7 +262,15 @@ def mbar_hessian(self, u_kn, N_k, f_k) -> np.ndarray: ----- Equation (C9) in JCP MBAR paper. """ - return generate_static_mbar_hessian(self)(u_kn, N_k, f_k) + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] + W = self.exp(logW) + + H = self.dot(W.T, W) + H *= N_k + H *= N_k[:, self.newaxis] + H -= self.diag(W.sum(0) * N_k) + return -1.0 * H def mbar_log_W_nk(self, u_kn, N_k, f_k): """Calculate the log weight matrix. @@ -285,7 +293,9 @@ def mbar_log_W_nk(self, u_kn, N_k, f_k): ----- Equation (9) in JCP MBAR paper. """ - return generate_static_mbar_log_W_nk(self)(u_kn, N_k, f_k) + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] + return logW def mbar_W_nk(self, u_kn, N_k, f_k): """Calculate the weight matrix. @@ -308,7 +318,7 @@ def mbar_W_nk(self, u_kn, N_k, f_k): ----- Equation (9) in JCP MBAR paper. """ - return generate_static_mbar_W_nk(self, self.mbar_log_W_nk)(u_kn, N_k, f_k) + return self.exp(self.mbar_log_W_nk(u_kn, N_k, f_k)) def precondition_u_kn(self, u_kn, N_k, f_k): """Subtract a sample-dependent constant from u_kn to improve precision @@ -335,7 +345,9 @@ def precondition_u_kn(self, u_kn, N_k, f_k): x_n such that the current objective function value is zero, which should give maximum precision in the objective function. """ - return generate_static_precondition_u_kn(self)(u_kn, N_k, f_k) + u_kn = u_kn - u_kn.min(0) + u_kn += (self.logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - self.dot(N_k, f_k) / N_k.sum() + return u_kn def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): """ @@ -801,88 +813,3 @@ def validate_inputs(u_kn, N_k, f_k): f_k = ensure_type(f_k, "float", 1, "f_k", shape=(n_states,)) return u_kn, N_k, f_k - - -def generate_static_mbar_gradient(solver: MBARSolver): - def mbar_gradient(u_kn, N_k, f_k): - # N_k must be float (should be cast at a higher level) - log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = solver.logsumexp(-log_denominator_n - u_kn, axis=1) - return -1 * N_k * (1.0 - solver.exp(f_k + log_numerator_k)) - - return mbar_gradient - - -def generate_static_mbar_objective(solver: MBARSolver): - def mbar_objective(u_kn, N_k, f_k): - log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - obj = solver.sum(log_denominator_n) - solver.dot(N_k, f_k) - - return obj - - return mbar_objective - - -def generate_static_mbar_objective_and_gradient(solver: MBARSolver): - def mbar_objective_and_gradient(u_kn, N_k, f_k): - log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = solver.logsumexp(-log_denominator_n - u_kn, axis=1) - grad = -1 * N_k * (1.0 - solver.exp(f_k + log_numerator_k)) - - obj = solver.sum(log_denominator_n) - solver.dot(N_k, f_k) - - return obj, grad - - return mbar_objective_and_gradient - - -def generate_static_mbar_hessian(solver: MBARSolver): - def mbar_hessian(u_kn, N_k, f_k): - log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, solver.newaxis] - W = solver.exp(logW) - - H = solver.dot(W.T, W) - H *= N_k - H *= N_k[:, solver.newaxis] - H -= solver.diag(W.sum(0) * N_k) - return -1.0 * H - - return mbar_hessian - - -def generate_static_mbar_log_W_nk(solver: MBARSolver): - def mbar_log_W_nk(u_kn, N_k, f_k): - log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, solver.newaxis] - return logW - - return mbar_log_W_nk - - -def generate_static_mbar_W_nk(solver: MBARSolver, static_mbar_log_W_nk: callable): - def mbar_W_nk(u_kn, N_k, f_k): - return solver.exp(static_mbar_log_W_nk(u_kn, N_k, f_k)) - - return mbar_W_nk - - -def generate_static_precondition_u_kn(solver: MBARSolver): - def precondition_u_kn(u_kn, N_k, f_k): - u_kn = u_kn - u_kn.min(0) - u_kn += (solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - solver.dot(N_k, f_k) / N_k.sum() - return u_kn - - return precondition_u_kn - - -def generate_jit_self_consistent_update(solver: MBARSolver): - def jit_self_consistent_update(u_kn, N_k, f_k): - # Asteroid - log_denominator_n = solver.logsumexp(f_k - u_kn.T, b=N_k, axis=1) - # All states can contribute to the numerator term. Check transpose - return -1.0 * solver.logsumexp( - -log_denominator_n - u_kn, axis=1 - ) # pylint: disable=invalid-unary-operand-type - - return jit_self_consistent_update From f69b80cce16ee3a3112be18d215a6ac2fcdcebf1 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Mon, 26 Jun 2023 12:27:10 -0400 Subject: [PATCH 11/12] Upstream method change with scipy>=1.9 for this one call. They say the pre 1.9 behavior was a bug, hence no depreciation warning. --- pymbar/tests/test_timeseries.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymbar/tests/test_timeseries.py b/pymbar/tests/test_timeseries.py index e5455ba7..1a57fd11 100644 --- a/pymbar/tests/test_timeseries.py +++ b/pymbar/tests/test_timeseries.py @@ -130,7 +130,12 @@ def test_compare_detectEquil(show_hist=False): bs_de = timeseries.detect_equilibration_binary_search(D_t, bs_nodes=10) std_de = timeseries.detect_equilibration(D_t, fast=False, nskip=1) t_res.append(bs_de[0] - std_de[0]) - t_res_mode = float(stats.mode(t_res)[0][0]) + try: + # scipy<1.9 + t_res_mode = float(stats.mode(t_res)[0][0]) + except IndexError: + # scipy>=1.9 + t_res_mode = float(stats.mode(t_res, keepdims=True)[0][0]) assert_almost_equal(t_res_mode, 0.0, decimal=1) if show_hist: import matplotlib.pyplot as plt From 00a9cd7240e6bc81d0a46966502035f82858e8a1 Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Tue, 27 Jun 2023 09:53:34 -0400 Subject: [PATCH 12/12] Fix the docs, may conflict with main due to #510 but I can fix that too. --- devtools/conda-envs/test_env_jax.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/devtools/conda-envs/test_env_jax.yaml b/devtools/conda-envs/test_env_jax.yaml index d466eeae..c2dd795e 100644 --- a/devtools/conda-envs/test_env_jax.yaml +++ b/devtools/conda-envs/test_env_jax.yaml @@ -22,4 +22,5 @@ dependencies: - xlrd # Docs - numpydoc + - sphinx <7 - sphinxcontrib-bibtex