Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring field interpolation and allow custom interpolation methods in Scipy mode #1816

Open
wants to merge 57 commits into
base: main
Choose a base branch
from

Conversation

VeckoTheGecko
Copy link
Contributor

@VeckoTheGecko VeckoTheGecko commented Jan 8, 2025

This PR refactors many of the indexing methods and interpolation methods out of field.py, and moves indexing code to a separate file to make things more manageable and reduce the coupling with the Field class.

This PR also allows users to easily overwrite the behaviour of existing interpolation methods or (untested) define new interpolation methods in Scipy mode. This can be done via the register_2d_interpolator(...) and register_3d_interpolator(...) decorators. This behaviour is in beta and is subject to change.

This also makes the interpolation functions more easily testable by providing only the required data.

Fixes #1823

Copy link
Contributor Author

@VeckoTheGecko VeckoTheGecko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on the below @erikvansebille ? Let me know if you want me to break this into separate PRs

parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/tools/statuscodes.py Show resolved Hide resolved
@VeckoTheGecko
Copy link
Contributor Author

VeckoTheGecko commented Jan 16, 2025

Note for review: de89311 includes testing that shows the refactor of the 3d interpolation is 100% equivalent to the previous version for all combinations of interp method, grid types, and cell locations.

@VeckoTheGecko
Copy link
Contributor Author

I'm quite aware that this is getting to be a big PR. To be expected for such a big refactor, but I think it would be good to continue in other PRs (for vector interp and indexing) to keep this reviewable.

@VeckoTheGecko VeckoTheGecko changed the title Refactoring indexing and interpolation Refactoring field interpolation (and move indexing code) Jan 16, 2025
Comment on lines 68 to 93
@pytest.mark.parametrize(
"func, eta, xsi, expected",
[
pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"),
pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"),
pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"),
pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"),
pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"),
# pytest.param(interpolation._linear_2d, ...),
# pytest.param(interpolation._linear_invdist_land_tracer_2d, ...),
],
)
def test_2d(self, data_2d, func, eta, xsi, expected):
ctx = interpolation.InterpolationContext2D(data_2d, eta, xsi, self.ti, self.yi, self.xi)
assert func(ctx) == expected

@pytest.mark.parametrize(
"func, eta, xsi, expected",
[
# pytest.param(interpolation._nearest_3d, ...),
# pytest.param(interpolation._cgrid_velocity_3d, ...),
# pytest.param(interpolation._linear_invdist_land_tracer_3d, ...),
# pytest.param(interpolation._linear_3d, ...),
# pytest.param(interpolation._tracer_3d, ...),
],
)
Copy link
Contributor Author

@VeckoTheGecko VeckoTheGecko Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on test cases @erikvansebille ? I think it might be good to add some on the raw arrays

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes, we absolutely also need 3D tests. Perhaps again comparing to JIT mode?

@VeckoTheGecko VeckoTheGecko marked this pull request as ready for review January 16, 2025 18:18
@VeckoTheGecko
Copy link
Contributor Author

I added an extra commit in the history that shows 100% equivalence of the refactor. See updated #1816 (comment)

@VeckoTheGecko VeckoTheGecko changed the title Refactoring field interpolation (and move indexing code) Refactoring field interpolation and allow custom interpolation methods in Scipy mode Jan 17, 2025
Copy link
Member

@erikvansebille erikvansebille left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First bit of reviews (of the three main new files). Rest to come after the weekend

parcels/_index_search.py Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/_interpolation.py Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/_interpolation.py Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/_index_search.py Show resolved Hide resolved
Copy link
Member

@erikvansebille erikvansebille left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More comments, now on all files except for the tests/* files

parcels/_interpolation.py Show resolved Hide resolved
parcels/_index_search.py Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not also move the _search_indices(), _search_indices_curvilinear() and _search_indices_rectilinear() methods to the _index_search.py file? Why are they still in field.py?

In general, I see that even in this PR, field.py still contains a lot of methods/functions that don't specifically need to be here. That would really clean up the field.py file?

E.g. VectorField.dist, VectorField._is_land2D() and VectorField.jacobian can go to an interpolation_utils.py file (mimicking what is in include folder for C)?

And the spatial interpolation for VectorFields can also go to _interpolation.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - I have pushed these changes. I haven't made any changes to VectorField, I will make those in another PR

parcels/tools/statuscodes.py Outdated Show resolved Hide resolved
VeckoTheGecko added a commit that referenced this pull request Jan 20, 2025
VeckoTheGecko added a commit that referenced this pull request Jan 20, 2025
Not needed anymore since show_time isn't in the codebase anymore #1816 (comment)
parcels/_index_search.py Outdated Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/_interpolation.py Outdated Show resolved Hide resolved
parcels/field.py Outdated Show resolved Hide resolved
parcels/field.py Outdated Show resolved Hide resolved
Comment on lines 76 to 77
# pytest.param(interpolation._linear_2d, ...),
# pytest.param(interpolation._linear_invdist_land_tracer_2d, ...),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's crucial that we test all interpolation methods with these simple unit tests. But I must say I don't really understand how this test-function is now set up. Why a class?

And indeed, I would not test just one point, but a few edge cases too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's crucial that we test all interpolation methods with these simple unit tests

Agreed

But I must say I don't really understand how this test-function is now set up. Why a class?

Admittedly classes aren't used much in Pytest beyond a way of grouping together tests (even then - not used often). I thought it would be a straight-forward way to group the tests that are testing on ti,zi,yi,xi = 0,1,1,1 on the data_2d and data_3d datasets.

And indeed, I would not test just one point, but a few edge cases too.

Agreed - those can be functions outside of this class. Also testing of the land interp method

)
def test_2d(self, data_2d, func, eta, xsi, expected):
ctx = interpolation.InterpolationContext2D(data_2d, eta, xsi, self.ti, self.yi, self.xi)
assert func(ctx) == expected
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't want to hand-code all the expecteds, we could also compare to JIT mode? I know that perhaps we will remove JIT mode in the long term, but until then it would be good to check that they are consistent

Copy link
Contributor Author

@VeckoTheGecko VeckoTheGecko Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do. I think that would require going back up to constructing the Field class, and then testing the interpolation on that.

I guess that would be the only way of doing it? I assume that the JIT functions don't follow the same structure as the new scipy interpolation functions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I fear that the only way to test again JIT is going through the pset.execute(). So make a FieldSet, a large ParticleSet, and then call a Field-evaoluation (particle.u = fieldset.U[particle.time, particle.depth, particle.lat, particle.lon]) in a custom kernel and assert whether all particle.u are the same value as the new Scipy interpolation.

Not the cleanest code and I can imagine you're slightly disappointed to require all these extra classes in this test-function (defeats the purpose of unit tests somewhat), but by far the most robust validation that our new Scipy works. When/if we move away from JIT, we can remove all that code ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the cleanest code and I can imagine you're slightly disappointed to require all these extra classes in this test-function

Yes, but at the end of the day its important to have proper validation and clean code can come gradually as we refactor parts of the codebase and improve :)

Comment on lines 68 to 93
@pytest.mark.parametrize(
"func, eta, xsi, expected",
[
pytest.param(interpolation._nearest_2d, 0.49, 0.49, 3.0, id="nearest_2d-1"),
pytest.param(interpolation._nearest_2d, 0.49, 0.51, 4.0, id="nearest_2d-2"),
pytest.param(interpolation._nearest_2d, 0.51, 0.49, 5.0, id="nearest_2d-3"),
pytest.param(interpolation._nearest_2d, 0.51, 0.51, 6.0, id="nearest_2d-4"),
pytest.param(interpolation._tracer_2d, None, None, 6.0, id="tracer_2d"),
# pytest.param(interpolation._linear_2d, ...),
# pytest.param(interpolation._linear_invdist_land_tracer_2d, ...),
],
)
def test_2d(self, data_2d, func, eta, xsi, expected):
ctx = interpolation.InterpolationContext2D(data_2d, eta, xsi, self.ti, self.yi, self.xi)
assert func(ctx) == expected

@pytest.mark.parametrize(
"func, eta, xsi, expected",
[
# pytest.param(interpolation._nearest_3d, ...),
# pytest.param(interpolation._cgrid_velocity_3d, ...),
# pytest.param(interpolation._linear_invdist_land_tracer_3d, ...),
# pytest.param(interpolation._linear_3d, ...),
# pytest.param(interpolation._tracer_3d, ...),
],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And yes, we absolutely also need 3D tests. Perhaps again comparing to JIT mode?

if gridindexingtype == "mom5" and z > 2 * grid.depth[0] - grid.depth[1]:
return (-1, z / grid.depth[0])
else:
_raise_field_out_of_bound_surface_error(z, 0, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not ideal that this function passes y=0 and x=0. I understand we don't know what x and y are at this moment, but printing them as zeros can be misleading to users. Would it not be better to print not available or unknown or so? Or is there no way to somehow figure out what x and y are?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code was also putting them to be 0, so this was just a continuation of that. Easy fix to set x and y to be None!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, and could we then print these Nones as 'unknown' in the actual warning message to the user? Would then make more sense to them?

VeckoTheGecko added a commit that referenced this pull request Jan 29, 2025
@VeckoTheGecko
Copy link
Contributor Author

Cleaned up now. We can either merge this first, or #1834. Whichever one comes first, I'll rebase the other

@VeckoTheGecko
Copy link
Contributor Author

I encountered this

=========================== short test summary info ============================
FAILED tests/test_interpolation.py::test_scipy_vs_jit[nearest] - assert not np.True_
 +  where np.True_ = <function isclose at 0x7f9f38c37130>(np.float32(0.75000733), np.float64(0.75), atol=1e-08)
 +    where <function isclose at 0x7f9f38c37130> = np.isclose
 +    and   np.float32(0.75000733) = P[118388](lon=0.166994, lat=0.335120, depth=0.750007, pid=148.000000, time=0.003000).depth
= 1 failed, 1181 passed, 1 skipped, 8 xfailed, 4087 warnings in 687.10s (0:11:27) =

Perhaps due to #1834 due to it being in the depth dimension (though could not recreate locally)

@erikvansebille
Copy link
Member

I encountered this

=========================== short test summary info ============================
FAILED tests/test_interpolation.py::test_scipy_vs_jit[nearest] - assert not np.True_
 +  where np.True_ = <function isclose at 0x7f9f38c37130>(np.float32(0.75000733), np.float64(0.75), atol=1e-08)
 +    where <function isclose at 0x7f9f38c37130> = np.isclose
 +    and   np.float32(0.75000733) = P[118388](lon=0.166994, lat=0.335120, depth=0.750007, pid=148.000000, time=0.003000).depth
= 1 failed, 1181 passed, 1 skipped, 8 xfailed, 4087 warnings in 687.10s (0:11:27) =

Perhaps due to #1834 due to it being in the depth dimension (though could not recreate locally)

Hmm, it's a 1e-6 error; it could also be because of accumulating round-off errors. I found the 1e-8 extremely tight already, but since it worked locally on my computer too I didn't change it. I think 1e-6 is still totally acceptable, so I'll see if that fixes the breaking assert

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Backlog
Development

Successfully merging this pull request may close these issues.

Refactor interpolation and indexing methods out of field.py
2 participants