diff --git a/setup.py b/setup.py index d00bfbe..c1d6683 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def run(self): ] cfg = 'Debug' if self.debug else 'Release' - #cfg = 'Debug' + cfg = 'Debug' build_args = ['--config', cfg] cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] diff --git a/stator/symbolic/binary_ops_rt.hpp b/stator/symbolic/binary_ops_rt.hpp index 3e16483..f35fd72 100644 --- a/stator/symbolic/binary_ops_rt.hpp +++ b/stator/symbolic/binary_ops_rt.hpp @@ -35,7 +35,8 @@ namespace sym { } bool operator==(const BinaryOp& o) const { - return (_l == o._l) && (_r == o._r); + //Shortcut comparison before proceeding with item by item + return (this == &o) || ((_l == o._l) && (_r == o._r)); } template diff --git a/stator/symbolic/dict_rt.hpp b/stator/symbolic/dict_rt.hpp index 99b7db1..4faf4ae 100644 --- a/stator/symbolic/dict_rt.hpp +++ b/stator/symbolic/dict_rt.hpp @@ -31,6 +31,7 @@ namespace sym { } typedef Store::key_type key_type; + typedef Store::value_type value_type; typedef Store::reference reference; typedef Store::const_reference const_reference; typedef typename Store::iterator iterator; @@ -54,18 +55,76 @@ namespace sym { bool empty() const noexcept { return _store.empty(); } bool operator==(const Dict& o) const { - return _store == o._store; + //Shortcut comparison before proceeding with item by item + return (this == &o) || (_store == o._store); } - + template constexpr bool operator==(const RHS&) const { return false; } + std::pair insert( const value_type& value ) { + return _store.insert(value); + } + + std::pair insert( value_type&& value ) { + return _store.insert(std::move(value)); + } + private: Store _store; }; + auto operator+(const Dict& l, const Dict& r) { + auto out_ptr = Dict::create(); + auto& out = *out_ptr; + + for (const auto& item : l) + out.insert(item); + + for (const auto& item : r) { + auto it = out.find(item.first); + if (it != out.end()) + out[item.first] = out[item.first] + item.second; + else + out.insert(item); + } + + return out_ptr; + } + + auto operator-(const Dict& l, const Dict& r) { + auto out_ptr = Dict::create(); + auto& out = *out_ptr; + + for (const auto& item : l) + out.insert(item); + + for (const auto& item : r) { + auto it = out.find(item.first); + if (it != out.end()) + out[item.first] = out[item.first] - item.second; + else + out[item.first] = -item.second; + } + + return out_ptr; + } + + auto operator*(const Dict& l, const Dict& r) { + auto out_ptr = Dict::create(); + auto& out = *out_ptr; + + for (const auto& item : l) { + auto it = r.find(item.first); + if (it != r.end()) + out[item.first] = item.second * it->second; + } + + return out_ptr; + } + namespace detail { template<> struct Type_index { static const int value = 16; }; } @@ -85,6 +144,9 @@ namespace sym { return out; } + std::pair BP(const Dict& v) + { return std::make_pair(std::numeric_limits::max(), std::numeric_limits::max()); } + template inline std::string repr(const sym::Dict& f) { diff --git a/stator/symbolic/list_rt.hpp b/stator/symbolic/list_rt.hpp index 1c8851f..ff46b8a 100644 --- a/stator/symbolic/list_rt.hpp +++ b/stator/symbolic/list_rt.hpp @@ -53,7 +53,8 @@ namespace sym { bool empty() const noexcept { return _store.empty(); } bool operator==(const List& o) const { - return _store == o._store; + //Shortcut comparison before proceeding with item by item + return (this == &o) || (_store == o._store); } template @@ -79,7 +80,7 @@ namespace sym { return out; } - Expr operator+(const List& l, const List& r) { + auto operator+(const List& l, const List& r) { if (l.size() != r.size()) stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r; @@ -92,6 +93,45 @@ namespace sym { return out; } + auto operator-(const List& l, const List& r) { + if (l.size() != r.size()) + stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r; + + auto out = List::create(); + out->resize(l.size()); + + for (size_t idx(0); idx < l.size(); ++idx) + (*out)[idx] = l[idx] - r[idx]; + + return out; + } + + auto operator*(const List& l, const List& r) { + if (l.size() != r.size()) + stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r; + + auto out = List::create(); + out->resize(l.size()); + + for (size_t idx(0); idx < l.size(); ++idx) + (*out)[idx] = l[idx] * r[idx]; + + return out; + } + + auto operator/(const List& l, const List& r) { + if (l.size() != r.size()) + stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r; + + auto out = List::create(); + out->resize(l.size()); + + for (size_t idx(0); idx < l.size(); ++idx) + (*out)[idx] = l[idx] / r[idx]; + + return out; + } + Expr simplify(const List& in) { auto out_ptr = List::create(); auto& out = *out_ptr; @@ -104,8 +144,12 @@ namespace sym { return out_ptr; } + std::pair BP(const List& v) + { return std::make_pair(std::numeric_limits::max(), std::numeric_limits::max()); } + + template - inline std::string repr(const sym::List& f) + inline std::string repr(const List& f) { std::string out = std::string((Config::Latex_output) ? "\\left[" : "["); const std::string end = std::string((Config::Latex_output) ? "\\right]" : "]"); diff --git a/stator/symbolic/runtime.hpp b/stator/symbolic/runtime.hpp index 0fa0b59..b0bc4a9 100644 --- a/stator/symbolic/runtime.hpp +++ b/stator/symbolic/runtime.hpp @@ -279,16 +279,14 @@ namespace sym { comparisons between RTBase derived types. */ template - struct ComparisonVisitor : public sym::detail::VisitorHelper > { + struct ComparisonVisitor : public sym::detail::VisitorHelper, bool> { ComparisonVisitor(const LHS& l): _l(l) {} - template sym::Expr apply(const RHS& r) { - _result = _l == r; - return Expr(); + template bool apply(const RHS& r) { + return _l == r; } const LHS& _l; - bool _result; }; } @@ -303,20 +301,16 @@ namespace sym { virtual bool compare(const Expr& rhs) const { const Derived& lhs(*static_cast(this)); detail::ComparisonVisitor visitor(lhs); - rhs->visit(visitor); - return visitor._result; + return rhs->visit(visitor); } }; namespace detail { /*! \brief Binding power visitor for sym::detail::BP(const Expr&). */ - struct BPVisitor : public sym::detail::VisitorHelper { - template sym::Expr apply(const T& rhs) { - _BP = BP(rhs); - return Expr(); + struct BPVisitor : public sym::detail::VisitorHelper> { + template std::pair apply(const T& rhs) { + return BP(rhs); } - - std::pair _BP; }; } @@ -325,8 +319,7 @@ namespace sym { */ inline std::pair BP(const Expr& v) { detail::BPVisitor vis; - v->visit(vis); - return vis._BP; + return v->visit(vis); } } @@ -497,6 +490,16 @@ namespace sym { return Expr(store(Op::apply(l, r))); } + template + Expr dd_visit(const Dict& l, const Dict& r, Op) { + return Expr(simplify(Op::apply(l, r))); + } + + template + Expr dd_visit(const List& l, const List& r, Op) { + return Expr(simplify(Op::apply(l, r))); + } + //Direct evaluation of doubles template Expr dd_visit(const double& l, const T2& r, detail::Subtract) { @@ -873,21 +876,17 @@ namespace sym { namespace detail { template - struct ReprVisitor : public sym::detail::VisitorHelper > { - template sym::Expr apply(const T& rhs) { - _repr = repr(rhs); - return sym::Expr(); + struct ReprVisitor : public sym::detail::VisitorHelper, std::string> { + template std::string apply(const T& rhs) { + return repr(rhs); } - - std::string _repr; }; } template std::string repr(const sym::RTBase& b) { detail::ReprVisitor visitor; - b.visit(visitor); - return visitor._repr; + return b.visit(visitor); } /*! \brief Give a representation of an Expr. diff --git a/stator/symbolic/unary_ops_rt.hpp b/stator/symbolic/unary_ops_rt.hpp index 33229c8..ce3fd40 100644 --- a/stator/symbolic/unary_ops_rt.hpp +++ b/stator/symbolic/unary_ops_rt.hpp @@ -35,8 +35,9 @@ namespace sym { return std::shared_ptr(new UnaryOp(arg)); } - bool operator==(const UnaryOp& o) const { - return (_arg == o._arg); + bool operator==(const UnaryOp& o) const { + //Shortcut comparison before proceeding with item by item + return (this == &o) || (_arg == o._arg); } template diff --git a/stator/symbolic/units.hpp b/stator/symbolic/units.hpp index 91b0635..78c0699 100644 --- a/stator/symbolic/units.hpp +++ b/stator/symbolic/units.hpp @@ -18,3 +18,29 @@ */ #pragma once + +namespace sym { + namespace detail { + struct Units { + static constexpr int leftBindingPower = 60; + static constexpr auto associativity = Associativity::LEFT; + static constexpr bool commutative = true; + static constexpr bool associative = true; + typedef NoIdentity left_identity; + typedef Unity right_identity; + typedef Null left_zero; + typedef NoIdentity right_zero; + static inline std::string l_repr() { return ""; } + static inline std::string repr() { return "{"; } + static inline std::string r_repr() { return "}"; } + static inline std::string l_latex_repr() { return ""; } + static inline std::string latex_repr() { return "\left\{"; } + static inline std::string r_latex_repr() { return "\right\}"; } + template static auto apply(const L& l, const R& r) { + return BinaryOp::create(l, r); + } + static constexpr int type_index = 18; + }; + } + +} diff --git a/stator/symbolic/variable.hpp b/stator/symbolic/variable.hpp index 7ea9a85..41a2765 100644 --- a/stator/symbolic/variable.hpp +++ b/stator/symbolic/variable.hpp @@ -87,8 +87,7 @@ namespace sym { namespace detail { template - struct Type_index > - { static const int value = 1; }; + struct Type_index > { static const int value = 1; }; } } diff --git a/stator/symbolic/variable_rt.hpp b/stator/symbolic/variable_rt.hpp index 5723d06..8f191b6 100644 --- a/stator/symbolic/variable_rt.hpp +++ b/stator/symbolic/variable_rt.hpp @@ -48,7 +48,8 @@ namespace sym { inline bool operator==(const VarRT& o) const { - return _name == o._name; + //Shortcut comparison before proceeding with string comparison + return (this == &o) || (_name == o._name); } template diff --git a/tests/stator_test.py b/tests/stator_test.py index 909a0db..374d7f9 100644 --- a/tests/stator_test.py +++ b/tests/stator_test.py @@ -13,7 +13,7 @@ def test_interface(self): self.assertEqual(repr(Expr("1=1")),"Expr('1=1')") self.assertEqual(repr(Expr("x[1]")),"Expr('x[1]')") - def test_array(self): + def test_list(self): self.assertEqual(repr(Expr("[]")),"Expr('[]')") self.assertEqual(repr(Expr("[1]")),"Expr('[1]')") self.assertEqual(repr(Expr("[1, 2]")),"Expr('[1, 2]')") @@ -22,11 +22,17 @@ def test_array(self): df = derivative(f, x) self.assertEqual(repr(simplify(df)), "Expr('[0, 1, 4*x]')") self.assertEqual(sub(df, Expr('x=2')), [0, 1, 8]) + self.assertEqual(Expr("[1,2,3,4]") + Expr("[0,1,2,3]"), Expr('[1, 3, 5, 7]').to_python()) + self.assertEqual(Expr("[1,2,3,4]") - Expr("[0,1,2,3]"), Expr('[1, 1, 1, 1]').to_python()) + self.assertEqual(Expr("[1,2,3,4]") * Expr("[0,1,2,3]"), Expr('[0, 2, 6, 12]').to_python()) + self.assertEqual(Expr("[1,2,3,4]") / Expr("[1,2,3,2]"), Expr('[1, 1, 1, 2]').to_python()) - def test_dict(self): self.assertEqual(repr(Expr("{}")),"Expr('{}')") self.assertEqual(repr(Expr("{x:1}")),"Expr('{x:1}')") + self.assertEqual(Expr("{x:1}") + Expr("{x:1, y:2}"), Expr('{x:2.0, y:2.0}').to_python()) + self.assertEqual(Expr("{x:1}") - Expr("{x:1, y:2}"), Expr('{x:0, y:-2.0}').to_python()) + self.assertEqual(Expr("{x:3}") * Expr("{x:2, y:2}"), Expr('{x:6}').to_python()) def test_sub_generic(self): self.assertEqual(sub(Expr("x"), Expr('{x:2}')), 2) diff --git a/tests/symbolic_runtime_test.cpp b/tests/symbolic_runtime_test.cpp index e3bfa6c..5d87a59 100644 --- a/tests/symbolic_runtime_test.cpp +++ b/tests/symbolic_runtime_test.cpp @@ -222,6 +222,7 @@ UNIT_TEST( symbolic_list_basic ) auto& x = *x_ptr; UNIT_TEST_CHECK_EQUAL(sub(Expr("[1, x, y]"), Expr(x=2)), Expr("[1,2,y]")); UNIT_TEST_CHECK_EQUAL(derivative(Expr("[1, x, y]"), x), Expr("[0,1,0]")); + UNIT_TEST_CHECK_EQUAL(simplify(Expr("[1, 1, 1]")+Expr("[0, 1, 2]")), Expr("[1,2,3]")); } UNIT_TEST( symbolic_dict_basic ) @@ -237,7 +238,7 @@ UNIT_TEST( symbolic_dict_basic ) v[y] = Expr(3); UNIT_TEST_CHECK_EQUAL(sub(Expr("x"), v), Expr("2")); UNIT_TEST_CHECK_EQUAL(sub(Expr("y"), v), Expr("3")); - + UNIT_TEST_CHECK_EQUAL(sub(Expr("[1, x, y]"), Expr("{x:2, y:3}")), Expr("[1,2,3]")); UNIT_TEST_CHECK_EQUAL(sub(Expr("[1, x, y]"), v), Expr("[1,2,3]")); }