-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmakeDatasetMMAPS_WR.py
105 lines (100 loc) · 4.77 KB
/
makeDatasetMMAPS_WR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import glob
import random
from spatial_transforms import Compose, ToTensor, CenterCrop, Scale, Normalize, MultiScaleCornerCrop, RandomHorizontalFlip, Binary
def gen_split(root_dir, stackSize, train):
Dataset = []
Labels = []
NumFrames = []
ActionLabels = {}
root_dire = os.path.join(root_dir, 'processed_frames2')
for dir_user in sorted(os.listdir(root_dire)):
class_id = 0
dir1 = os.path.join(root_dire, dir_user)
if os.path.isfile(dir1):
continue
if train == True:
if dir_user == 'S1' or dir_user == 'S3' or dir_user == 'S4':
for target in sorted(os.listdir(dir1)):
if target in ActionLabels.keys():
class_id = ActionLabels[target]
else:
ActionLabels[target] = class_id
dir2 = os.path.join(dir1, target)
if os.path.isfile(dir2):
continue
insts = sorted(os.listdir(dir2))
if insts != []:
for inst in insts:
inst_dir = os.path.join(dir2, inst)
inst_dir = os.path.join(inst_dir, 'rgb')
numFrames = len(glob.glob1(inst_dir, '*.png')) - len(glob.glob1(inst_dir, '*(1).png'))
if numFrames >= stackSize:
Dataset.append(inst_dir)
Labels.append(class_id)
NumFrames.append(numFrames)
class_id += 1
else:
if dir_user == 'S2':
for target in sorted(os.listdir(dir1)):
if target in ActionLabels.keys():
class_id = ActionLabels[target]
else:
ActionLabels[target] = class_id
dir2 = os.path.join(dir1, target)
if os.path.isfile(dir2):
continue
insts = sorted(os.listdir(dir2))
if insts != []:
for inst in insts:
inst_dir = os.path.join(dir2, inst)
inst_dir = os.path.join(inst_dir, 'rgb')
numFrames = len(glob.glob1(inst_dir, '*.png')) - len(glob.glob1(inst_dir, '*(1).png'))
if numFrames >= stackSize:
Dataset.append(inst_dir)
Labels.append(class_id)
NumFrames.append(numFrames)
class_id += 1
return Dataset, Labels, NumFrames
class makeDataset(Dataset):
def __init__(self, root_dir, spatial_transform=None, seqLen=20,
train=True, mulSeg=False, numSeg=1, fmt='.png', regression=True):
normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
self.images, self.labels, self.numFrames = gen_split(root_dir, 5, train) # vedi sopra
self.main_spatial_transform = spatial_transform # transformation di data augmentation
self.spatial_transform_rgb = Compose([self.main_spatial_transform, ToTensor(), normalize])
if regression == False:
self.spatial_transform_mmaps = Compose([self.main_spatial_transform, Scale(7), ToTensor(), Binary(0.4)])
else:
self.spatial_transform_mmaps = Compose([self.main_spatial_transform, Scale(7), ToTensor()])
self.train = train
self.mulSeg = mulSeg
self.numSeg = numSeg
self.seqLen = seqLen
self.fmt = fmt
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
vid_name = self.images[idx]
label = self.labels[idx]
numFrame = self.numFrames[idx]
inpSeq = []
inpSeqMap = []
self.main_spatial_transform.randomize_parameters()
mmaps_vid_name = vid_name.replace('rgb', 'mmaps')
for i in np.linspace(1, numFrame, self.seqLen, endpoint=False):
fl_name = vid_name + '/' + 'rgb' + str(int(np.floor(i))).zfill(4) + self.fmt
fm_name = mmaps_vid_name + '/' + 'map' + str(int(np.floor(i))).zfill(4) + self.fmt
if os.path.exists(fm_name) == False:
fm_name = mmaps_vid_name + '/' + 'map' + str(int(np.floor(i+1))).zfill(4) + self.fmt
img = Image.open(fl_name)
img_map = Image.open(fm_name)
inpSeq.append(self.spatial_transform_rgb(img.convert('RGB')))
inpSeqMap.append(self.spatial_transform_mmaps(img_map.convert('L')))
inpSeq = torch.stack(inpSeq, 0)
inpSeqMap = torch.stack(inpSeqMap, 0)
return inpSeq, inpSeqMap, label