-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo_utils.py
79 lines (61 loc) · 2.86 KB
/
demo_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from pose_modules.Mediapipe import Model
from decord import VideoReader, cpu
import numpy as np
from data_transform import *
from models.model_params import HWGATParams
class CFG():
def __init__(self):
cuda_id = 0
self.mode = 'test'
self.src_len = 192
self.feature_type = 'keypoints'
self.input_dim = 2
self.model_type = 'HWGAT'
self.device = torch.device(
f"cuda:{cuda_id}" if torch.cuda.is_available() else "cpu")
print("Running on device = ", self.device)
self.model_params = HWGATParams({'num_class' : 2002, 'src_len' : 192}, self.input_dim, self.device)
self.batch_size = 1
self.origin_idx = 0 #nose point
self.anchor_points = [3, 4] #shoulder points
self.left_slice = [9, 19, 7] #left hand slice
self.right_slice = [19, 29, 8] #right hand slice
kps = [0, 2, 5, 11, 12, 13, 14, 15, 16] + [0+33+468, 4+33+468, 5+33+468, 8+33+468, 9+33+468, 12+33+468, 13+33+468, 16+33+468, 17+33+468, 20+33+468,
0+21+33+468, 4+21+33+468, 5+21+33+468, 8+21+33+468, 9+21+33+468, 12+21+33+468, 13+21+33+468, 16+21+33+468,
17+21+33+468, 20+21+33+468]
self.test_transform = self.val_transform = Compose([MediapipeDataProcess(),
PoseSelect(kps, [0, 1]),
HandCorrection(self.left_slice, self.right_slice),
NormalizeKeypoints(self.origin_idx, self.anchor_points),
TemporalSample(self.src_len),
WindowCreate(self.src_len),
])
self.class_map_path = 'input/FDMSE/class_map_FDMSE.csv'
self.save_model_path = 'output/FDMSE/HWGAT_240402_1556/model_best_loss.pt'
def get_video_features(vid_name) -> list:
pose_model = Model()
kp_shape = (543, 4)
if type(vid_name) is str:
cap = VideoReader(vid_name, ctx=cpu(0))
else:
cap = VideoReader(vid_name, ctx=cpu(0))
num_frames = len(cap)
vid_height, vid_width = cap[0].shape[:2]
features = np.zeros((num_frames, *kp_shape))
i_th_frame = 0
for image in cap:
# saving the i-th frame feature
features[i_th_frame] = pose_model(image.asnumpy())[0]
i_th_frame += 1
return vid_name, features, i_th_frame, vid_width, vid_height
def get_video_data(video):
vid_name, features, num_frames, vid_width, vid_height = get_video_features(video)
data = {
'feat': features,
'num_frames': num_frames,
'vid_name': vid_name,
'vid_width': vid_width,
'vid_height': vid_height
}
return data