-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampleuffmnist.cpp
250 lines (216 loc) · 7.5 KB
/
sampleuffmnist.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#include"sampleuffmnist.h"
//!
//! sampleUffMNIST.cpp
//! This file contains the implementation of the Uff MNIST sample.
//! It creates the network using the MNIST model converted to uff.
//!
//! It can be run with the following command line:
//! Command: ./sample_uff_mnist [-h or --help] [-d or --datadir=<path to data directory>] [--useDLACore=<int>]
//!
//!
//! \brief Creates the network, configures the builder and creates the network engine
//!
//! \details This function creates the MNIST network by parsing the Uff model
//! and builds the engine that will be used to run MNIST (mEngine)
//!
//! \return Returns true if the engine was created successfully and false otherwise
//!
bool SampleUffMNIST::build()
{
//建立一个builder
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(
nvinfer1::createInferBuilder(gLogger.getTRTLogger()));
if (!builder)
{
return false;
}
//使用builder建立network
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
if (!network)
{
return false;
}
//创建一个builder的配置对象
auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
if (!config)
{
return false;
}
//创建一个uffparseer
auto parser = SampleUniquePtr<nvuffparser::IUffParser>(nvuffparser::createUffParser());
if (!parser)
{
return false;
}
//调用成员方法constructNetwork方法进行创建网络
constructNetwork(parser, network);
//The maximum batch size which can be used at execution time, and also the batch size for which the engine will be optimized.
builder->setMaxBatchSize(mParams.batchSize);
config->setMaxWorkspaceSize(16_MiB);
config->setFlag(BuilderFlag::kGPU_FALLBACK);
if (mParams.fp16)
{
config->setFlag(BuilderFlag::kFP16);
}
if (mParams.int8)
{
config->setFlag(BuilderFlag::kINT8);
}
samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
builder->buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
if (!mEngine)
{
return false;
}
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 3);
return true;
}
//!
//! \brief Uses a Uff parser to create the MNIST Network and marks the output layers
//!
//! \param network Pointer to the network that will be populated with the MNIST network
//!
//! \param builder Pointer to the engine builder
//!
//配置parse的参数后调用parse方法将network填满
void SampleUffMNIST::constructNetwork(
SampleUniquePtr<nvuffparser::IUffParser>& parser,
SampleUniquePtr<nvinfer1::INetworkDefinition>& network)
{
// There should only be one input and one output tensor
assert(mParams.inputTensorNames.size() == 1);
assert(mParams.outputTensorNames.size() == 1);
// Register tensorflow input
parser->registerInput(mParams.inputTensorNames[0].c_str(),
nvinfer1::Dims3(1, 28, 28),
nvuffparser::UffInputOrder::kNCHW);
parser->registerOutput(mParams.outputTensorNames[0].c_str());
//模型的权重的精度kFLOAT=FP32 format
//network将被uff parse填满
parser->parse(mParams.uffFileName.c_str(), *network, nvinfer1::DataType::kFLOAT);
//如果要求使用int8的推理模型那么
if (mParams.int8)
{
samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
}
}
//!
//! \brief Reads the input data, preprocesses, and stores the result in a managed buffer
//!
bool SampleUffMNIST::processInput(const samplesCommon::BufferManager& buffers,
const std::string& inputTensorName,
int inputFileIdx) const
{
const int inputH = mInputDims.d[1];
const int inputW = mInputDims.d[2];
std::vector<uint8_t> fileData(inputH * inputW);
readPGMFile(
locateFile(std::to_string(inputFileIdx) + ".pgm", mParams.dataDirs),
fileData.data(), inputH, inputW);
// Print ASCII representation of digit
gLogInfo << "Input:\n";
for (int i = 0; i < inputH * inputW; i++)
{
gLogInfo << (" .:-=+*#%@"[fileData[i] / 26])
<< (((i + 1) % inputW) ? "" : "\n");
}
gLogInfo << std::endl;
float* hostInputBuffer = static_cast<float*>(buffers.getHostBuffer(inputTensorName));
for (int i = 0; i < inputH * inputW; i++)
{
hostInputBuffer[i] = 1.0 - float(fileData[i]) / 255.0;
}
return true;
}
//!
//! \brief Verifies that the inference output is correct
//!
bool SampleUffMNIST::verifyOutput(const samplesCommon::BufferManager& buffers,
const std::string& outputTensorName,
int groundTruthDigit) const
{
const float* prob = static_cast<const float*>(buffers.getHostBuffer(outputTensorName));
gLogInfo << "Output:\n";
float val{0.0f};
int idx{0};
// Determine index with highest output value
for (int i = 0; i < kDIGITS; i++)
{
if (val < prob[i])
{
val = prob[i];
idx = i;
}
}
// Print output values for each index
for (int j = 0; j < kDIGITS; j++)
{
gLogInfo << j << "=> " << setw(10) << prob[j] << "\t : ";
// Emphasize index with highest output value
if (j == idx)
{
gLogInfo << "***";
}
gLogInfo << "\n";
}
gLogInfo << std::endl;
return (idx == groundTruthDigit);
}
//!
//! \brief Runs the TensorRT inference engine for this sample
//!
//! \details This function is the main execution function of the sample.
//! It allocates the buffer, sets inputs, executes the engine, and verifies the output.
//!
bool SampleUffMNIST::infer()
{
// Create RAII buffer manager object
samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);
auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(
mEngine->createExecutionContext());
if (!context)
{
return false;
}
bool outputCorrect = true;
float total = 0;
// Try to infer each digit 0-9
for (int digit = 0; digit < kDIGITS; digit++)
{
if (!processInput(buffers, mParams.inputTensorNames[0], digit))
{
return false;
}
// Copy data from host input buffers to device input buffers
buffers.copyInputToDevice();
const auto t_start = std::chrono::high_resolution_clock::now();
// Execute the inference work
if (!context->execute(mParams.batchSize,
buffers.getDeviceBindings().data()))
{
return false;
}
const auto t_end = std::chrono::high_resolution_clock::now();
const float ms = std::chrono::duration<float, std::milli>(t_end - t_start).count();
total += ms;
// Copy data from device output buffers to host output buffers
buffers.copyOutputToHost();
// Check and print the output of the inference
outputCorrect &= verifyOutput(buffers, mParams.outputTensorNames[0], digit);
}
total /= kDIGITS;
gLogInfo << "Average over " << kDIGITS << " runs is " << total << " ms."
<< std::endl;
return outputCorrect;
}
//!
//! \brief Used to clean up any state created in the sample class
//!
bool SampleUffMNIST::teardown()
{
nvuffparser::shutdownProtobufLibrary();
return true;
}