Skip to content

Commit

Permalink
Change reset function to correspond to gegelati changes
Browse files Browse the repository at this point in the history
  • Loading branch information
QuentinVacher-rl committed Jan 12, 2024
1 parent 5469dbf commit 419c60b
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion mnist/src/mnist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void MNIST::doAction(uint64_t actionID)
this->changeCurrentImage();
}

void MNIST::reset(size_t seed, Learn::LearningMode mode)
void MNIST::reset(size_t seed, Learn::LearningMode mode, uint16_t iterationNumber, uint64_t generationNumber)
{
// Reset the classificationTable
ClassificationLearningEnvironment::reset(seed);
Expand Down
3 changes: 2 additions & 1 deletion mnist/src/mnist.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class MNIST : public Learn::ClassificationLearningEnvironment {
virtual void doAction(uint64_t actionID) override;

/// Inherited via LearningEnvironment
virtual void reset(size_t seed = 0, Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
virtual void reset(size_t seed = 0, Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0, uint64_t generationNumber = 0) override;

/// Inherited via LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>> getDataSources() override;
Expand Down
2 changes: 1 addition & 1 deletion pendulum/src/Learn/pendulum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ std::vector<std::reference_wrapper<const Data::DataHandler>> Pendulum::getDataSo
return result;
}

void Pendulum::reset(size_t seed, Learn::LearningMode mode)
void Pendulum::reset(size_t seed, Learn::LearningMode mode, uint16_t iterationNumber, uint64_t generationNumber)
{
// Create seed from seed and mode
size_t hash_seed = Data::Hash<size_t>()(seed) ^ Data::Hash<Learn::LearningMode>()(mode);
Expand Down
3 changes: 2 additions & 1 deletion pendulum/src/Learn/pendulum.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class Pendulum : public Learn::LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>> getDataSources() override;

/// Inherited via LearningEnvironment
virtual void reset(size_t seed = 0, Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
virtual void reset(size_t seed = 0, Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0, uint64_t generationNumber = 0) override;

/**
* \brief Get the action from its associated ID.
Expand Down
2 changes: 1 addition & 1 deletion stickgame/src/Learn/stickGameAdversarial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void StickGameAdversarial::doAction(uint64_t actionID)
}
}

void StickGameAdversarial::reset(size_t seed, Learn::LearningMode mode)
void StickGameAdversarial::reset(size_t seed, Learn::LearningMode mode, uint16_t iterationNumber, uint64_t generationNumber)
{
// Create seed from seed and mode
size_t hash_seed =
Expand Down
4 changes: 3 additions & 1 deletion stickgame/src/Learn/stickGameAdversarial.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class StickGameAdversarial : public Learn::AdversarialLearningEnvironment
// Inherited via LearningEnvironment
virtual void reset(
size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override;

// Inherited via LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down
2 changes: 1 addition & 1 deletion tic-tac-toe/src/Learn/TicTacToe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void TicTacToe::randomPlay(double symbolOfPlayer) {
this->board.setDataAt(typeid(double), i, symbolOfPlayer);
}

void TicTacToe::reset(size_t seed, Learn::LearningMode mode) {
void TicTacToe::reset(size_t seed, Learn::LearningMode mode, uint16_t iterationNumber, uint64_t generationNumber) {
// Create seed from seed and mode
size_t hash_seed = Data::Hash<size_t>()(seed) ^Data::Hash<Learn::LearningMode>()(mode);
this->rng.setSeed(hash_seed);
Expand Down
4 changes: 3 additions & 1 deletion tic-tac-toe/src/Learn/TicTacToe.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ class TicTacToe : public Learn::AdversarialLearningEnvironment {
/// Inherited via LearningEnvironment
virtual void
reset(size_t seed = 0,
Learn::LearningMode mode = Learn::LearningMode::TRAINING) override;
Learn::LearningMode mode = Learn::LearningMode::TRAINING,
uint16_t iterationNumber = 0,
uint64_t generationNumber = 0) override;

/// Inherited via LearningEnvironment
virtual std::vector<std::reference_wrapper<const Data::DataHandler>>
Expand Down

0 comments on commit 419c60b

Please sign in to comment.