Skip to content

Commit

Permalink
Manage ColumnLayoutRelationData with a separate class
Browse files Browse the repository at this point in the history
Algorithms don't ask for table options when they are passed an external
instance
  • Loading branch information
BUYT-1 authored and chernishev committed Apr 15, 2024
1 parent be87f07 commit 7bb14c4
Show file tree
Hide file tree
Showing 25 changed files with 128 additions and 74 deletions.
28 changes: 17 additions & 11 deletions src/core/algorithms/create_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

namespace algos {

template <typename AlgorithmBase = Algorithm>
std::unique_ptr<AlgorithmBase> CreateAlgorithmInstance(AlgorithmType algorithm) {
auto const create = [](auto I) -> std::unique_ptr<AlgorithmBase> {
template <typename AlgorithmBase = Algorithm, typename... ConstructorArgs>
std::unique_ptr<AlgorithmBase> CreateAlgorithmInstance(AlgorithmType algorithm,
ConstructorArgs&&... args) {
auto const create = [&args...](auto I) -> std::unique_ptr<AlgorithmBase> {
using AlgorithmType = std::tuple_element_t<I, AlgorithmTypes>;
if constexpr (std::is_convertible_v<AlgorithmType *, AlgorithmBase *>) {
return std::make_unique<AlgorithmType>();
if constexpr (std::is_convertible_v<AlgorithmType*, AlgorithmBase*>) {
return std::make_unique<AlgorithmType>(std::forward<ConstructorArgs>(args)...);
} else {
throw std::invalid_argument(
"Cannot use " + boost::typeindex::type_id<AlgorithmType>().pretty_name() +
Expand All @@ -23,16 +24,21 @@ std::unique_ptr<AlgorithmBase> CreateAlgorithmInstance(AlgorithmType algorithm)
static_cast<size_t>(algorithm), create);
}

template <typename AlgorithmBase>
std::vector<AlgorithmType> GetAllDerived() {
template <typename Base>
bool IsBaseOf(AlgorithmType algorithm) {
auto const is_derived = [](auto I) -> bool {
using AlgorithmType = std::tuple_element_t<I, AlgorithmTypes>;
return std::is_base_of_v<AlgorithmBase, AlgorithmType>;
using AlgoType = std::tuple_element_t<I, AlgorithmTypes>;
return std::is_base_of_v<Base, AlgoType>;
};
return boost::mp11::mp_with_index<std::tuple_size<AlgorithmTypes>>(
static_cast<size_t>(algorithm), is_derived);
}

template <typename AlgorithmBase>
std::vector<AlgorithmType> GetAllDerived() {
std::vector<AlgorithmType> derived_from_base{};
for (AlgorithmType algo : AlgorithmType::_values()) {
if (boost::mp11::mp_with_index<std::tuple_size<AlgorithmTypes>>(static_cast<size_t>(algo),
is_derived)) {
if (IsBaseOf<AlgorithmBase>(algo)) {
derived_from_base.push_back(algo);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/core/algorithms/fd/aidfd/aid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ namespace algos {

Aid::Aid() : FDAlgorithm({kDefaultPhaseName}) {
RegisterOptions();
MakeOptionsAvailable({config::TableOpt.GetName()});
MakeOptionsAvailable({config::kTableOpt.GetName()});
}

void Aid::RegisterOptions() {
RegisterOption(config::TableOpt(&input_table_));
RegisterOption(config::kTableOpt(&input_table_));
}

void Aid::LoadDataInternal() {
Expand Down
5 changes: 3 additions & 2 deletions src/core/algorithms/fd/depminer/depminer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

namespace algos {

Depminer::Depminer()
: PliBasedFDAlgorithm({"AgreeSets generation", "Finding CMAXSets", "Finding LHS"}) {}
Depminer::Depminer(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({"AgreeSets generation", "Finding CMAXSets", "Finding LHS"},
relation_manager) {}

using boost::dynamic_bitset, std::make_shared, std::shared_ptr, std::setw, std::vector, std::list,
std::dynamic_pointer_cast;
Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/depminer/depminer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Depminer : public PliBasedFDAlgorithm {
unsigned long long ExecuteInternal() final;

public:
Depminer();
Depminer(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);
};

} // namespace algos
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/dfd/dfd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

namespace algos {

DFD::DFD() : PliBasedFDAlgorithm({kDefaultPhaseName}) {
DFD::DFD(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({kDefaultPhaseName}, relation_manager) {
RegisterOptions();
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/dfd/dfd.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class DFD : public PliBasedFDAlgorithm {
unsigned long long ExecuteInternal() final;

public:
DFD();
DFD(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);
};

} // namespace algos
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/fastfds/fastfds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace algos {

using std::vector, std::set;

FastFDs::FastFDs() : PliBasedFDAlgorithm({"Agree sets generation", "Finding minimal covers"}) {
FastFDs::FastFDs(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({"Agree sets generation", "Finding minimal covers"}, relation_manager) {
RegisterOptions();
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/fastfds/fastfds.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace algos {

class FastFDs : public PliBasedFDAlgorithm {
public:
FastFDs();
FastFDs(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);

private:
using OrderingComparator = std::function<bool(Column const&, Column const&)>;
Expand Down
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/fd_mine/fd_mine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace algos {

using boost::dynamic_bitset;

FdMine::FdMine() : PliBasedFDAlgorithm({kDefaultPhaseName}) {}
FdMine::FdMine(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({kDefaultPhaseName}, relation_manager) {}

void FdMine::ResetStateFd() {
candidate_set_.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/fd_mine/fd_mine.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class FdMine : public PliBasedFDAlgorithm {
unsigned long long ExecuteInternal() override;

public:
FdMine();
FdMine(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);
};

} // namespace algos
4 changes: 2 additions & 2 deletions src/core/algorithms/fd/fdep/fdep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ namespace algos {

FDep::FDep() : FDAlgorithm({kDefaultPhaseName}) {
RegisterOptions();
MakeOptionsAvailable({config::TableOpt.GetName()});
MakeOptionsAvailable({config::kTableOpt.GetName()});
}

void FDep::RegisterOptions() {
RegisterOption(config::TableOpt(&input_table_));
RegisterOption(config::kTableOpt(&input_table_));
}

void FDep::LoadDataInternal() {
Expand Down
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/fun/fun.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ bool FunQuadruple::Contains(Vertical const& that) const {
return candidate_.Contains(that);
}

FUN::FUN() : PliBasedFDAlgorithm({kDefaultPhaseName}) {}
FUN::FUN(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({kDefaultPhaseName}, relation_manager) {}

void FUN::ResetStateFd() {
fds_.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/fun/fun.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class FunQuadruple {

class FUN : public PliBasedFDAlgorithm {
public:
FUN();
FUN(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);

// Entities from the algorithm itself
private:
Expand Down
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/hyfd/hyfd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

namespace algos::hyfd {

HyFD::HyFD() : PliBasedFDAlgorithm({}) {}
HyFD::HyFD(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({}, relation_manager) {}

unsigned long long HyFD::ExecuteInternal() {
using namespace hy;
Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/hyfd/hyfd.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class HyFD : public PliBasedFDAlgorithm {
void RegisterFDs(std::vector<RawFD>&& fds, std::vector<algos::hy::ClusterId> const& og_mapping);

public:
HyFD();
HyFD(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);
};

} // namespace algos::hyfd
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/pfdtane/pfdtane.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ void PFDTane::MakeExecuteOptsAvailableFDInternal() {

void PFDTane::ResetStateFd() {}

PFDTane::PFDTane() : PliBasedFDAlgorithm({kDefaultPhaseName}) {
PFDTane::PFDTane(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({kDefaultPhaseName}, relation_manager) {
RegisterOptions();
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/pfdtane/pfdtane.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PFDTane : public PliBasedFDAlgorithm {
unsigned long long ExecuteInternal() final;

public:
PFDTane();
PFDTane(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);
static double CalculateUccError(model::PositionListIndex const* pli,
ColumnLayoutRelationData const* relation_data);

Expand Down
31 changes: 15 additions & 16 deletions src/core/algorithms/fd/pli_based_fd_algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,26 @@

namespace algos {

PliBasedFDAlgorithm::PliBasedFDAlgorithm(std::vector<std::string_view> phase_names)
: FDAlgorithm(std::move(phase_names)) {
RegisterOptions();
MakeOptionsAvailable({config::TableOpt.GetName(), config::EqualNullsOpt.GetName()});
PliBasedFDAlgorithm::PliBasedFDAlgorithm(
std::vector<std::string_view> phase_names,
std::optional<ColumnLayoutRelationDataManager> relation_manager)
: FDAlgorithm(std::move(phase_names)),
relation_manager_(relation_manager.has_value()
? *relation_manager
: ColumnLayoutRelationDataManager{
&input_table_, &is_null_equal_null_, &relation_}) {
if (relation_manager.has_value()) return;
RegisterRelationManagerOptions();
MakeOptionsAvailable({config::kTableOpt.GetName(), config::kEqualNullsOpt.GetName()});
}

void PliBasedFDAlgorithm::RegisterOptions() {
RegisterOption(config::TableOpt(&input_table_));
RegisterOption(config::EqualNullsOpt(&is_null_equal_null_));
void PliBasedFDAlgorithm::RegisterRelationManagerOptions() {
RegisterOption(config::kTableOpt(&input_table_));
RegisterOption(config::kEqualNullsOpt(&is_null_equal_null_));
}

void PliBasedFDAlgorithm::LoadDataInternal() {
relation_ = ColumnLayoutRelationData::CreateFrom(*input_table_, is_null_equal_null_);
relation_ = relation_manager_.GetRelation();

if (relation_->GetColumnData().empty()) {
throw std::runtime_error("Got an empty dataset: FD mining is meaningless.");
Expand All @@ -37,12 +44,4 @@ std::vector<Column const*> PliBasedFDAlgorithm::GetKeys() const {
return keys;
}

void PliBasedFDAlgorithm::LoadData(std::shared_ptr<ColumnLayoutRelationData> data) {
if (data->GetColumnData().empty()) {
throw std::runtime_error("Got an empty dataset: FD mining is meaningless.");
} // TODO: this has to be repeated for every "alternative" data load
relation_ = std::move(data);
ExecutePrepare(); // TODO: this has to be repeated for every "alternative" data load
}

} // namespace algos
34 changes: 29 additions & 5 deletions src/core/algorithms/fd/pli_based_fd_algorithm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <optional>

#include "config/equal_nulls/type.h"
#include "config/tabular_data/input_table_type.h"
#include "fd_algorithm.h"
Expand All @@ -8,11 +10,35 @@
namespace algos {

class PliBasedFDAlgorithm : public FDAlgorithm {
public:
class ColumnLayoutRelationDataManager {
private:
config::InputTable* input_table_;
config::EqNullsType* is_null_equal_null_;
std::shared_ptr<ColumnLayoutRelationData>* relation_;

public:
ColumnLayoutRelationDataManager(
config::InputTable* input_table, config::EqNullsType* is_null_equal_null,
std::shared_ptr<ColumnLayoutRelationData>* relation_ptr) noexcept
: input_table_(input_table),
is_null_equal_null_(is_null_equal_null),
relation_(relation_ptr) {}

std::shared_ptr<ColumnLayoutRelationData> GetRelation() const {
if (*relation_ == nullptr)
*relation_ =
ColumnLayoutRelationData::CreateFrom(**input_table_, *is_null_equal_null_);
return *relation_;
}
};

private:
config::InputTable input_table_;
config::EqNullsType is_null_equal_null_;
ColumnLayoutRelationDataManager const relation_manager_;

void RegisterOptions();
void RegisterRelationManagerOptions();

protected:
std::shared_ptr<ColumnLayoutRelationData> relation_;
Expand All @@ -27,12 +53,10 @@ class PliBasedFDAlgorithm : public FDAlgorithm {
}

public:
explicit PliBasedFDAlgorithm(std::vector<std::string_view> phase_names);
PliBasedFDAlgorithm(std::vector<std::string_view> phase_names,
std::optional<ColumnLayoutRelationDataManager> relation_manager);

std::vector<Column const*> GetKeys() const override;

using Algorithm::LoadData;
void LoadData(std::shared_ptr<ColumnLayoutRelationData> data);
};

} // namespace algos
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/pyro/pyro.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace algos {

std::mutex search_spaces_mutex;

Pyro::Pyro() : PliBasedFDAlgorithm({kDefaultPhaseName}) {
Pyro::Pyro(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({kDefaultPhaseName}, relation_manager) {
RegisterOptions();
fd_consumer_ = [this](auto const& fd) {
this->DiscoverFd(fd);
Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/pyro/pyro.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Pyro : public DependencyConsumer, public PliBasedFDAlgorithm {
unsigned long long ExecuteInternal() final;

public:
Pyro();
Pyro(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);
};

} // namespace algos
3 changes: 2 additions & 1 deletion src/core/algorithms/fd/tane/tane.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace algos {

using boost::dynamic_bitset;

Tane::Tane() : PliBasedFDAlgorithm({kDefaultPhaseName}) {
Tane::Tane(std::optional<ColumnLayoutRelationDataManager> relation_manager)
: PliBasedFDAlgorithm({kDefaultPhaseName}, relation_manager) {
RegisterOptions();
}

Expand Down
2 changes: 1 addition & 1 deletion src/core/algorithms/fd/tane/tane.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Tane : public PliBasedFDAlgorithm {
int count_of_ucc_ = 0;
long apriori_millis_ = 0;

Tane();
Tane(std::optional<ColumnLayoutRelationDataManager> relation_manager = std::nullopt);

static double CalculateZeroAryFdError(ColumnData const* rhs,
ColumnLayoutRelationData const* relation_data);
Expand Down
Loading

0 comments on commit 7bb14c4

Please sign in to comment.