-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cpp
125 lines (103 loc) · 3.21 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <string>
#include <boost/python.hpp>
#include <boost/python/numpy.hpp>
#include "CompiledNN.h"
#include "Model.h"
namespace py = boost::python;
namespace np = py::numpy;
struct ModelWrapper
{
NeuralNetwork::Model model;
ModelWrapper() = default;
explicit ModelWrapper(const py::object &file)
{
load(file);
}
void setInputUInt8(std::size_t index)
{
model.setInputUInt8(index);
}
bool isInputUInt8(std::size_t index) const
{
return model.isInputUInt8(index);
}
void clear()
{
model.clear();
}
void load(const py::object &file)
{
model.load(py::extract<std::string>(py::str{file}));
}
};
struct CompiledNNWrapper
{
NeuralNetwork::CompiledNN compiledNN;
py::object tensorOwner;
void compile(const ModelWrapper &model)
{
compiledNN.compile(model.model);
}
bool valid() const
{
return compiledNN.valid();
}
std::size_t numOfInputs() const
{
return compiledNN.numOfInputs();
}
np::ndarray tensor2ndarray(NeuralNetwork::TensorXf &tensor)
{
const auto &shape{tensor.dims()};
assert(shape.size() <= 4);
switch (shape.size())
{
case 1:
return np::from_data(tensor.data(), np::dtype::get_builtin<float>(), py::make_tuple(shape[0]), py::make_tuple(sizeof(float)), tensorOwner);
case 2:
return np::from_data(tensor.data(), np::dtype::get_builtin<float>(), py::make_tuple(shape[0], shape[1]), py::make_tuple(shape[1] * sizeof(float), sizeof(float)), tensorOwner);
case 3:
return np::from_data(tensor.data(), np::dtype::get_builtin<float>(), py::make_tuple(shape[0], shape[1], shape[2]), py::make_tuple(shape[1] * shape[2] * sizeof(float), shape[2] * sizeof(float), sizeof(float)), tensorOwner);
case 4:
return np::from_data(tensor.data(), np::dtype::get_builtin<float>(), py::make_tuple(shape[0], shape[1], shape[2], shape[3]), py::make_tuple(shape[1] * shape[2] * shape[3] * sizeof(float), shape[2] * shape[3] * sizeof(float), shape[3] * sizeof(float), sizeof(float)), tensorOwner);
case 0:
default:
return np::empty(py::make_tuple(), np::dtype::get_builtin<float>());
}
}
np::ndarray input(std::size_t index)
{
return tensor2ndarray(compiledNN.input(index));
}
std::size_t numOfOutputs() const
{
return compiledNN.numOfOutputs();
}
np::ndarray output(std::size_t index)
{
return tensor2ndarray(compiledNN.output(index));
}
void apply() const
{
compiledNN.apply();
}
};
BOOST_PYTHON_MODULE(PyCompiledNN)
{
Py_Initialize();
np::initialize();
py::class_<ModelWrapper, boost::noncopyable>("Model")
.def(py::init<py::object>())
.def("setInputUInt8", &ModelWrapper::setInputUInt8)
.def("isInputUInt8", &ModelWrapper::isInputUInt8)
.def("clear", &ModelWrapper::clear)
.def("load", &ModelWrapper::load);
py::class_<CompiledNNWrapper, boost::noncopyable>("CompiledNN")
.def("compile", &CompiledNNWrapper::compile)
.def("valid", &CompiledNNWrapper::valid)
.def("numOfInputs", &CompiledNNWrapper::numOfInputs)
.def("input", &CompiledNNWrapper::input)
.def("numOfOutputs", &CompiledNNWrapper::numOfOutputs)
.def("output", &CompiledNNWrapper::output)
.def("apply", &CompiledNNWrapper::apply);
}