-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlearner.h
76 lines (67 loc) · 3.2 KB
/
learner.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// Header file for learner class
//
// Copyright (C) 2012 Heidelberg University
//
// Author: Sascha Fendrich
//
// This file is part of Sol.
//
// Sol is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Sol is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with Sol. If not, see <http://www.gnu.org/licenses/>.
#ifndef LEARNER_H
#define LEARNER_H
#include <string>
#include <boost/program_options.hpp>
#include "data_set.h"
#include "model.h"
class Learner
{
public:
typedef enum { kRegNone, kRegL1, kRegL2 } RegType; // Regulatization type
Learner (); // Constructor
int Init (int argc, char **argv); // Initialize options
int Run (); // Run learning process
private:
void Learn (const DataSet &data_set); // SGD loop
virtual bool SingleUpdate (const DataSet &data_set) = 0; // Loss-update
virtual void Evaluate (const DataSet &data_set) = 0; // Evaluation
protected:
boost::program_options::options_description options_; // Program options
Model model_; // Learning model
bool learn_; // Learn model on input data
bool evaluate_; // Evaluate model on input data
bool print_result_; // Print evaluation result to std::cout
bool print_predictions_; // Print predictions to std::cout
bool pegasos_projection_; // Use pegasos L2-ball projection
std::string data_in_; // Read data from file
std::string model_in_; // Read an initial model from file
std::string model_out_; // Write model to file
bool write_intermediate_models_; // Write model at each iteration
bool decreasing_lr_; // Use decreasing learning rate
float initial_learning_rate_; // Initial learning rate
float learning_rate_; // Current learning rate
float margin_; // Margin
RegType reg_type_; // Regularization type
float reg_param_; // Regularization parameter
int reg_interval_; // Updates between regularization steps
int num_classes_; // Number of classes (multi-class)
int num_labels_; // Number of labels (multi-label)
int num_features_; // Number of features
int num_iterations_; // Number of iterations
int num_instances_; // Number of instances in data set
int num_submodels_; // Number of submodels
int progress_interval_; // Updates between progress reports
unsigned random_seed_; // Random seed
};
std::istream& operator>> (std::istream& in, Learner::RegType& reg_type);
#endif