Skip to content

Commit

Permalink
Missing files from previous commit
Browse files Browse the repository at this point in the history
  • Loading branch information
toastedcrumpets committed Sep 3, 2021
1 parent eb92c7e commit 1a566fe
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 105 deletions.
4 changes: 2 additions & 2 deletions pysrc/stator/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct ToPythonVisitor : public sym::detail::VisitorHelper<ToPythonVisitor, py::
return std::move(out);
}

py::object apply(const sym::Dict& v) {
py::object apply(const sym::DictRT& v) {
py::dict out;
for (const auto& itm : v)
out[to_python(itm.first)] = to_python(itm.second);
Expand All @@ -43,7 +43,7 @@ py::object to_python(const sym::Expr& b) {
}

sym::Expr make_Expr(const py::dict& d) {
auto out_ptr = sym::Dict::create();
auto out_ptr = sym::DictRT::create();
auto& out = *out_ptr;
for (const auto& item : d) {
sym::Expr key = py::cast<sym::Expr>(item.first);
Expand Down
21 changes: 10 additions & 11 deletions stator/symbolic/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ namespace sym {
return Base::_store[d];
}

template<class...Args>
bool operator==(const Addressing<Args...>& ad) const {
return (ad.getDimensions() == getDimensions()) && (ad._store == Base::_store);
}

auto begin() const { return Base::_store.begin(); }
auto begin(){ return Base::_store.begin(); }
auto end() const { return Base::_store.end(); }
Expand Down Expand Up @@ -169,6 +174,11 @@ namespace sym {
return Base::_store[coords_to_index(d)];
}

template<class...Args>
bool operator==(const Addressing<Args...>& ad) const {
return (ad.getDimensions() == getDimensions()) && (ad._store == Base::_store);
}

void resize(const Coords& d) {
_dimensions = d;
if constexpr (StoreSize == -1u) {
Expand Down Expand Up @@ -302,17 +312,6 @@ namespace sym {
typedef Addressing<T, Addressing_t> Base;
using Base::Base;

template<class... Args>
bool operator==(const Array<Args...>& o) const {
//Shortcut comparison before proceeding with item by item
return (this == &o) || (Base::_store == o._store);
}

template<class RHS> bool operator==(const RHS& r) const {
return false;
}


template<class ...Args>
static auto create(Args&&...args) {
return std::move(Array(args...));
Expand Down
17 changes: 9 additions & 8 deletions stator/symbolic/array_rt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ namespace sym {
return ArrayPtr(new Array(vals));
}

//We need to force the use of the Addressing operator[], not the RTBaseHelper::operator[]
//We need to force the use of the Addressing operator[], not the RTBaseHelper::operator[], same for ==
using Base::operator[];
using Base::operator==;

template<class... Args>
bool operator==(const Array<Args...>& o) const {
//Shortcut comparison before proceeding with item by item
return (this == &o) || (Base::_store == o._store);
template<class RHS>
bool operator==(const RHS& ad) const {
return false;
}

template<class RHS> bool operator==(const RHS& r) const {
return false;
}
template<class...Args>
bool operator==(const Array<Args...>& ad) const {
return (ad.getDimensions() == getDimensions()) && (ad._store == Base::_store);
}
};

namespace detail {
Expand Down
34 changes: 23 additions & 11 deletions stator/symbolic/dict.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ namespace sym {
DictBase() {}
DictBase(const DictBase& l) = default;

typedef typename Store::key_type key_type;
typedef typename Store::value_type value_type;
typedef typename Store::reference reference;
typedef typename Store::const_reference const_reference;
typedef typename Store::iterator iterator;
typedef typename Store::const_iterator const_iterator;

Expand All @@ -43,12 +39,12 @@ namespace sym {
const_iterator end() const {return _store.end();}
const_iterator cend() const {return _store.cend();}

iterator find( const key_type& key ) { return _store.find(key); }
const_iterator find( const key_type& key ) const { return _store.find(key); }
iterator find( const Key& key ) { return _store.find(key); }
const_iterator find( const Key& key ) const { return _store.find(key); }

reference& operator[](const key_type& k) { return _store[k]; }
reference& at(const key_type& k) { return _store.at(k); }
const const_reference& at(const key_type& k) const { return _store.at(k); }
Value& operator[](const Key& k) { return _store[k]; }
Value& at(const Key& k) { return _store.at(k); }
const Value& at(const Key& k) const { return _store.at(k); }

typename Store::size_type size() const noexcept { return _store.size(); }
bool empty() const noexcept { return _store.empty(); }
Expand All @@ -64,11 +60,11 @@ namespace sym {
return false;
}

std::pair<iterator,bool> insert( const value_type& value ) {
std::pair<iterator,bool> insert( const typename Store::value_type& value ) {
return _store.insert(value);
}

std::pair<iterator,bool> insert(value_type&& value ) {
std::pair<iterator,bool> insert(typename Store::value_type&& value ) {
return _store.insert(std::move(value));
}

Expand All @@ -86,4 +82,20 @@ namespace sym {
struct Expr;
template<> class Dict<Expr, Expr>;
typedef Dict<Expr, Expr> DictRT;
}

namespace std
{
template<class Key, class Value> struct hash<sym::Dict<Key, Value>>
{
std::size_t operator()(sym::Dict<Key, Value> const& v) const noexcept
{
std::size_t seed = 16;
for (const auto& item : v) {
stator::hash_combine(seed, std::hash<typename std::decay<decltype(item.first)>::type>{}(item.first));
stator::hash_combine(seed, std::hash<typename std::decay<decltype(item.second)>::type>{}(item.second));
}
return seed;
}
};
}
81 changes: 25 additions & 56 deletions stator/symbolic/dict_rt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,66 +19,34 @@

#pragma once

#include <stator/symbolic/array.hpp>
#include <stator/symbolic/runtime.hpp>

namespace sym {
class Dict;
typedef std::shared_ptr<Dict> DictPtr;

class Dict: public RTBaseHelper<Dict> {
typedef std::unordered_map<Expr, Expr> Store;

Dict() {}
Dict(const Dict& l) = default;
public:
static auto create() {
return DictPtr(new Dict());
}

typedef Store::key_type key_type;
typedef Store::value_type value_type;
typedef Store::reference reference;
typedef Store::const_reference const_reference;
typedef typename Store::iterator iterator;
typedef typename Store::const_iterator const_iterator;

iterator begin() {return _store.begin();}
const_iterator begin() const {return _store.begin();}
const_iterator cbegin() const {return _store.cbegin();}
iterator end() {return _store.end();}
const_iterator end() const {return _store.end();}
const_iterator cend() const {return _store.cend();}

iterator find( const key_type& key ) { return _store.find(key); }
const_iterator find( const key_type& key ) const { return _store.find(key); }

Expr& operator[](const key_type& k) { return _store[k]; }
Expr& at(const key_type& k) { return _store.at(k); }
const Expr& at(const key_type& k) const { return _store.at(k); }

Store::size_type size() const noexcept { return _store.size(); }
bool empty() const noexcept { return _store.empty(); }

bool operator==(const Dict& o) const {
//Shortcut comparison before proceeding with item by item
return (this == &o) || (_store == o._store);
}
template<>
class Dict<Expr, Expr>: public RTBaseHelper<DictRT>, public DictBase<Expr, Expr> {
typedef DictBase<Expr, Expr> Base;

template<class RHS>
constexpr bool operator==(const RHS&) const {
return false;
}

std::pair<iterator,bool> insert( const value_type& value ) {
return _store.insert(value);
}

std::pair<iterator,bool> insert( value_type&& value ) {
return _store.insert(std::move(value));
}

private:
Store _store;
Dict() {}

public:

typedef std::shared_ptr<Dict> DictPtr;

static auto create() {
return DictPtr(new Dict());
}

using Base::operator==;
};

typedef Dict<Expr, Expr> DictRT;

namespace detail {
template<> struct Type_index<DictRT> { static const int value = 16; };
}

/*
auto operator+(const Dict& l, const Dict& r) {
auto out_ptr = Dict::create();
auto& out = *out_ptr;
Expand Down Expand Up @@ -185,4 +153,5 @@ namespace std
return seed;
}
};
*/
}
Empty file removed stator/symbolic/list_rt.hpp
Empty file.
2 changes: 1 addition & 1 deletion stator/symbolic/parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ namespace sym {

struct DictToken : public LeftOperatorBase {
Expr apply(ExprTokenizer& tk) const {
auto a_ptr = Dict::create();
auto a_ptr = DictRT::create();
auto& a = *a_ptr;

if (tk.next() == "}") {
Expand Down
25 changes: 11 additions & 14 deletions stator/symbolic/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ namespace sym {
template<class T> class ConstantRT;
template<> struct Var<nullptr>;
typedef Var<nullptr> VarRT;
template<> class Array<Expr, RowMajorAddressing<-1u, -1u>>;

class Dict;


namespace detail {
template<class RetType> struct VisitorInterface;
}
Expand Down Expand Up @@ -161,7 +158,7 @@ namespace sym {
Expr(const detail::NoIdentity&) { stator_throw() << "This should never be called as NoIdentity is not a usable type";}

explicit Expr(const ArrayRT&);
explicit Expr(const Dict&);
explicit Expr(const DictRT&);

/*! \brief Expression comparison operator.
Expand Down Expand Up @@ -230,7 +227,7 @@ namespace sym {
virtual RetType visit(const BinaryOp<Expr, detail::ArrayAccess, Expr>& ) = 0;
virtual RetType visit(const BinaryOp<Expr, detail::Units, Expr>& ) = 0;
virtual RetType visit(const ArrayRT& ) = 0;
virtual RetType visit(const Dict& ) = 0;
virtual RetType visit(const DictRT& ) = 0;
virtual RetType visit(const UnaryOp<Expr, detail::Negate>& ) = 0;
};

Expand Down Expand Up @@ -275,7 +272,7 @@ namespace sym {
{ return static_cast<Derived*>(this)->apply(x); }
inline virtual RetType visit(const ArrayRT& x)
{ return static_cast<Derived*>(this)->apply(x); }
inline virtual RetType visit(const Dict& x)
inline virtual RetType visit(const DictRT& x)
{ return static_cast<Derived*>(this)->apply(x); }
inline virtual RetType visit(const UnaryOp<Expr, detail::Negate>& x)
{ return static_cast<Derived*>(this)->apply(x); }
Expand Down Expand Up @@ -388,7 +385,7 @@ namespace sym {
inline Expr::Expr(const VarRT& v) : Base(v.shared_from_this()) {}

Expr::Expr(const ArrayRT& v) : Base(v.shared_from_this()) {}
Expr::Expr(const Dict& v) : Base(v.shared_from_this()) {}
Expr::Expr(const DictRT& v) : Base(v.shared_from_this()) {}

template<class T>
const T& Expr::as() const {
Expand Down Expand Up @@ -451,7 +448,7 @@ namespace sym {
return simplify(v);
}

Expr apply(const Dict& v) {
Expr apply(const DictRT& v) {
return simplify(v);
}

Expand Down Expand Up @@ -658,7 +655,7 @@ namespace sym {

namespace detail {
struct SubstituteDictRT : VisitorHelper<SubstituteDictRT> {
SubstituteDictRT(const Dict& replacement):
SubstituteDictRT(const DictRT& replacement):
_replacement(replacement)
{}

Expand Down Expand Up @@ -731,7 +728,7 @@ namespace sym {
return Expr();
}

const Dict& _replacement;
const DictRT& _replacement;
};

}
Expand All @@ -744,7 +741,7 @@ namespace sym {
return (result) ? result : f;
}

Expr sub(const Expr& f, const Dict& rep) {
Expr sub(const Expr& f, const DictRT& rep) {
detail::SubstituteDictRT visitor(rep);
Expr result = f->visit(visitor);
return (result) ? result : f;
Expand All @@ -761,7 +758,7 @@ namespace sym {
stator_throw() << "No substitution process available for " << v << "\n Needs to be a Equality or a Dict.";
}

Expr apply(const Dict& d) {
Expr apply(const DictRT& d) {
return sub(_f, d);
}

Expand Down Expand Up @@ -897,7 +894,7 @@ namespace sym {
case detail::Type_index<BinaryOp<Expr, detail::ArrayAccess, Expr>>::value: return c.visit(static_cast<const BinaryOp<Expr, detail::ArrayAccess, Expr>&>(*this));
case detail::Type_index<BinaryOp<Expr, detail::Units, Expr>>::value: return c.visit(static_cast<const BinaryOp<Expr, detail::Units, Expr>&>(*this));
case detail::Type_index<ArrayRT>::value: return c.visit(static_cast<const ArrayRT&>(*this));
case detail::Type_index<Dict>::value: return c.visit(static_cast<const Dict&>(*this));
case detail::Type_index<DictRT>::value: return c.visit(static_cast<const DictRT&>(*this));
case detail::Type_index<UnaryOp<Expr, detail::Negate>>::value: return c.visit(static_cast<const UnaryOp<Expr, detail::Negate>&>(*this));
default: stator_throw() << "Unhandled type index (" << _type_idx << ") for the visitor";
}
Expand Down
1 change: 1 addition & 0 deletions stator/symbolic/symbolic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ namespace sym {
#include <stator/symbolic/simplify.hpp>
#include <stator/symbolic/integrate.hpp>
#include <stator/symbolic/array.hpp>
#include <stator/symbolic/dict.hpp>
#include <stator/symbolic/taylor.hpp>
#include <stator/symbolic/units.hpp>

Expand Down
5 changes: 4 additions & 1 deletion tests/symbolic_array_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ UNIT_TEST(symbolic_array_runtime) {

UNIT_TEST_CHECK_EQUAL(sym::repr(test_array), "[1, x, y]");
UNIT_TEST_CHECK_EQUAL(sym::repr(derivative(test_array, x)), "[0, 1, 0]");
UNIT_TEST_CHECK_EQUAL(sym::sub(sym::ArrayRT::create({sym::Expr(1), x, y}), sym::Expr(x=2)), sym::Expr("[1,2,y]"));
auto f = sym::sub(sym::ArrayRT::create({sym::Expr(1), x, y}), sym::Expr(x=2));
auto result = sym::Expr("[1, 2, y]");
f == result;
UNIT_TEST_CHECK_EQUAL(f, result);
UNIT_TEST_CHECK_EQUAL(sym::simplify(sym::Expr("[1, 1, 1]")+sym::Expr("[0, 1, 2]")), sym::Expr("[1,2,3]"));
}
2 changes: 1 addition & 1 deletion tests/symbolic_runtime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ UNIT_TEST( symbolic_runtime_derivative )

UNIT_TEST( symbolic_dict_basic )
{
auto v_ptr = Dict::create();
auto v_ptr = DictRT::create();
auto& v = *v_ptr;

auto x_ptr = VarRT::create("x");
Expand Down

0 comments on commit 1a566fe

Please sign in to comment.