-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrandomForest.cpp
76 lines (67 loc) · 2.21 KB
/
randomForest.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// copyright Luca Istrate, Andrei Medar
#include "randomForest.h"
#include <iostream>
#include <random>
#include <vector>
#include <string>
#include "decisionTree.h"
using std::vector;
using std::pair;
using std::string;
using std::mt19937;
vector<vector<int>> get_random_samples(const vector<vector<int>> &samples,
int num_to_return) {
// TODO(you)
// Intoarce un vector de marime num_to_return cu elemente random,
// diferite din samples
int min = 0;
int max = samples.size() - 1;
vector<vector<int>> ret;
while (num_to_return != ret.size()){
int ok = 1;
int index = (min + (std::rand()%(int)(max - min + 1)));
for (int j = 0; j < ret.size(); ++j){
if (samples[index] == ret[j]){
ok = 0;
break;
}
}
if (ok == 1){
ret.push_back(samples[index]);
}
}
return ret;
}
RandomForest::RandomForest(int num_trees, const vector<vector<int>> &samples)
: num_trees(num_trees), images(samples) {}
void RandomForest::build() {
// Aloca pentru fiecare Tree cate n / num_trees
// Unde n e numarul total de teste de training
// Apoi antreneaza fiecare tree cu testele alese
assert(!images.empty());
vector<vector<int>> random_samples;
int data_size = images.size() / num_trees;
for (int i = 0; i < num_trees; i++) {
// cout << "Creating Tree nr: " << i << endl;
random_samples = get_random_samples(images, data_size);
// Construieste un Tree nou si il antreneaza
trees.push_back(Node());
trees[trees.size() - 1].train(random_samples);
}
}
int RandomForest::predict(const vector<int> &image) {
// TODO(you)
// Va intoarce cea mai probabila prezicere pentru testul din argument
// se va interoga fiecare Tree si se va considera raspunsul final ca
// fiind cel majoritar
vector <int> freq(10, 0);
int max = 0, index;
for (int i = 0; i < num_trees; ++i){
freq[trees[i].predict(image)]++;
if (max < freq[trees[i].predict(image)]){
max = freq[trees[i].predict(image)];
index = trees[i].predict(image);
}
}
return index;
}