Skip to content

Commit

Permalink
Added operators for dicts and lists
Browse files Browse the repository at this point in the history
  • Loading branch information
toastedcrumpets committed May 13, 2021
1 parent d2514ec commit 22ae8fd
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 38 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run(self):
]

cfg = 'Debug' if self.debug else 'Release'
#cfg = 'Debug'
cfg = 'Debug'
build_args = ['--config', cfg]

cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
Expand Down
3 changes: 2 additions & 1 deletion stator/symbolic/binary_ops_rt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace sym {
}

bool operator==(const BinaryOp& o) const {
return (_l == o._l) && (_r == o._r);
//Shortcut comparison before proceeding with item by item
return (this == &o) || ((_l == o._l) && (_r == o._r));
}

template<class RHS>
Expand Down
66 changes: 64 additions & 2 deletions stator/symbolic/dict_rt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace sym {
}

typedef Store::key_type key_type;
typedef Store::value_type value_type;
typedef Store::reference reference;
typedef Store::const_reference const_reference;
typedef typename Store::iterator iterator;
Expand All @@ -54,18 +55,76 @@ namespace sym {
bool empty() const noexcept { return _store.empty(); }

bool operator==(const Dict& o) const {
return _store == o._store;
//Shortcut comparison before proceeding with item by item
return (this == &o) || (_store == o._store);
}

template<class RHS>
constexpr bool operator==(const RHS&) const {
return false;
}

std::pair<iterator,bool> insert( const value_type& value ) {
return _store.insert(value);
}

std::pair<iterator,bool> insert( value_type&& value ) {
return _store.insert(std::move(value));
}

private:
Store _store;
};

auto operator+(const Dict& l, const Dict& r) {
auto out_ptr = Dict::create();
auto& out = *out_ptr;

for (const auto& item : l)
out.insert(item);

for (const auto& item : r) {
auto it = out.find(item.first);
if (it != out.end())
out[item.first] = out[item.first] + item.second;
else
out.insert(item);
}

return out_ptr;
}

auto operator-(const Dict& l, const Dict& r) {
auto out_ptr = Dict::create();
auto& out = *out_ptr;

for (const auto& item : l)
out.insert(item);

for (const auto& item : r) {
auto it = out.find(item.first);
if (it != out.end())
out[item.first] = out[item.first] - item.second;
else
out[item.first] = -item.second;
}

return out_ptr;
}

auto operator*(const Dict& l, const Dict& r) {
auto out_ptr = Dict::create();
auto& out = *out_ptr;

for (const auto& item : l) {
auto it = r.find(item.first);
if (it != r.end())
out[item.first] = item.second * it->second;
}

return out_ptr;
}

namespace detail {
template<> struct Type_index<Dict> { static const int value = 16; };
}
Expand All @@ -85,6 +144,9 @@ namespace sym {
return out;
}

std::pair<int, int> BP(const Dict& v)
{ return std::make_pair(std::numeric_limits<int>::max(), std::numeric_limits<int>::max()); }

template<class Config = DefaultReprConfig>
inline std::string repr(const sym::Dict& f)
{
Expand Down
50 changes: 47 additions & 3 deletions stator/symbolic/list_rt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ namespace sym {
bool empty() const noexcept { return _store.empty(); }

bool operator==(const List& o) const {
return _store == o._store;
//Shortcut comparison before proceeding with item by item
return (this == &o) || (_store == o._store);
}

template<class RHS>
Expand All @@ -79,7 +80,7 @@ namespace sym {
return out;
}

Expr operator+(const List& l, const List& r) {
auto operator+(const List& l, const List& r) {
if (l.size() != r.size())
stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r;

Expand All @@ -92,6 +93,45 @@ namespace sym {
return out;
}

auto operator-(const List& l, const List& r) {
if (l.size() != r.size())
stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r;

auto out = List::create();
out->resize(l.size());

for (size_t idx(0); idx < l.size(); ++idx)
(*out)[idx] = l[idx] - r[idx];

return out;
}

auto operator*(const List& l, const List& r) {
if (l.size() != r.size())
stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r;

auto out = List::create();
out->resize(l.size());

for (size_t idx(0); idx < l.size(); ++idx)
(*out)[idx] = l[idx] * r[idx];

return out;
}

auto operator/(const List& l, const List& r) {
if (l.size() != r.size())
stator_throw() << "Mismatched list size for: \n" << l << "\n and\n" << r;

auto out = List::create();
out->resize(l.size());

for (size_t idx(0); idx < l.size(); ++idx)
(*out)[idx] = l[idx] / r[idx];

return out;
}

Expr simplify(const List& in) {
auto out_ptr = List::create();
auto& out = *out_ptr;
Expand All @@ -104,8 +144,12 @@ namespace sym {
return out_ptr;
}

std::pair<int, int> BP(const List& v)
{ return std::make_pair(std::numeric_limits<int>::max(), std::numeric_limits<int>::max()); }


template<class Config = DefaultReprConfig>
inline std::string repr(const sym::List& f)
inline std::string repr(const List& f)
{
std::string out = std::string((Config::Latex_output) ? "\\left[" : "[");
const std::string end = std::string((Config::Latex_output) ? "\\right]" : "]");
Expand Down
45 changes: 22 additions & 23 deletions stator/symbolic/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,14 @@ namespace sym {
comparisons between RTBase derived types.
*/
template<class LHS>
struct ComparisonVisitor : public sym::detail::VisitorHelper<ComparisonVisitor<LHS> > {
struct ComparisonVisitor : public sym::detail::VisitorHelper<ComparisonVisitor<LHS>, bool> {
ComparisonVisitor(const LHS& l): _l(l) {}

template<class RHS> sym::Expr apply(const RHS& r) {
_result = _l == r;
return Expr();
template<class RHS> bool apply(const RHS& r) {
return _l == r;
}

const LHS& _l;
bool _result;
};
}

Expand All @@ -303,20 +301,16 @@ namespace sym {
virtual bool compare(const Expr& rhs) const {
const Derived& lhs(*static_cast<const Derived*>(this));
detail::ComparisonVisitor<Derived> visitor(lhs);
rhs->visit(visitor);
return visitor._result;
return rhs->visit(visitor);
}
};

namespace detail {
/*! \brief Binding power visitor for sym::detail::BP(const Expr&). */
struct BPVisitor : public sym::detail::VisitorHelper<BPVisitor> {
template<class T> sym::Expr apply(const T& rhs) {
_BP = BP(rhs);
return Expr();
struct BPVisitor : public sym::detail::VisitorHelper<BPVisitor, std::pair<int, int>> {
template<class T> std::pair<int, int> apply(const T& rhs) {
return BP(rhs);
}

std::pair<int, int> _BP;
};
}

Expand All @@ -325,8 +319,7 @@ namespace sym {
*/
inline std::pair<int, int> BP(const Expr& v) {
detail::BPVisitor vis;
v->visit(vis);
return vis._BP;
return v->visit(vis);
}
}

Expand Down Expand Up @@ -497,6 +490,16 @@ namespace sym {
return Expr(store(Op::apply(l, r)));
}

template<class Op>
Expr dd_visit(const Dict& l, const Dict& r, Op) {
return Expr(simplify(Op::apply(l, r)));
}

template<class Op>
Expr dd_visit(const List& l, const List& r, Op) {
return Expr(simplify(Op::apply(l, r)));
}

//Direct evaluation of doubles
template<class T2>
Expr dd_visit(const double& l, const T2& r, detail::Subtract) {
Expand Down Expand Up @@ -873,21 +876,17 @@ namespace sym {

namespace detail {
template<class Config>
struct ReprVisitor : public sym::detail::VisitorHelper<ReprVisitor<Config> > {
template<class T> sym::Expr apply(const T& rhs) {
_repr = repr<Config>(rhs);
return sym::Expr();
struct ReprVisitor : public sym::detail::VisitorHelper<ReprVisitor<Config>, std::string> {
template<class T> std::string apply(const T& rhs) {
return repr<Config>(rhs);
}

std::string _repr;
};
}

template<class Config>
std::string repr(const sym::RTBase& b) {
detail::ReprVisitor<Config> visitor;
b.visit(visitor);
return visitor._repr;
return b.visit(visitor);
}

/*! \brief Give a representation of an Expr.
Expand Down
5 changes: 3 additions & 2 deletions stator/symbolic/unary_ops_rt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ namespace sym {
return std::shared_ptr<UnaryOp>(new UnaryOp(arg));
}

bool operator==(const UnaryOp<Expr,Op>& o) const {
return (_arg == o._arg);
bool operator==(const UnaryOp& o) const {
//Shortcut comparison before proceeding with item by item
return (this == &o) || (_arg == o._arg);
}

template<class RHS>
Expand Down
26 changes: 26 additions & 0 deletions stator/symbolic/units.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,29 @@
*/

#pragma once

namespace sym {
namespace detail {
struct Units {
static constexpr int leftBindingPower = 60;
static constexpr auto associativity = Associativity::LEFT;
static constexpr bool commutative = true;
static constexpr bool associative = true;
typedef NoIdentity left_identity;
typedef Unity right_identity;
typedef Null left_zero;
typedef NoIdentity right_zero;
static inline std::string l_repr() { return ""; }
static inline std::string repr() { return "{"; }
static inline std::string r_repr() { return "}"; }
static inline std::string l_latex_repr() { return ""; }
static inline std::string latex_repr() { return "\left\{"; }
static inline std::string r_latex_repr() { return "\right\}"; }
template<class L, class R> static auto apply(const L& l, const R& r) {
return BinaryOp<decltype(store(l)), detail::Units, decltype(store(r))>::create(l, r);
}
static constexpr int type_index = 18;
};
}

}
3 changes: 1 addition & 2 deletions stator/symbolic/variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ namespace sym {

namespace detail {
template<auto N, class ...Args>
struct Type_index<Var<N, Args...> >
{ static const int value = 1; };
struct Type_index<Var<N, Args...> > { static const int value = 1; };
}
}

Expand Down
3 changes: 2 additions & 1 deletion stator/symbolic/variable_rt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ namespace sym {


inline bool operator==(const VarRT& o) const {
return _name == o._name;
//Shortcut comparison before proceeding with string comparison
return (this == &o) || (_name == o._name);
}

template<class RHS>
Expand Down
10 changes: 8 additions & 2 deletions tests/stator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_interface(self):
self.assertEqual(repr(Expr("1=1")),"Expr('1=1')")
self.assertEqual(repr(Expr("x[1]")),"Expr('x[1]')")

def test_array(self):
def test_list(self):
self.assertEqual(repr(Expr("[]")),"Expr('[]')")
self.assertEqual(repr(Expr("[1]")),"Expr('[1]')")
self.assertEqual(repr(Expr("[1, 2]")),"Expr('[1, 2]')")
Expand All @@ -22,11 +22,17 @@ def test_array(self):
df = derivative(f, x)
self.assertEqual(repr(simplify(df)), "Expr('[0, 1, 4*x]')")
self.assertEqual(sub(df, Expr('x=2')), [0, 1, 8])
self.assertEqual(Expr("[1,2,3,4]") + Expr("[0,1,2,3]"), Expr('[1, 3, 5, 7]').to_python())
self.assertEqual(Expr("[1,2,3,4]") - Expr("[0,1,2,3]"), Expr('[1, 1, 1, 1]').to_python())
self.assertEqual(Expr("[1,2,3,4]") * Expr("[0,1,2,3]"), Expr('[0, 2, 6, 12]').to_python())
self.assertEqual(Expr("[1,2,3,4]") / Expr("[1,2,3,2]"), Expr('[1, 1, 1, 2]').to_python())


def test_dict(self):
self.assertEqual(repr(Expr("{}")),"Expr('{}')")
self.assertEqual(repr(Expr("{x:1}")),"Expr('{x:1}')")
self.assertEqual(Expr("{x:1}") + Expr("{x:1, y:2}"), Expr('{x:2.0, y:2.0}').to_python())
self.assertEqual(Expr("{x:1}") - Expr("{x:1, y:2}"), Expr('{x:0, y:-2.0}').to_python())
self.assertEqual(Expr("{x:3}") * Expr("{x:2, y:2}"), Expr('{x:6}').to_python())

def test_sub_generic(self):
self.assertEqual(sub(Expr("x"), Expr('{x:2}')), 2)
Expand Down
3 changes: 2 additions & 1 deletion tests/symbolic_runtime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ UNIT_TEST( symbolic_list_basic )
auto& x = *x_ptr;
UNIT_TEST_CHECK_EQUAL(sub(Expr("[1, x, y]"), Expr(x=2)), Expr("[1,2,y]"));
UNIT_TEST_CHECK_EQUAL(derivative(Expr("[1, x, y]"), x), Expr("[0,1,0]"));
UNIT_TEST_CHECK_EQUAL(simplify(Expr("[1, 1, 1]")+Expr("[0, 1, 2]")), Expr("[1,2,3]"));
}

UNIT_TEST( symbolic_dict_basic )
Expand All @@ -237,7 +238,7 @@ UNIT_TEST( symbolic_dict_basic )
v[y] = Expr(3);
UNIT_TEST_CHECK_EQUAL(sub(Expr("x"), v), Expr("2"));
UNIT_TEST_CHECK_EQUAL(sub(Expr("y"), v), Expr("3"));

UNIT_TEST_CHECK_EQUAL(sub(Expr("[1, x, y]"), Expr("{x:2, y:3}")), Expr("[1,2,3]"));
UNIT_TEST_CHECK_EQUAL(sub(Expr("[1, x, y]"), v), Expr("[1,2,3]"));
}

Expand Down

0 comments on commit 22ae8fd

Please sign in to comment.