-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpytorch2onnx.py
41 lines (26 loc) · 1.24 KB
/
pytorch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from net.st_gcn import Model
import torch
import torch.nn as nn
import torch.onnx
import numpy as np
weights_path = 'models/epoch50_model.pt'
onnx_model = "stgcn.onnx"
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
if __name__ == '__main__':
device = torch.device("cpu")
model = Model(in_channels=3, num_class=400, edge_importance_weighting=True, graph_args={'layout': 'openpose', 'strategy': 'spatial'})
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
input_data = np.random.randn(1, 36, 3, 300).astype(np.float32)
dummy_input = torch.from_numpy(input_data).float().to(device)
torch.onnx.export(model, # model being run
(dummy_input), #, dummy_A), # model input (or a tuple for multiple inputs)
onnx_model, #where to save the model (can be a file or file-like object)
verbose=True,
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
)
print('model generated')