-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpredictor.cpp
202 lines (173 loc) · 4.93 KB
/
predictor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#include "error.hpp"
#include "predictor.hpp"
#include "profiler.hpp"
#include "timer.h"
#include "timer.impl.hpp"
#include <algorithm>
#include <iosfwd>
#include <iostream>
#include <memory>
#include <string>
#include <typeinfo>
#include <utility>
#include <vector>
#if 0
#define DEBUG_STMT std ::cout << __func__ << " " << __LINE__ << "\n";
#else
#define DEBUG_STMT
#endif
using namespace torch;
using std::string;
extern Torch_IValue Torch_ConvertIValueToTorchIValue(torch::IValue value);
class Predictor {
public:
Predictor(const string &model_file, Torch_DeviceKind device);
void Predict(Torch_TensorContext *cInputs, int inputLength);
torch::jit::script::Module net_;
torch::IValue output_;
torch::DeviceType mode_{torch::kCPU};
profile *prof_{nullptr};
std::string profile_filename_{"profile.trace"};
bool profile_enabled_{false};
};
Predictor::Predictor(const string &model_file, Torch_DeviceKind device) {
// Load the network
net_ = torch::jit::load(model_file);
if (device == CUDA_DEVICE_KIND) mode_ = torch::kCUDA;
if (mode_ == torch::kCUDA) {
net_.to(at::kCUDA);
}
#ifdef PROFILING_ENABLED
profile_enabled_ = true;
#endif
}
void Predictor::Predict(Torch_TensorContext *cInputs, int inputLength) {
std::vector<torch::jit::IValue> inputs{};
for (int ii = 0; ii < inputLength; ii++) {
at::Tensor tensor = reinterpret_cast<Torch_Tensor *>(cInputs[ii])->tensor;
std::cout << "tensor dim = " << tensor.dim() << " size = ";
for (auto sz : tensor.sizes()) {
std::cout << sz << ", ";
}
std::cout << "\n";
inputs.emplace_back(tensor);
}
if (profile_enabled_ == true) {
autograd::profiler::RecordProfile guard(profile_filename_);
output_ = net_.forward(inputs);
return;
}
output_ = net_.forward(inputs);
}
Torch_PredictorContext Torch_NewPredictor(const char *model_file, Torch_DeviceKind mode) {
HANDLE_TH_ERRORS(Torch_GlobalError);
const auto ctx = new Predictor(model_file, mode);
return (Torch_PredictorContext)ctx;
END_HANDLE_TH_ERRORS(Torch_GlobalError, (Torch_PredictorContext)0);
}
void InitPytorch() {}
void Torch_PredictorRun(Torch_PredictorContext pred, Torch_TensorContext *cInputs, int inputLength) {
HANDLE_TH_ERRORS(Torch_GlobalError);
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return;
}
predictor->Predict(cInputs, inputLength);
END_HANDLE_TH_ERRORS(Torch_GlobalError, );
}
int Torch_PredictorNumOutputs(Torch_PredictorContext pred) {
HANDLE_TH_ERRORS(Torch_GlobalError);
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return 0;
}
if (predictor->output_.isTensor()) {
return 1;
}
if (predictor->output_.isTuple()) {
return predictor->output_.toTuple()->elements().size();
}
return 0;
END_HANDLE_TH_ERRORS(Torch_GlobalError, 0);
}
Torch_IValue Torch_PredictorGetOutput(Torch_PredictorContext pred) {
HANDLE_TH_ERRORS(Torch_GlobalError);
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return Torch_IValue{};
}
return Torch_ConvertIValueToTorchIValue(predictor->output_);
END_HANDLE_TH_ERRORS(Torch_GlobalError, Torch_IValue{});
}
void Torch_PredictorDelete(Torch_PredictorContext pred) {
HANDLE_TH_ERRORS(Torch_GlobalError);
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return;
}
if (predictor->prof_) {
predictor->prof_->reset();
delete predictor->prof_;
predictor->prof_ = nullptr;
}
delete predictor;
END_HANDLE_TH_ERRORS(Torch_GlobalError, );
}
void Torch_ProfilingStart(Torch_PredictorContext pred, const char *name, const char *metadata) {
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return;
}
if (name == nullptr) {
name = "";
}
if (metadata == nullptr) {
metadata = "";
}
if (predictor->prof_ == nullptr) {
predictor->prof_ = new profile(name, metadata);
} else {
predictor->prof_->reset();
}
}
void Torch_ProfilingEnd(Torch_PredictorContext pred) {
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return;
}
if (predictor->prof_) {
predictor->prof_->end();
}
}
void Torch_ProfilingEnable(Torch_PredictorContext pred) {
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return;
}
predictor->profile_enabled_ = true;
}
void Torch_ProfilingDisable(Torch_PredictorContext pred) {
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return;
}
if (predictor->prof_) {
predictor->prof_->reset();
}
predictor->profile_enabled_ = false;
}
char *Torch_ProfilingRead(Torch_PredictorContext pred) {
HANDLE_TH_ERRORS(Torch_GlobalError);
auto predictor = (Predictor *)pred;
if (predictor == nullptr) {
return strdup("");
}
if (predictor->prof_ == nullptr) {
return strdup("");
}
std::stringstream ss;
std::ifstream in(predictor->profile_filename_);
ss << in.rdbuf();
return strdup(ss.str().c_str());
END_HANDLE_TH_ERRORS(Torch_GlobalError, (char *)0);
}