diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index af3a47f39c..cc6dd574b6 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -569,6 +569,82 @@ class_ bind_vector(handle scope, std::string const &name, A return cl; } +// +// std::set +// +template , typename... Args> +class_ bind_set(handle scope, std::string const &name, Args &&...args) { + using Class_ = class_; + using T = typename Set::value_type; + using ItType = typename Set::iterator; + + auto vtype_info = detail::get_type_info(typeid(T)); + bool local = !vtype_info || vtype_info->module_local; + + Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); + cl.def(init<>()); + cl.def(init(), "Copy constructor"); + cl.def(init([](iterable it) { + auto s = std::unique_ptr(new Set()); + for (handle h : it) + s->insert(h.cast()); + return s.release(); + })); + cl.def(self == self); + cl.def(self != self); + cl.def( + "remove", + [](Set &s, const T &x) { + auto p = s.find(x); + if (p != s.end()) + s.erase(p); + else + throw value_error(); + }, + arg("x"), + "Remove the item from the set whose value is x. " + "It is an error if there is no such item."); + cl.def( + "__contains__", + [](const Set &s, const T &x) { return s.find(x) != s.end(); }, + arg("x"), + "Return true if the container contains ``x``."); + cl.def( + "add", + [](Set &s, const T &value) { s.insert(value); }, + arg("x"), + "Add an item to the set."); + cl.def("clear", [](Set &s) { s.clear(); }, "Clear the contents."); + cl.def( + "__iter__", + [](Set &s) { + return make_iterator(s.begin(), s.end()); + }, + keep_alive<0, 1>() /* Essential: keep set alive while iterator exists */ + ); + cl.def( + "__repr__", + [name](Set &s) { + std::ostringstream os; + os << name << '{'; + for (auto it = s.begin(); it != s.end(); ++it) { + if (it != s.begin()) + os << ", "; + os << *it; + } + os << '}'; + return os.str(); + }, + "Return the canonical string representation of this set."); + cl.def( + "__bool__", + [](const Set &s) -> bool { return !s.empty(); }, + "Check whether the set is nonempty"); + cl.def("__len__", &Set::size); + + return cl; +} + // // std::map, std::unordered_map // diff --git a/tests/test_stl_binders.cpp b/tests/test_stl_binders.cpp index f846ae8482..583b605d30 100644 --- a/tests/test_stl_binders.cpp +++ b/tests/test_stl_binders.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -183,6 +184,9 @@ TEST_SUBMODULE(stl_binders, m) { py::bind_vector>(m, "VectorEl"); py::bind_vector>>(m, "VectorVectorEl"); + // test_set_int + py::bind_set>(m, "SetInt"); + // test_map_string_double py::bind_map>(m, "MapStringDouble"); py::bind_map>(m, "UnorderedMapStringDouble"); diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index 9856ba462b..470bcf2091 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -149,6 +149,29 @@ def test_vector_custom(): assert str(vv_b) == "VectorEl[El{1}, El{2}]" +def test_set_int(): + s_a = m.SetInt() + s_b = m.SetInt() + + assert len(s_a) == 0 + assert s_a == s_b + + s_a.add(1) + + assert 1 in s_a + assert str(s_a) == "SetInt{1}" + assert s_a != s_b + + for i in range(5): + s_a.add(i) + + assert sorted(s_a) == [0, 1, 2, 3, 4] + + s_a.clear() + assert len(s_a) == 0 + assert str(s_a) == "SetInt{}" + + def test_map_string_double(): mm = m.MapStringDouble() mm["a"] = 1