-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.h
55 lines (50 loc) · 1.5 KB
/
model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#ifndef MODEL_H
#define MODEL_H
#include "utils.h"
struct model_parameter
{
double min_delta_loss = -1;
int loss_eval_circle = 5;
int max_iterations = 200;
double loss = -1;
double delta_loss = 1e10;
int verbose = 1;
int current_iteration = -1;
double training_acc = -1;
double max_acc = -1;
double learning_rate = 0.01;
int batch_size = 1;
};
template <typename ParamStruct>
class Model
{
private:
public:
Model(){}
Model(shared_ptr<DataLoader> _data_loader): dataloader(_data_loader){}
~Model(){}
virtual void train() = 0;
// virtual Eigen::MatrixXd predict(const Eigen::MatrixXd &input, bool post_process) = 0;
void printVerbose()
{
std::cout << param.current_iteration << "/" << param.max_iterations << "\t";
std::cout << "loss: " << param.loss << "\t" << "acc: " << int(param.training_acc*10000)/100. << "%";
std::cout <<"\t" << "max: " << int(param.max_acc*10000)/100. << "%" << std::endl;
}
template<typename T>
void setParam(const std::string param_flag, T value)
{
if(param_flag == "min_delta_loss")
param.min_delta_loss = value;
else if(param_flag == "loss_eval_circle")
param.loss_eval_circle = value;
else if(param_flag == "max_iterations")
param.max_iterations = value;
else if(param_flag == "verbose")
param.verbose = value;
}
protected:
shared_ptr<DataLoader> dataloader = nullptr;
ParamStruct param;
};
#endif