Skip to content

Commit

Permalink
The actual fix for Eigen usage
Browse files Browse the repository at this point in the history
  • Loading branch information
toastedcrumpets committed Sep 2, 2022
1 parent 7117ad3 commit 00e4578
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 258 deletions.
16 changes: 7 additions & 9 deletions stator/symbolic/binary_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ namespace sym {
static inline std::string latex_repr() { return "+"; }
static inline std::string r_latex_repr() { return ""; }
//Apply has to accept by const ref, as returned objs may reference/alias the arguments, so everything needs at least the parent scope
template<class L, class R> static auto apply(const L& l, const R& r) { return l + r; }
template<class L, class R> static auto apply(const L& l, const R& r) { return store(l + r); }
static constexpr int type_index = 8;
};

Expand All @@ -92,7 +92,7 @@ namespace sym {
static inline std::string l_latex_repr() { return ""; }
static inline std::string latex_repr() { return "-"; }
static inline std::string r_latex_repr() { return ""; }
template<class L, class R> static auto apply(const L& l, const R& r) { return l - r; }
template<class L, class R> static auto apply(const L& l, const R& r) { return store(l - r); }
static constexpr int type_index = 9;
};

Expand All @@ -112,7 +112,7 @@ namespace sym {
static inline std::string l_latex_repr() { return ""; }
static inline std::string latex_repr() { return "\\times "; }
static inline std::string r_latex_repr() { return ""; }
template<class L, class R> static auto apply(const L& l, const R& r) { return l * r; }
template<class L, class R> static auto apply(const L& l, const R& r) { return store(l * r); }
static constexpr int type_index = 10;
};

Expand All @@ -132,7 +132,7 @@ namespace sym {
static inline std::string l_latex_repr() { return "\\frac{"; }
static inline std::string latex_repr() { return "}{"; }
static inline std::string r_latex_repr() { return "}"; }
template<class L, class R> static auto apply(const L& l, const R& r) { return l / r; }
template<class L, class R> static auto apply(const L& l, const R& r) { return store(l / r); }
static constexpr int type_index = 11;
};

Expand All @@ -156,7 +156,7 @@ namespace sym {
//the MSVSC compiler gets confused
template<class L, class R,
typename = typename std::enable_if<!std::is_base_of<Eigen::EigenBase<R>, R>::value>::type>
static auto apply(const L& l, const R& r) { return pow(l, r); }
static auto apply(const L& l, const R& r) { return store(pow(l, r)); }
static constexpr int type_index = 12;
};

Expand All @@ -177,7 +177,7 @@ namespace sym {
static inline std::string latex_repr() { return "="; }
static inline std::string r_latex_repr() { return ""; }
template<class L, class R> static auto apply(const L& l, const R& r) {
return BinaryOp<decltype(store(l)), detail::Equality, decltype(store(r))>::create(l, r);
return BinaryOp<decltype(store(l)), detail::Equality, decltype(store(r))>::create(l, r);
}
static constexpr int type_index = 13;
};
Expand All @@ -199,9 +199,7 @@ namespace sym {
typedef NoIdentity right_zero;
typedef NoIdentity left_zero;
template<class L, class R>
static auto apply(const L& l, const R& r) -> decltype(store(l[r])) {
return l[r];
}
static auto apply(const L& l, const R& r) { return store(l[r]); }
static constexpr int type_index = 14;
};
}
Expand Down
214 changes: 214 additions & 0 deletions tests/symbolic_poly_solve_roots_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,220 @@ UNIT_TEST( poly_cubic_special_cases )
}
}

UNIT_TEST( poly_Sturm_chains )
{
using namespace sym;
const Polynomial<1> x{0, 1};

{ //Example from wikipedia (x^4+x^3-x-1)
auto f = expand(x*x*x*x + x*x*x - x - 1);
auto chain = sturm_chain(f);

UNIT_TEST_CHECK(compare_expression(chain.get(0), f));
UNIT_TEST_CHECK(compare_expression(chain.get(1), expand(4*x*x*x + 3*x*x - 1)));
UNIT_TEST_CHECK(compare_expression(chain.get(2), expand((3.0/16) * x*x + (3.0/4)*x + (15.0/16))));
UNIT_TEST_CHECK(compare_expression(chain.get(3), expand(-32*x -64)));
UNIT_TEST_CHECK(compare_expression(chain.get(4), Polynomial<0>{-3.0/16}));
UNIT_TEST_CHECK(compare_expression(chain.get(5), Polynomial<0>{0}));
UNIT_TEST_CHECK(compare_expression(chain.get(6), Polynomial<0>{0}));

//This polynomial has roots at -1 and +1
UNIT_TEST_CHECK_EQUAL(chain.sign_changes(-HUGE_VAL), 3u);
UNIT_TEST_CHECK_EQUAL(chain.sign_changes(0), 2u);
UNIT_TEST_CHECK_EQUAL(chain.sign_changes(HUGE_VAL), 1u);

UNIT_TEST_CHECK_EQUAL(chain.roots(0.5, 3.0), 1u);
UNIT_TEST_CHECK_EQUAL(chain.roots(-2.141, -0.314159265), 1u);
UNIT_TEST_CHECK_EQUAL(chain.roots(-HUGE_VAL, HUGE_VAL), 2u);
}
}

UNIT_TEST( descartes_sturm_and_budan_01_alesina_rootcount_test )
{
using namespace sym;
const Polynomial<1> x{0, 1};

//The values 0.5, and 2.0 are strong tests of the algorithms, as
//these are the division points for VCA and VAS algorithms.
const double roots[] = {-1e5, -0.14159265, -0.0001,0.1, 0.3333, 0.5, 0.8, 1.001, 2.0, 3.14159265, 1e7};

for (const double root1: roots)
for (const double root2: roots)
if (root1 != root2)
for (const double root3: roots)
if (!std::set<double>{root1, root2}.count(root3))
for (const double root4: roots)
if (!std::set<double>{root1, root2, root3}.count(root4))
for (const double root5: roots)
if (!std::set<double>{root1, root2, root3, root4}.count(root5))
for (int sign : {-1, +1}) {
//Test where all 5 roots of a 5th order
//Polynomial are real
auto f1 = expand(sign * (x - root1) * (x - root2) * (x - root3) * (x - root4) * (x - root5));

//Test with the same roots, but an additional 2
//imaginary roots
auto f2 = expand(f1 * (x * x - 3 * x + 4));

auto roots_in_range = [&](double a, double b) {
return size_t
(((root1 > a) && (root1 < b))
+((root2 > a) && (root2 < b))
+((root3 > a) && (root3 < b))
+((root4 > a) && (root4 < b))
+((root5 > a) && (root5 < b)))
; };

size_t roots_in_01 = roots_in_range(0, 1);

auto chain1 = sturm_chain(f1);
auto chain2 = sturm_chain(f2);

//Test interval [0,1]
switch (roots_in_01) {
case 0:
case 1:
UNIT_TEST_CHECK_EQUAL(budan_01_test(f1), roots_in_01);
UNIT_TEST_CHECK_EQUAL(budan_01_test(f2), roots_in_01);
UNIT_TEST_CHECK_EQUAL(alesina_galuzzi_test(f1, 0.0, 1.0), roots_in_01);
UNIT_TEST_CHECK_EQUAL(alesina_galuzzi_test(f2, 0.0, 1.0), roots_in_01);
break;
default:
UNIT_TEST_CHECK(budan_01_test(f1) >= roots_in_01);
UNIT_TEST_CHECK(budan_01_test(f2) >= roots_in_01);
UNIT_TEST_CHECK(alesina_galuzzi_test(f1, 0.0, 1.0) >= roots_in_01);
UNIT_TEST_CHECK(alesina_galuzzi_test(f2, 0.0, 1.0) >= roots_in_01);
break;
}
UNIT_TEST_CHECK_EQUAL(chain1.roots(0.0, 1.0), roots_in_01);
UNIT_TEST_CHECK_EQUAL(chain2.roots(0.0, 1.0), roots_in_01);

//Test interval [0, \infty]
size_t positive_roots = roots_in_range(0, HUGE_VAL);
switch (positive_roots) {
case 0:
case 1:
UNIT_TEST_CHECK_EQUAL(descartes_rule_of_signs(f1), positive_roots);
UNIT_TEST_CHECK_EQUAL(descartes_rule_of_signs(f2), positive_roots);
break;
default:
UNIT_TEST_CHECK(descartes_rule_of_signs(f1) >= positive_roots);
break;
}
UNIT_TEST_CHECK_EQUAL(chain1.roots(0.0, HUGE_VAL), positive_roots);
UNIT_TEST_CHECK_EQUAL(chain2.roots(0.0, HUGE_VAL), positive_roots);

//Try some others
UNIT_TEST_CHECK_EQUAL(chain1.roots(-HUGE_VAL, HUGE_VAL), 5u);
UNIT_TEST_CHECK_EQUAL(chain2.roots(-HUGE_VAL, HUGE_VAL), 5u);
UNIT_TEST_CHECK(alesina_galuzzi_test(f1,-1.0, 30.0) >= roots_in_range(-1, 30));
UNIT_TEST_CHECK(alesina_galuzzi_test(f1,-0.01, 5.0) >= roots_in_range(-0.01, 5));
}
}

UNIT_TEST( LMQ_upper_bound_test )
{
using namespace sym;
const Polynomial<1> x{0, 1};

const double roots[] = {-1e5, -0.14159265, 3.14159265, -0.0001,0.1, 0.3333, 0.6, 1.001, 2.0, 3.14159265, 1e7};

//Test simple expressions
for (const double root1: roots)
for (const double root2: roots)
for (const double root3: roots)
for (const double root4: roots)
for (int sign : {-1, +1})
{
auto f = expand(sign * (x - root1) * (x - root2) * (x - root3) * (x - root4));

double max_root = root1;
max_root = std::max(max_root, root2);
max_root = std::max(max_root, root3);
max_root = std::max(max_root, root4);

double bound = LMQ_upper_bound(f);
if (max_root < 0)
UNIT_TEST_CHECK_EQUAL(bound, 0);
else
UNIT_TEST_CHECK(bound >= max_root);
}

//Test expressions with zero coefficients
for (const double root1: roots)
for (const double root2: roots)
for (int sign : {-1, +1})
{
auto f = expand(sign * (x - root1) * (x - root2) + 0 * x*x*x*x*x);
double max_root = std::max(root1, root2);

double bound = LMQ_upper_bound(f);
if (max_root < 0)
UNIT_TEST_CHECK_EQUAL(bound, 0);
else
UNIT_TEST_CHECK(bound >= max_root);
}

//Test constant coefficients
UNIT_TEST_CHECK_EQUAL(LMQ_upper_bound(expand(1 + 0 * x*x*x*x*x)), 0);
}

UNIT_TEST( LMQ_lower_bound_test )
{
using namespace sym;
const Polynomial<1> x{0, 1};

const double roots[] = {-1e5, -0.14159265, 3.14159265, -0.0001,0.1, 0.3333, 0.6, 1.001, 2.0, 3.14159265, 1e7};

for (const double root1: roots)
for (const double root2: roots)
for (const double root3: roots)
for (const double root4: roots)
for (int sign : {-1, +1})
{
auto f = expand(sign * (x - root1) * (x - root2) * (x - root3) * (x - root4));

double min_pos_root = HUGE_VAL;
if (root1 >= 0)
min_pos_root = std::min(min_pos_root, root1);
if (root2 >= 0)
min_pos_root = std::min(min_pos_root, root2);
if (root3 >= 0)
min_pos_root = std::min(min_pos_root, root3);
if (root4 >= 0)
min_pos_root = std::min(min_pos_root, root4);

double bound = LMQ_lower_bound(f);
if (min_pos_root == HUGE_VAL)
UNIT_TEST_CHECK_EQUAL(bound, HUGE_VAL);
else
UNIT_TEST_CHECK(bound <= min_pos_root);
}

//Test expressions with zero coefficients
for (const double root1: roots)
for (const double root2: roots)
for (int sign : {-1, +1})
{
auto f = expand(sign * (x - root1) * (x - root2) + 0 * x*x*x*x*x);

double min_pos_root = HUGE_VAL;
if (root1 >= 0)
min_pos_root = std::min(min_pos_root, root1);
if (root2 >= 0)
min_pos_root = std::min(min_pos_root, root2);

double bound = LMQ_lower_bound(f);
if (min_pos_root == HUGE_VAL)
UNIT_TEST_CHECK_EQUAL(bound, HUGE_VAL);
else
UNIT_TEST_CHECK(bound <= min_pos_root);
}

//Test constant coefficients
UNIT_TEST_CHECK_EQUAL(LMQ_lower_bound(expand(1 + 0 * x*x*x*x*x)), HUGE_VAL);
}

UNIT_TEST( poly_root_tests)
{
using namespace sym;
Expand Down
Loading

0 comments on commit 00e4578

Please sign in to comment.