Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MBAR bootstrap error #1077

Merged
merged 12 commits into from
Jan 30, 2025
Merged

Use MBAR bootstrap error #1077

merged 12 commits into from
Jan 30, 2025

Conversation

jthorton
Copy link
Collaborator

@jthorton jthorton commented Jan 14, 2025

Fixes #1012 by using the bootstrap error from pymbar3/4.

Would this be a good time to switch to only supporting pymbar4 so we only have to maintain a single interface for MBAR?

Note:

  • the full pymbar4 package brings in JAX
  • I found that 1000 iterations of bootstrapping only takes around 1 min for the default protocol (using jax)
  • For the extended charge changing protocol this can take up to 15 mins (using jax)
  • The variability in the dDGs between test runs was larger which meant I had to relax the relative tolerance on the tests

Checklist

  • Added a news entry

Developers certificate of origin

@jthorton jthorton requested review from IAlibay and atravitz January 14, 2025 17:21
Copy link

codecov bot commented Jan 14, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.84%. Comparing base (00445dc) to head (40f266c).
Report is 13 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1077      +/-   ##
==========================================
- Coverage   94.64%   91.84%   -2.80%     
==========================================
  Files         137      137              
  Lines       10112    10105       -7     
==========================================
- Hits         9570     9281     -289     
- Misses        542      824     +282     
Flag Coverage Δ
fast-tests 91.84% <100.00%> (?)
slow-tests ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

Would this be a good time to switch to only supporting pymbar4 so we only have to maintain a single interface for MBAR?

Yes I think that would be a good idea - if we think it's stable (we might have to benchmark a bit), we should make the jump.
If we go by spec0 rules pymbar 3 is > 2 years old.

PyMBAR 3 also has all kinds of stability issues we should try to avoid.

the full pymbar4 package brings in JAX

:/ how big of a dependency is JAX? It might be that we don't really have an option here. I know you can use pymbar 4 without JAX (that's how it gets deployed on PyPi). cc @atravitz

For the extended charge changing protocol this can take up to 15 mins (using jax)

Oof that's quite long. I guess as long as we're only doing that once in a multi-hour simulation it doesn't matter too much.

@jthorton
Copy link
Collaborator Author

JAX is around 60MB, but we can use pymbar-core which is the non-JAX version that should be a bit slower, how much slower, I am not sure but compared to a multi-hour simulation it should still be negligible!

  + jax                     0.4.35  pyhd8ed1ab_1         conda-forge/noarch        1MB
  + jaxlib                  0.4.35  cpu_py312hadfe8e1_0  conda-forge/osx-64       56MB

On the other hand, adding JAX is not too noticeable compared to the cudatoolkit?

Copy link
Member

@IAlibay IAlibay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of todos:

  • Could you add a news entry please?
  • Could you make the necessary changes to switch to pymbar 4 please?

np.array([0.07471 , 0.052914, 0.041508, 0.036613, 0.032827, 0.030489,
0.028154, 0.026529, 0.025284, 0.023968]),
rtol=1e-04,
np.array([0.077645, 0.054695, 0.044680, 0.03947, 0.034822,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating these - the new error values are expected to be different, so we should update things where we can.

rtol=1e-04,
np.array([0.077645, 0.054695, 0.044680, 0.03947, 0.034822,
0.033443, 0.030793, 0.028777, 0.026683, 0.026199]),
rtol=1e-01,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little bit loose as a tolerance, but I guess it's fine given the bootstraps are stochastic.

except AttributeError:
r = mbar.compute_free_energy_differences()
# pymbar 4
mbar = MBAR(u_ln, N_l, solver_protocol="robust", n_bootstraps=1000, bootstrap_solver_protocol="robust")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is most of the cost in the forward & reverse analysis?

Copy link
Collaborator Author

@jthorton jthorton Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah running the bootstrapping on repeat is expensive! One thought on the forward and backward estimates should we be subsampling using g_t calculated for this subset of data? In the industry benchmarking I calculated it 3 ways no subsampling, subsample based on the % of data and subsample using the g_t calculated for the full set of data. https://github.com/OpenFreeEnergy/IndustryBenchmarks2024/blob/fb60d7a971cb5d04787d796b6adcf257d905786a/industry_benchmarks/analysis/1_download_and_extract_data.py#L464-L552

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

On the other hand, adding JAX is not too noticeable compared to the cudatoolkit?

Yeah - I also suspect we're picking up a ton of dependencies elsewhere.

Long term maybe we should look into an openfe-base version that has the very minimal set of dependencies for everything.

I'll let @atravitz weigh in, but generally I'm ok / would very much like it if we pushed for pymbar4 w/ JAX.

@IAlibay
Copy link
Member

IAlibay commented Jan 15, 2025

Completely forgot to ask @jthorton - could you have a look through our docs and see if there's anywhere we can make it clear that this is now the bootstrap error? I know some folks got confused by it all.

@jthorton
Copy link
Collaborator Author

Currently blocked by perses=0.10.3 which pins to pymbar3.

Copy link
Member

@IAlibay IAlibay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jthorton I'm going to approve with the expectation that we can merge once tests pass once there's a new release of perses.

@jthorton
Copy link
Collaborator Author

We might need to delay the import of pymbar, while testing out the CLI for the partial charge generation I see that it prints a lot of info to the terminal, this gets even worse when using multiprocessing.

openfe charge-molecules -M malt1_ligands.sdf -o charged_ligands.sdf -w 6
Warning on use of the timeseries module: If the inherent timescales of the system are long compared to those being analyzed, this statistical inefficiency may be an underestimate.  The estimate presumes the use of many statistically independent samples.  Tests should be performed to assess whether this condition is satisfied.   Be cautious in the interpretation of the data.

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

SMALL MOLECULE PARTIAL CHARGE GENERATOR
_________________________________________

@atravitz
Copy link
Contributor

We might need to delay the import of pymbar, while testing out the CLI for the partial charge generation I see that it prints a lot of info to the terminal, this gets even worse when using multiprocessing.

openfe charge-molecules -M malt1_ligands.sdf -o charged_ligands.sdf -w 6
Warning on use of the timeseries module: If the inherent timescales of the system are long compared to those being analyzed, this statistical inefficiency may be an underestimate.  The estimate presumes the use of many statistically independent samples.  Tests should be performed to assess whether this condition is satisfied.   Be cautious in the interpretation of the data.

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

SMALL MOLECULE PARTIAL CHARGE GENERATOR
_________________________________________

Are we able to suppress just this warning? Otherwise, delaying the import sounds good to me.

@atravitz
Copy link
Contributor

Also, adding JAX as a dependency is fine by me

@IAlibay
Copy link
Member

IAlibay commented Jan 17, 2025

We might need to delay the import of pymbar, while testing out the CLI for the partial charge generation I see that it prints a lot of info to the terminal, this gets even worse when using multiprocessing.

openfe charge-molecules -M malt1_ligands.sdf -o charged_ligands.sdf -w 6
Warning on use of the timeseries module: If the inherent timescales of the system are long compared to those being analyzed, this statistical inefficiency may be an underestimate.  The estimate presumes the use of many statistically independent samples.  Tests should be performed to assess whether this condition is satisfied.   Be cautious in the interpretation of the data.

****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************

SMALL MOLECULE PARTIAL CHARGE GENERATOR
_________________________________________

Are we able to suppress just this warning? Otherwise, delaying the import sounds good to me.

Oof yeah @jthorton if you don't do the warning supression here, could you open an issue about doing it later? This is OpenMMTools levels of noisy.

@jthorton
Copy link
Collaborator Author

jthorton commented Jan 20, 2025

Hiding the import warnings from pymbar is proving to be tricking as importing anything from openfe triggers them before I can catch them! The import of the message filter from openfe.utils.logging_filter import MsgIncludesStringFilter triggers it. They are also triggered when the process pool is created which causes a lot of repeated messages to be printed to screen!

@jthorton
Copy link
Collaborator Author

Adding the logging filter to the main init of openfe seems to have done the trick!

@IAlibay
Copy link
Member

IAlibay commented Jan 20, 2025

Adding the logging filter to the main init of openfe seems to have done the trick!

@atravitz could you weigh in on this? I am in favour of it because the alternative is super super noisy, but I do also see the case of "if we do this, then anyone that imports OpenFE at runtime will never see that warning again".

I don't know if there's a way to be more explicit and just catch child process?

@atravitz
Copy link
Contributor

Adding the logging filter to the main init of openfe seems to have done the trick!

@atravitz could you weigh in on this? I am in favour of it because the alternative is super super noisy, but I do also see the case of "if we do this, then anyone that imports OpenFE at runtime will never see that warning again".

I don't know if there's a way to be more explicit and just catch child process?

I looked into it a bit and haven't come up with a nicer way to handle these warnings yet. Curious if @mikemhenry has thoughts as someone who has worked with pymbar.

@jthorton is the current behavior that it does it not show the warning at all?

I don't want this to block this PR, so please open a follow-up issue and I'll look more into it.

@jthorton
Copy link
Collaborator Author

@jthorton is the current behavior that it does it not show the warning at all?

Yes this now hides the warning completely, I think we could fix this if we changed to only importing pymbar when its used but this is tied to openmmtools as well which might not possible. We could re-do the warning only when we use pymbar so it appears in the protocol execution logs if its critical that we pass on the warning?


**Changed:**

* The MBAR bootstrap (1000 iterations) error is used to estimate protocol uncertainty and pymbar3 is no longer supported `PR#1077 <https://github.com/OpenFreeEnergy/openfe/pull/1077>`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jthorton could you add an "instead of" here too maybe? (i.e. tell them what we used to do for MBAR estimation).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 44f1b28

@jthorton
Copy link
Collaborator Author

From the OpenFE meeting today, we will drop the perses atom mapper testing to unblock pymbar4 rollout and will come back to this once we have a new perses release.

@jthorton jthorton enabled auto-merge January 30, 2025 10:57
Copy link

No API break detected ✅

@jthorton jthorton merged commit 870463b into main Jan 30, 2025
12 checks passed
@jthorton jthorton deleted the bootstrap_error branch January 30, 2025 11:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Switch to bootstrapping for MBAR errors.
3 participants