Skip to content

Commit

Permalink
Fix Python container and pickle API (#1627)
Browse files Browse the repository at this point in the history
  • Loading branch information
franzpoeschel authored Jun 11, 2024
1 parent 5d9fb34 commit 87cdc96
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
6 changes: 4 additions & 2 deletions include/openPMD/binding/python/Container.H
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include <cstddef>
#include <memory>
#include <pybind11/attr.h>
#include <sstream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -118,11 +119,12 @@ Class_ finalize_container(Class_ cl)
// keep same policy as Container class: missing keys are created
cl.def(
"__getitem__",
[](Map &m, KeyType const &k) -> MappedType & { return m[k]; },
[](Map &m, KeyType const &k) -> MappedType { return m[k]; },
// copy + keepalive
// All objects in the openPMD object model are handles, so using a copy
// is safer and still performant.
py::return_value_policy::copy);
py::return_value_policy::move,
py::keep_alive<0, 1>());

// Assignment provided only if the type is copyable
py::detail::map_assignment<Map, Class_>(cl);
Expand Down
2 changes: 1 addition & 1 deletion include/openPMD/binding/python/Pickle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ add_pickle(pybind11::class_<T_Args...> &cl, T_SeriesAccessor &&seriesAccessor)
},

// __setstate__
[&seriesAccessor](py::tuple t) {
[&seriesAccessor](py::tuple const &t) {
// our tuple has exactly two elements: filePath & group
if (t.size() != 2)
throw std::runtime_error("Invalid state!");
Expand Down
5 changes: 4 additions & 1 deletion src/binding/python/MeshRecordComponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ void init_MeshRecordComponent(py::module &m)
add_pickle(
cl, [](openPMD::Series &series, std::vector<std::string> const &group) {
uint64_t const n_it = std::stoull(group.at(1));
return series.iterations[n_it].meshes[group.at(3)][group.at(4)];
return series.iterations[n_it]
.meshes[group.at(3)]
[group.size() < 5 ? MeshRecordComponent::SCALAR
: group.at(4)];
});

finalize_container<PyMeshRecordComponentContainer>(py_mrc_cnt);
Expand Down
4 changes: 2 additions & 2 deletions src/binding/python/RecordComponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1124,8 +1124,8 @@ void init_RecordComponent(py::module &m)
add_pickle(
cl, [](openPMD::Series &series, std::vector<std::string> const &group) {
uint64_t const n_it = std::stoull(group.at(1));
return series.iterations[n_it]
.particles[group.at(3)][group.at(4)][group.at(5)];
return series.iterations[n_it].particles[group.at(3)][group.at(
4)][group.size() < 6 ? RecordComponent::SCALAR : group.at(5)];
});

addRecordComponentSetGet(cl);
Expand Down

0 comments on commit 87cdc96

Please sign in to comment.