diff --git a/.gitignore b/.gitignore index b6e4761..13cab7c 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,11 @@ dmypy.json # Pyre type checker .pyre/ + +.DS_Store + +.pth +.pt +.pth.tar +model_param/ResNet101_flow_pretrain.pth.tar +model_param/ResNet101_rgb_pretrain.pth.tar diff --git a/chopstick_fusion_attention_transition.py b/chopstick_fusion_attention_transition.py new file mode 100644 index 0000000..6df61c5 --- /dev/null +++ b/chopstick_fusion_attention_transition.py @@ -0,0 +1,233 @@ +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader + +from dataset.chopstick_dataset import ChopstickDataset, ChopstickDataset_Pair +from dataset.chopstick_dataset import get_flow_feature_dict, get_rgb_feature_dict +from common import train, test, save_best_result + +import os +from os.path import join, isdir, isfile, exists +import argparse +import csv + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default='full', + choices=['full', 'only_x', 'only_htop', 'fc_att', 'no_att', 'cbam', 'sca', 'video_lstm', 'visual']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', choices=['resnet101_conv5', 'resnet101_conv4']) +parser.add_argument("--epoch_num", type=int, default=30) +parser.add_argument("--split_index", type=int, default=0, choices=[0,1,2,3,4]) +parser.add_argument("--label", type=str, default='Full_model') + +args = parser.parse_args() + +''' +class model (nn.Module): + def __init__ (self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2,2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1,1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.rnn_att = nn.GRUCell(self.x_size+self.rnn_top_size, self.rnn_att_size) + for param in self.rnn_att.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.a_size = 32 + self.xa_fc = nn.Linear(self.x_size, self.a_size, bias=True) + self.ha_fc = nn.Linear(self.rnn_att_size, self.a_size, bias=True) + self.a_fc = nn.Linear(self.a_size, 1, bias=False) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + + self.x_ln = nn.LayerNorm(self.x_size) + self.h_ln = nn.LayerNorm(self.rnn_top_size) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.1) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward (self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.randn(batch_size, self.rnn_top_size).to(video_tensor.device) + h_att = torch.randn(batch_size, self.rnn_att_size).to(video_tensor.device) + for frame_idx in range(seq_len): + featmap = video_tensor[:,frame_idx,:,:,:] #batch_size x 2D x 14 x 14 + + X = self.relu(self.pre_conv1(featmap)) #batch_size x C x 7 x 7 + X = self.pre_conv2(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) #batch_size x C + x_max = self.x_maxpool(X).view(batch_size, -1) + + rnn_att_in = torch.cat((self.x_ln(x_avg+x_max),self.h_ln(h_top)), dim=1) +# rnn_att_in = torch.cat((x_avg+x_max, h_top), dim=1) + h_att = self.rnn_att(rnn_att_in, h_att) #batch_size x rnn_att_size + + X_tmp = X.view(batch_size, self.x_size, -1).transpose(1,2) #batch_size x 49 x C + h_att_tmp = h_att.unsqueeze(1).expand(-1,X_tmp.size(1),-1) #batch_size x 49 x rnn_att_size + + a = self.tanh(self.xa_fc(X_tmp)+self.ha_fc(h_att_tmp)) + a = self.a_fc(a).unsqueeze(2) #batch_size x 49 + alpha = self.softmax(a) + s_att = alpha.view(batch_size, 1, X.size(2), X.size(3)) + video_soft_att.append(s_att) + + X = X * s_att #batch_size x C x 7 x 7 + rnn_top_in = torch.sum(X.view(batch_size, self.x_size, -1), dim=2) #batch_size x C x 7 x 7 + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + video_soft_att = torch.stack(video_soft_att, dim=1) #batch_size x seq_len x 1 x 14 x 14 + video_tmpr_att = torch.zeros(batch_size, seq_len) + return final_score, video_soft_att, video_tmpr_att +''' + +def read_model(model_type, feature_type, num_seg): + feature_size = 2048 if feature_type == 'resnet101_conv5' else 1024 + if model_type in ['full', 'only_x', 'only_htop', 'fc_att', 'no_att']: + from model_def.Spa_Att import model + return model(feature_size, num_seg, variant=model_type) + elif model_type in ['cbam']: + from model_def.CBAM_Att import model + return model(feature_size, num_seg) + elif model_type in ['sca']: + from model_def.SCA_Att import model + return model(feature_size, num_seg) + elif model_type in ['video_lstm']: + from model_def.VideoLSTM import model + return model(feature_size, num_seg) + elif model_type in ['visual']: + from model_def.Visual_Att import model + return model(feature_size, num_seg) + else: + raise Exception(f'Unsupport model type of {model_type}.') + +def get_train_test_pairs_dict (annotation_dir, split_idx): + train_pairs_dict = {} + train_videos = set() + train_csv = join(annotation_dir, 'chopstick_using_train_'+format(split_idx, '01d')+'.csv') + with open(train_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + train_pairs_dict[key] = 1 + train_videos.update(key) + csvfile.close() + + test_pairs_dict = {} + test_videos = set() + test_csv = join(annotation_dir, 'chopstick_using_val_'+format(split_idx, '01d')+'.csv') + with open(test_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + test_pairs_dict[key] = 1 + if row[0] not in train_videos: + test_videos.add(row[0]) + if row[1] not in train_videos: + test_videos.add(row[1]) + csvfile.close() + + return train_pairs_dict, test_pairs_dict, train_videos, test_videos + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset_dir = '../dataset/ChopstickUsing/ChopstickUsing_Stationary_800x450' + annotation_dir = '../dataset/ChopstickUsing/ChopstickUsing_Annotation/splits' + + video_name_list = os.listdir(dataset_dir) + video_rgb_feature_dict = get_rgb_feature_dict(dataset_dir, args.feature_type) + video_flow_feature_dict = get_flow_feature_dict(dataset_dir, args.feature_type) + + best_acc_keeper = [] + for split_idx in range(1, 5): + print("Split: "+format(split_idx, '01d')) + train_pairs_dict, test_pairs_dict, train_videos, test_videos = get_train_test_pairs_dict(annotation_dir, split_idx) + + num_seg = 25 + dataset_train = ChopstickDataset_Pair('fusion', video_rgb_feature_dict, video_flow_feature_dict, + train_pairs_dict, seg_sample=num_seg) + dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True) + + dataset_test = ChopstickDataset('fusion', video_rgb_feature_dict, video_flow_feature_dict, + video_name_list, seg_sample=num_seg) + dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + + model_ins = read_model(args.model, args.feature_type, num_seg) + + save_label = f'Chopstick/{args.model}/{split_idx:01d}' + + best_acc = 0.0 + if args.continue_train: + ckpt_dir = join('checkpoints', save_label, + 'best_checkpoint.pth.tar') + if exists(checkpoint): + checkpoint = torch.load(ckpt_dir) + model_ins.load_state_dict(checkpoint['state_dict']) + best_acc = checkpoint['best_acc'] + print("Start from previous checkpoint, with rank_cor: {:.4f}".format( + checkpoint['best_acc'])) + else: + print("No previous checkpoint. \nStart from scratch.") + else: + print("Start from scratch.") + + + model_ins.to(device) + + criterion = nn.MarginRankingLoss(margin=0.5) + + # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_ins.parameters()), + # lr=5e-6, weight_decay=0, amsgrad=False) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ins.parameters()), + lr=5e-4, momentum=0.9, weight_decay=1e-2) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2) + + min_loss = 1.0 + no_imprv = 0 + for epoch in range(args.epoch_num): + train(dataloader_train, model_ins, criterion, optimizer, epoch, device) + epoch_loss, epoch_acc = test(dataloader_test, test_pairs_dict, model_ins, criterion, epoch, device) + + if epoch_acc >= best_acc: + best_acc = epoch_acc + save_best_result(dataloader_test, test_videos, + model_ins, device, best_acc, save_label) + + if epoch_loss <= min_loss: + min_loss = epoch_loss + no_imprv = 0 + else: + no_imprv += 1 + print('Best acc: {:.3f}'.format(best_acc)) + # if no_imprv > 3: + # break + best_acc_keeper.append(best_acc) + + for split_idx, best_acc in enumerate(best_acc_keeper): + print(f'Split: {split_idx+1}, {best_acc:.4f}') + print('Avg:', '{:.4f}'.format(sum(best_acc_keeper)/4)) diff --git a/common.py b/common.py new file mode 100644 index 0000000..49906a6 --- /dev/null +++ b/common.py @@ -0,0 +1,204 @@ +import torch + +import os +from os.path import join + +from tqdm import tqdm +import cv2 +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image + +from utils.ImageShow import voxel_tensor_to_np, overlap_maps_on_voxel_np, img_np_show + +################################################## +# Train & Test # +################################################## +def train(dataloader, model, criterion, optimizer, epoch, device): + model.train() + + running_loss = 0.0 + running_acc = 0.0 + + for batch_idx, pair in enumerate(tqdm(dataloader)): + v1_tensor = pair['video1']['video'].to(device) + v2_tensor = pair['video2']['video'].to(device) + label = pair['label'].to(device).squeeze(1) + + v1_score, v1_satt, v1_tatt = model(v1_tensor) + v2_score, v2_satt, v2_tatt = model(v2_tensor) + + loss = criterion(v1_score, v2_score, label) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + running_acc += torch.nonzero((label*(v1_score-v2_score)) + > 0).size(0) / v1_tensor.size(0) + + epoch_loss = running_loss / len(dataloader) + epoch_acc = running_acc / len(dataloader) + print(f'Train: Epoch {epoch}, Loss:{epoch_loss:.4f}, Acc:{epoch_acc:.4f}') + + +def test(dataloader, pairs_dict, model, criterion, epoch, device): + model.eval() + + videos_score = {} + videos_satt = {} + for batch_idx, sample in enumerate(tqdm(dataloader)): + video_name = sample['name'] + v_tensor = sample['video'].to(device) + sampled_idx_list = sample['sampled_index'] + + with torch.no_grad(): + v_score, v_satt, v_tatt = model(v_tensor) + for i in range(v_tensor.size(0)): + videos_score[video_name[i]] = v_score[i].unsqueeze(0) + videos_satt[video_name[i]] = v_satt[i].unsqueeze(0) + + running_loss = 0.0 + running_acc = 0.0 + + pairs_list = list(pairs_dict.keys()) + for v1_name, v2_name in pairs_list: + v1_score = videos_score[v1_name] + v2_score = videos_score[v2_name] + v1_satt = videos_satt[v1_name] + v2_satt = videos_satt[v2_name] + label = torch.Tensor([pairs_dict[(v1_name, v2_name)]]).to(device) + + loss = criterion(v1_score, v2_score, label) + + running_loss += loss.item() + running_acc += torch.nonzero((label*(v1_score-v2_score)) > 0).size(0) + + epoch_loss = running_loss / len(pairs_list) + epoch_acc = running_acc / len(pairs_list) + + print(f'Test: Epoch {epoch}, Loss:{epoch_loss:.4f}, Acc:{epoch_acc:.4f}') + + return epoch_loss, epoch_acc + + +def save_best_result(dataloader, test_videos, model, device, best_acc, save_label): + model.eval() + + best_checkpoint = {'state_dict': model.state_dict(), 'best_acc': best_acc} + ckpt_dir = join('checkpoints', save_label) + os.makedirs(ckpt_dir, exist_ok=True) + torch.save(best_checkpoint, join(ckpt_dir, 'best_checkpoint.pth.tar')) + + videos_score = {} + htmp_dir = join('heatmaps', save_label) + os.makedirs(htmp_dir, exist_ok=True) + for batch_idx, sample in enumerate(dataloader): + video_name = sample['name'] + v_tensor = sample['video'].to(device) + sampled_idx_list = sample['sampled_index'] + + with torch.no_grad(): + v_score, v_satt, v_tatt = model(v_tensor) + + for i in range(v_tensor.size(0)): + pred_score = v_score[i].item() + videos_score[video_name[i]] = pred_score + + if video_name[i] in test_videos: + plot_video_heatmaps(v_tensor[i], v_satt[i], title=f'{video_name[i]} Pred:{pred_score:.4f}', + save_path=join(htmp_dir, f'{video_name[i]}.jpg')) + + score_dir = join('pred_scores', save_label) + os.makedirs(score_dir, exist_ok=True) + with open(join(score_dir, 'scores.txt'), 'w') as f: + score = v_score.detach().cpu().item() + f.writeline(f'{score:.4f}\n') + f.close() + +def plot_video_heatmaps (video_tensor, heatmap, title=None, save_path=None, save_separately=False): + # video_tensor: 3xLx112x112 + # heatmap: 1xLx112x112 + num_timesteps = video_tensor.shape[1] + assert num_timesteps == heatmap.shape[1] + + video_imgs = voxel_tensor_to_np(video_tensor) # np, 0~1, 3xLx112x112 + video_imgs_uint = np.uint8(video_imgs * 255) + + if torch.is_tensor(heatmap): + heatmap = heatmap.squeeze(0).numpy() # np, 0~1, Lx112x112 + else: + heatmap = heatmap[0] + + overlaps = overlap_maps_on_voxel_np(video_imgs, heatmap) # np, 0~1, 3xLx112x112 + overlaps_uint = np.uint8(overlaps * 255) + + if save_separately and save_path != None: + separate_save_dir = os.path.splitext(save_path)[0] + os.makedirs(separate_save_dir, exist_ok=True) + + # save plot imgs, explanation heatmap + num_subline = 2 + num_row = num_subline * ( (num_timesteps-1) // 8 + 1 ) + plt.clf() + fig = plt.figure(figsize=(16,num_row*2)) + for i in range(num_timesteps): + plt.subplot(num_row, 8, (i//8)*8*num_subline+i%8+1) + img_np_show(video_imgs_uint[:,i]) + plt.title(i, fontsize=8) + + plt.subplot(num_row, 8, (i//8)*8*num_subline+i%8+8+1) + img_np_show(overlaps_uint[:,i]) + + if save_separately: + video_img = Image.fromarray(video_imgs_uint[:,i].transpose(1,2,0)) + video_img.save(os.path.join(separate_save_dir, f'img_{i}.jpg')) + exp_img = Image.fromarray(overlaps_uint[:,i].transpose(1,2,0)) + exp_img.save(os.path.join(separate_save_dir, f'exp_{i}.jpg')) + + if title != None: + fig.suptitle(title, fontsize=14) + + if save_path != None: + save_dir = os.path.dirname(os.path.abspath(save_path)) + os.makedirs(save_dir, exist_ok=True) + + ext = os.path.splitext(save_path)[1].strip('.') + plt.savefig(save_path, format=ext, bbox_inches='tight') + + plt.close(fig) + +# # batched_heatmaps: batch_size x seq_len x 1 x 7 x 7 +# def save_heatmaps(batched_inputs, batched_heatmaps, save_dir, size, video_name, rand_idx_list, t_att): +# batch_size = batched_heatmaps.size(0) +# seq_len = batched_heatmaps.size(1) + +# for batch_offset in range(batch_size): +# att_save_dir = join(save_dir, video_name[batch_offset]) +# os.makedirs(att_save_dir, exist_ok=True) + +# dataset_name, video_name = video_name[batch_offset].split('_') +# dataset_dir = dataset_dir = join( +# '../dataset', dataset_name, dataset_name+'_640x480') +# ori_frames_dir = join(dataset_dir, video_name, 'frame') + +# for seq_idx in range(seq_len): +# frame_idx = int(rand_idx_list[seq_idx][batch_offset].item()) + +# heatmap = batched_heatmaps[batch_offset, seq_idx, 0, :, :] +# heatmap = (heatmap-heatmap.min()) / (heatmap.max()-heatmap.min()) +# heatmap = np.array(heatmap*255.0).astype(np.uint8) +# heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) +# heatmap = cv2.resize(heatmap, size) + +# ori_frame = cv2.imread( +# join(ori_frames_dir, format(frame_idx, '05d')+'.jpg')) +# ori_frame = cv2.resize(ori_frame, size) + +# comb = cv2.addWeighted(ori_frame, 0.6, heatmap, 0.4, 0) +# t_att_value = t_att[batch_offset, seq_idx].item() +# pic_save_dir = join(att_save_dir, format( +# frame_idx, '05d')+'_'+format(t_att_value, '.2f')+'.jpg') +# cv2.imwrite(pic_save_dir, comb) + diff --git a/dataset/chopstick_dataset.py b/dataset/chopstick_dataset.py new file mode 100644 index 0000000..012bd18 --- /dev/null +++ b/dataset/chopstick_dataset.py @@ -0,0 +1,283 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +import utility +from Models.ResNet import resnet101, resnet34 + +import os +from os.path import join, isdir, isfile +import glob +import math + +from PIL import Image +import pickle +from scipy.io import loadmat +from scipy.stats import spearmanr +from scipy.ndimage.filters import gaussian_filter +import numpy as np +from tqdm import tqdm +import cv2 +import csv +import time + +def extract_save_rgb_feature (dataset_dir): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_rgb_extractor = resnet101(pretrained=True, channel=3, output='conv5').to(device) + pretrained_model = torch.load('Models/ResNet101_rgb_pretrain.pth.tar') + resnet_rgb_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_rgb_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ + transforms.Resize((448,448), interpolation=3), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_dir = join(video_dir, 'frame') + + seq_len = len(os.listdir(video_frames_dir)) + video_frames_tensor = torch.stack([ + transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) + for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 + video_frames_tensor = video_frames_tensor.unsqueeze(0) + + # get video's resnet feature maps + video_feature = [] + for idx in range(5, seq_len-5, 1): + batched_frames = video_frames_tensor[:,idx,:,:,:].to(device) # 1x3x448x448 + with torch.no_grad(): + batched_feature = resnet_rgb_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len x 2048 x 14 x 14 + torch.save(video_feature, join(video_dir, 'feature', 'resnet_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') +# utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/chopstick_resnet101_conv5/', +# (320,180), video_name, dataset_dir) + return videos_feature_dict + +def extract_save_flow_feature (dataset_dir): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_flow_extractor = resnet101(pretrained=True, channel=20, output='conv5').to(device) + pretrained_model = torch.load('Models/ResNet101_flow_pretrain.pth.tar') + resnet_flow_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_flow_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ + transforms.Resize((448,448), interpolation=3), + transforms.ToTensor(), + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_tensor = [] + + seq_len = int( len(os.listdir(join(video_dir, 'flow_tvl1'))) / 2 ) + for frame_idx in range(1, seq_len+1): + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_x_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_y_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor = torch.cat(video_frames_tensor, 0).unsqueeze(0) #1 x seq_lenx2 x 448 x 448 + + # get video's resnet feature maps, 5-times down-sample + video_feature = [] + for idx in range(0, seq_len-9, 1): + batched_frames = video_frames_tensor[:,2*idx:2*(idx+10),:,:].to(device) # 1x20x448x448 + with torch.no_grad(): + batched_feature = resnet_flow_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len-9 x 2048 x 14 x 14 + + feature_save_dir = join(video_dir, 'feature') + if not os.path.isdir(feature_save_dir): + os.system('mkdir -p '+feature_save_dir) + torch.save(video_feature, join(feature_save_dir, 'resnet_flow_10_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') +# utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/chopstick_resnet101_10_pretrain_conv5/', +# (320,180), video_name, dataset_dir) + return videos_feature_dict + +def get_rgb_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature +# print(video_name, video_feature.size(0)) + return videos_feature_dict + +def get_flow_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature +# print(video_name, video_feature.size(0)) + return videos_feature_dict + +def get_train_test_pairs_dict(annotation_dir, split_idx): + train_pairs_dict = {} + train_videos = set() + train_csv = join(annotation_dir, 'chopstick_using_train_' + + format(split_idx, '01d')+'.csv') + with open(train_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + train_pairs_dict[key] = 1 + train_videos.update(key) + csvfile.close() + + test_pairs_dict = {} + test_videos = set() + test_csv = join(annotation_dir, 'chopstick_using_val_' + + format(split_idx, '01d')+'.csv') + with open(test_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + test_pairs_dict[key] = 1 + if row[0] not in train_videos: + test_videos.add(row[0]) + if row[1] not in train_videos: + test_videos.add(row[1]) + csvfile.close() + + return train_pairs_dict, test_pairs_dict, train_videos, test_videos +######################################################################################################################### +# Chopstick_Dataset # +######################################################################################################################### +def video_sample (video_tensor, rand_idx_list): + sampled_video_tensor = torch.stack([video_tensor[idx,:,:,:] for idx in rand_idx_list], dim=0) + return sampled_video_tensor + +class ChopstickDataset (Dataset): + def __init__ (self, f_type, video_rgb_feature_dict, video_flow_feature_dict, video_name_list, seg_sample=None): + self.f_type = f_type + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + self.video_name_list = video_name_list + self.dataset_dir = '../dataset/ChopstickUsing/ChopstickUsing_Stationary_800x450' + + def __len__ (self): + return len(self.video_name_list) + + def __getitem__ (self, i): + video_name = self.video_name_list[i] + video_sample = self.read_one_video(video_name) + return video_sample + + def read_one_video (self, video_name): + video_dir = join(self.dataset_dir, video_name) + + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_rgb_tensor.size(0) + rand_idx_list = utility.avg_last_sample(seq_len, self.seg_sample) + + video_flow_tensor = video_sample(video_flow_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + video_rgb_tensor = video_sample(video_rgb_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + rand_idx_list = [num+5 for num in rand_idx_list] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_idx_list} + return sample + +class ChopstickDataset_Pair (Dataset): + def __init__ (self, f_type, video_rgb_feature_dict, video_flow_feature_dict, pairs_dict, seg_sample=None): + self.f_type = f_type + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + + self.dataset_dir = '../dataset/MIT_Dive_Dataset/diving_samples_len_ori_800x450' + + self.pairs_dict = pairs_dict + self.pairs_list = list(self.pairs_dict.keys()) + + def __len__ (self): + return len(self.pairs_list) + + def __getitem__ (self, i): + v1_name, v2_name = self.pairs_list[i] + v1_sample = self.read_one_video(v1_name) + v2_sample = self.read_one_video(v2_name) + label = torch.Tensor([ self.pairs_dict[(v1_name, v2_name)] ]) + + pair = {'video1': v1_sample, 'video2': v2_sample, 'label': label} + return pair + + def read_one_video (self, video_name): + video_dir = join(self.dataset_dir, video_name) + + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_rgb_tensor.size(0) + rand_idx_list = utility.avg_rand_sample(seq_len, self.seg_sample) + + video_flow_tensor = video_sample(video_flow_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + video_rgb_tensor = video_sample(video_rgb_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) #num_seg x 4096 x 14 x 14 + + rand_idx_list = [num+5 for num in rand_idx_list] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_idx_list} + return sample + +if __name__ == "__main__": + extract_save_rgb_feature('../dataset/ChopstickUsing/ChopstickUsing_Stationary_800x450') + extract_save_flow_feature('../dataset/ChopstickUsing/ChopstickUsing_Stationary_800x450') +# get_rgb_feature_dict('../dataset/ChopstickUsing/ChopstickUsing_Stationary_800x450') diff --git a/dataset/diving_dataset.py b/dataset/diving_dataset.py new file mode 100644 index 0000000..a0a465a --- /dev/null +++ b/dataset/diving_dataset.py @@ -0,0 +1,285 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +import utility +from Models.ResNet import resnet101, resnet34 + +import os +from os.path import join, isdir, isfile +import glob +import math + +from PIL import Image +import pickle +from scipy.io import loadmat +from scipy.stats import spearmanr +from scipy.ndimage.filters import gaussian_filter +import numpy as np +from tqdm import tqdm +import cv2 +import csv +import time + +def extract_save_rgb_feature (dataset_dir): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_rgb_extractor = resnet101(pretrained=True, channel=3, output='conv5').to(device) + pretrained_model = torch.load('Models/ResNet101_rgb_pretrain.pth.tar') + resnet_rgb_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_rgb_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ + transforms.CenterCrop((448,448)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + videos_feature_dict = {} + video_name_list = [format(video_idx, '03d') for video_idx in range(1, 160)] + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_dir = join(video_dir, 'frame') + video_feature_dir = join(video_dir, 'feature') + + if not isdir(video_feature_dir): + os.system('mkdir '+video_feature_dir) +# for file in os.listdir(video_feature_dir): +# if 'heatmaps' not in file: +# os.system('rm '+join(video_feature_dir, file)) + + seq_len = len(os.listdir(video_frames_dir)) + video_frames_tensor = torch.stack([ + transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) + for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 + video_frames_tensor = video_frames_tensor.unsqueeze(0) + + # get video's resnet feature maps + video_feature = [] + for idx in range(seq_len): + batched_frames = video_frames_tensor[:,idx,:,:,:].to(device) # 1x3x448x448 + with torch.no_grad(): + batched_feature = resnet_rgb_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len x 2048 x 14 x 14 + torch.save(video_feature, join(video_dir, 'feature', 'resnet101_rgb_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/MITDive_resnet101_rgb_pretrain_conv5/', + (224,224), video_name, dataset_dir) + + return videos_feature_dict + +def extract_save_flow_feature (dataset_dir): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_flow_extractor = resnet101(pretrained=True, channel=20, output='conv5').to(device) + pretrained_model = torch.load('Models/ResNet101_flow_pretrain.pth.tar') + resnet_flow_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_flow_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ + transforms.CenterCrop((448,448)), + transforms.ToTensor(), + ]) + + videos_feature_dict = {} + video_name_list = [format(video_idx, '03d') for video_idx in range(1, 160)] + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_tensor = [] + + seq_len = int( len(os.listdir(join(video_dir, 'flow_tvl1'))) / 2 ) + for frame_idx in range(1, seq_len+1): + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_x_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_y_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor = torch.cat(video_frames_tensor, 0).unsqueeze(0) #1 x seq_lenx2 x 448 x 448 + + # get video's resnet feature maps + video_feature = [] + for idx in range(seq_len-9): + batched_frames = video_frames_tensor[:,2*idx:2*(idx+10),:,:].to(device) # 1x20x448x448 + with torch.no_grad(): + batched_feature = resnet_flow_extractor(batched_frames) # 1xCxWxH + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len-9 x 256 x 28 x 28 + torch.save(video_feature, join(video_dir, 'feature', 'resnet101_flow_10_pretrain_conv5.pt')) +# os.system('rm '+join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/MITDive_resnet101_flow_pretrain_conv5/', + (224,224), video_name, dataset_dir) + return videos_feature_dict + +def get_rgb_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = [format(video_idx, '03d') for video_idx in range(1, 160)] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_rgb_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_rgb_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature + return videos_feature_dict + +def get_flow_feature_dict (dataset_dir, feature_type='resnet101_conv5.pt'): + videos_feature_dict = {} + video_name_list = [format(video_idx, '03d') for video_idx in range(1, 160)] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_flow_10_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_flow_10_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature + return videos_feature_dict + +def del_featmaps (dataset_dir): + video_name_list = [format(video_idx, '03d') for video_idx in range(1, 160)] + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_feature_dir = join(video_dir, 'feature') + + for file in os.listdir(video_feature_dir): + if 'heatmaps' in file or 'resnet101' in file: + pass + else: + print("delete: ", file) + os.system('rm '+join(video_feature_dir, file)) + +def del_flow (dataset_dir): + video_name_list = [format(video_idx, '03d') for video_idx in range(1, 160)] + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_flow_dir = join(video_dir, 'flow') + + os.system('rm -rf '+video_flow_dir) + print('delete: ', video_flow_dir) + +######################################################################################################################### +# MIT_Dive_Dataset # +######################################################################################################################### +def video_sample (video_tensor, rand_idx_list): + sampled_video_tensor = torch.stack([video_tensor[idx,:,:,:] for idx in rand_idx_list], dim=0) + return sampled_video_tensor + +class MITDiveDataset (Dataset): + def __init__ (self, f_type, video_rgb_feature_dict, video_flow_feature_dict, video_idx_list, seg_sample): + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + self.video_idx_list = video_idx_list + self.f_type = f_type + self.overall_scores = np.load('../dataset/MIT_Dive_Dataset/diving_samples_len_ori_800x450/diving_overall_scores.npy') + self.overall_scores = torch.FloatTensor(np.squeeze(self.overall_scores)) + + def __len__ (self): + return len(self.video_idx_list) + + def __getitem__ (self, i): + video_idx = self.video_idx_list[i] + video_name = format(video_idx, '03d') + video_sample = self.read_one_video(video_name) + return video_sample + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_flow_tensor.size(0) #length of optical flow stack (index: 0~N-10) + + rand_flow_idx_list = utility.avg_last_sample(seq_len, self.seg_sample) + video_flow_tensor = video_sample(video_flow_tensor, rand_flow_idx_list) #num_seg x 2048 x 14 x 14 + + rand_rgb_idx_list = [flow_idx+5 for flow_idx in rand_flow_idx_list] + video_rgb_tensor = video_sample(video_rgb_tensor, rand_rgb_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) #num_seg x 4096 x 14 x 14 + + video_score = self.overall_scores[int(video_name)-1] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_rgb_idx_list, 'score': video_score} + return sample + +######################################################################################################################### +# MIT_Dive_Dataset for Pair # +######################################################################################################################### +class MITDiveDataset_Pair(Dataset): + def __init__(self, f_type, video_rgb_feature_dict, video_flow_feature_dict, pairs_dict, seg_sample): + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + self.f_type = f_type + + self.pairs_dict = pairs_dict + self.pairs_list = list(self.pairs_dict.keys()) + + def __len__(self): + return len(self.pairs_dict) + + def __getitem__(self, index): + v1_idx, v2_idx = self.pairs_list[index] + v1_name = format(v1_idx, '03d') + v2_name = format(v2_idx, '03d') + v1_sample = self.read_one_video(v1_name) + v2_sample = self.read_one_video(v2_name) + label = torch.Tensor([ self.pairs_dict[(v1_idx, v2_idx)] ]) + + pair = {'video1': v1_sample, 'video2': v2_sample, 'label': label} + return pair + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_flow_tensor.size(0) #length of optical flow stack (index: 0~N-10) + + rand_flow_idx_list = utility.avg_rand_sample(seq_len, self.seg_sample) + video_flow_tensor = video_sample(video_flow_tensor, rand_flow_idx_list) #num_seg x 2048 x 14 x 14 + + rand_rgb_idx_list = [flow_idx+5 for flow_idx in rand_flow_idx_list] + video_rgb_tensor = video_sample(video_rgb_tensor, rand_rgb_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_rgb_idx_list} + return sample + +if __name__ == "__main__": + # extract_save_flow_feature('../dataset/MIT_Dive_Dataset/diving_samples_len_ori_800x450') + # extract_save_rgb_feature('../dataset/MIT_Dive_Dataset/diving_samples_len_ori_800x450') + del_flow('../dataset/MIT_Dive_Dataset/diving_samples_len_ori_800x450') \ No newline at end of file diff --git a/dataset/dough_dataset.py b/dataset/dough_dataset.py new file mode 100644 index 0000000..f42eb1a --- /dev/null +++ b/dataset/dough_dataset.py @@ -0,0 +1,275 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +import utility +from Models.ResNet import resnet101, resnet34 + +import os +from os.path import join, isdir, isfile +import glob +import math + +from PIL import Image +import pickle +from scipy.io import loadmat +from scipy.stats import spearmanr +import numpy as np +from tqdm import tqdm +import cv2 +import csv +import time + +def extract_save_rgb_feature (dataset_dir): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_rgb_extractor = resnet101(pretrained=True, channel=3, output='conv5').to(device) + pretrained_model = torch.load('Models/ResNet101_rgb_pretrain.pth.tar') + resnet_rgb_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_rgb_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ + transforms.Resize((448,448), interpolation=3), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_dir = join(video_dir, 'frame') + + seq_len = len(os.listdir(video_frames_dir)) + video_frames_tensor = torch.stack([ + transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) + for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 + video_frames_tensor = video_frames_tensor.unsqueeze(0) + + # get video's resnet feature maps + video_feature = [] + for idx in range(5, seq_len-5, 50): + batched_frames = video_frames_tensor[:,idx,:,:,:].to(device) # 1x3x448x448 + with torch.no_grad(): + batched_feature = resnet_rgb_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len x 2048 x 14 x 14 + torch.save(video_feature, join(video_dir, 'feature', 'resnet_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') +# utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/chopstick_resnet101_conv5/', +# (320,180), video_name, dataset_dir) + return videos_feature_dict + +def extract_save_flow_feature (dataset_dir): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_flow_extractor = resnet101(pretrained=True, channel=20, output='conv5').to(device) + pretrained_model = torch.load('Models/ResNet101_flow_pretrain.pth.tar') + resnet_flow_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_flow_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ + transforms.Resize((448,448), interpolation=3), + transforms.ToTensor(), + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_tensor = [] + + seq_len = int( len(os.listdir(join(video_dir, 'flow_tvl1'))) / 2 ) + for frame_idx in range(1, seq_len+1): + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_x_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_y_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor = torch.cat(video_frames_tensor, 0).unsqueeze(0) #1 x seq_lenx2 x 448 x 448 + + # get video's resnet feature maps, 5-times down-sample + video_feature = [] + for idx in range(0, seq_len-9, 50): + batched_frames = video_frames_tensor[:,2*idx:2*(idx+10),:,:].to(device) # 1x20x448x448 + with torch.no_grad(): + batched_feature = resnet_flow_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len-9 x 2048 x 14 x 14 + + feature_save_dir = join(video_dir, 'feature') + if not os.path.isdir(feature_save_dir): + os.system('mkdir -p '+feature_save_dir) + torch.save(video_feature, join(feature_save_dir, 'resnet_flow_10_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') +# utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/chopstick_resnet101_10_pretrain_conv5/', +# (320,180), video_name, dataset_dir) + return videos_feature_dict + +def get_rgb_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature +# print(video_name, video_feature.size(0)) + return videos_feature_dict + +def get_flow_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature +# print(video_name, video_feature.size(0)) + return videos_feature_dict + +def get_train_test_pairs_dict(annotation_dir, split_idx): + train_pairs_dict = {} + train_videos = set() + train_csv = join(annotation_dir, 'DoughRolling_train_' + + format(split_idx, '01d')+'.csv') + with open(train_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + train_pairs_dict[key] = 1 + train_videos.update(key) + csvfile.close() + + test_pairs_dict = {} + test_videos = set() + test_csv = join(annotation_dir, 'DoughRolling_val_' + + format(split_idx, '01d')+'.csv') + with open(test_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + test_pairs_dict[key] = 1 + if key[0] not in train_videos: + test_videos.add(key[0]) + if key[1] not in train_videos: + test_videos.add(key[1]) + csvfile.close() + + return train_pairs_dict, test_pairs_dict, train_videos, test_videos + +######################################################################################################################### +# Chopstick_Dataset # +######################################################################################################################### +def video_sample (video_tensor, rand_idx_list): + sampled_video_tensor = torch.stack([video_tensor[idx,:,:,:] for idx in rand_idx_list], dim=0) + return sampled_video_tensor + +class DoughDataset (Dataset): + def __init__ (self, f_type, video_rgb_feature_dict, video_flow_feature_dict, video_name_list, seg_sample=None): + self.f_type = f_type + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + self.video_name_list = video_name_list + + def __len__ (self): + return len(self.video_name_list) + + def __getitem__ (self, i): + video_name = self.video_name_list[i] + video_sample = self.read_one_video(video_name) + return video_sample + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_rgb_tensor.size(0) + rand_idx_list = utility.avg_last_sample(seq_len, self.seg_sample) + + video_flow_tensor = video_sample(video_flow_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + video_rgb_tensor = video_sample(video_rgb_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + rand_idx_list = [50*num+5 for num in rand_idx_list] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_idx_list} + return sample + +class DoughDataset_Pair (Dataset): + def __init__ (self, f_type, video_rgb_feature_dict, video_flow_feature_dict, pairs_dict, seg_sample=None): + self.f_type = f_type + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + + self.pairs_dict = pairs_dict + self.pairs_list = list(self.pairs_dict.keys()) + + def __len__ (self): + return len(self.pairs_list) + + def __getitem__ (self, i): + v1_name, v2_name = self.pairs_list[i] + v1_sample = self.read_one_video(v1_name) + v2_sample = self.read_one_video(v2_name) + label = torch.Tensor([ self.pairs_dict[(v1_name, v2_name)] ]) + + pair = {'video1': v1_sample, 'video2': v2_sample, 'label': label} + return pair + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_rgb_tensor.size(0) + rand_idx_list = utility.avg_rand_sample(seq_len, self.seg_sample) + + video_flow_tensor = video_sample(video_flow_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + video_rgb_tensor = video_sample(video_rgb_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) #num_seg x 4096 x 14 x 14 + + rand_idx_list = [50*num+5 for num in rand_idx_list] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_idx_list} + return sample + +if __name__ == "__main__": + extract_save_flow_feature('/data/lzq/dataset/DoughRolling/DoughRolling_600x450') + extract_save_rgb_feature('/data/lzq/dataset/DoughRolling/DoughRolling_600x450') diff --git a/dataset/drawing_dataset.py b/dataset/drawing_dataset.py new file mode 100644 index 0000000..bc0f0c7 --- /dev/null +++ b/dataset/drawing_dataset.py @@ -0,0 +1,304 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +import utility +from Models.ResNet import resnet101, resnet34 + +import os +from os.path import join, isdir, isfile +import glob +import math +import argparse + +from PIL import Image +import pickle +from scipy.io import loadmat +from scipy.stats import spearmanr +from scipy.ndimage.filters import gaussian_filter +import numpy as np +from tqdm import tqdm +import cv2 +import csv +import time + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset_name", type=str, default='All', choices=['All', 'SonicDrawing', 'HandDrawing']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', choices=['resnet101_conv4', 'resnet101_conv5']) + +args = parser.parse_args() + +# featmap: seq_len x 2048 x 14 x 14 +def save_featmap_heatmaps (featmap, save_dir, size, video_name, dataset_dir): + seq_len = featmap.size(0) + + s = torch.norm(featmap, p=2, dim=1, keepdim=True) # seq_len x 1 x 14 x 14 + s = F.normalize(s.view(seq_len, -1),dim=1).view(s.size()) + + att_save_dir = join(save_dir, video_name) + ori_frames_dir = join(dataset_dir, video_name, 'frame') + + if not os.path.isdir(att_save_dir): + os.system('mkdir -p '+att_save_dir) + else: + os.system('rm -rf '+att_save_dir) + os.system('mkdir -p '+att_save_dir) + + for seq_idx in range(seq_len): + frame_idx = 5*int(seq_idx) + + heatmap = s[seq_idx,0,:,:] + heatmap = (heatmap-heatmap.min()) / (heatmap.max()-heatmap.min()) + heatmap = np.array(heatmap*255.0).astype(np.uint8) + heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) + heatmap = cv2.resize(heatmap, size) + + ori_frame = cv2.imread(join(ori_frames_dir, format(frame_idx, '05d')+'.png')) + if 'Dive' in save_dir: + ori_frame = ori_frame[1:449,176:624] + ori_frame = cv2.resize(ori_frame, size) + + comb = cv2.addWeighted(ori_frame, 0.6, heatmap, 0.4, 0) + pic_save_dir = join(att_save_dir, format(frame_idx, '05d')+'.jpg') + cv2.imwrite(pic_save_dir, comb) + +def extract_save_rgb_feature (dataset_dir): + feat_name = args.feature_type.split('_')[1] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_rgb_extractor = resnet101(pretrained=True, channel=3, output=feat_name).to(device) + pretrained_model = torch.load('Models/ResNet101_rgb_pretrain.pth.tar') + resnet_rgb_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_rgb_extractor.parameters(): + param.requires_grad = False + + if feat_name == 'conv4': + resize = (224,224) + elif feat_name == 'conv5': + resize = (448,448) + transform = transforms.Compose([ + transforms.Resize(resize, interpolation=3), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_dir = join(video_dir, 'frame') + + seq_len = len(os.listdir(video_frames_dir)) + video_frames_tensor = torch.stack([ + transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.png'))) + for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 + video_frames_tensor = video_frames_tensor.unsqueeze(0) + + # get video's resnet feature maps + video_feature = [] + for idx in range(5, seq_len-5, 5): + batched_frames = video_frames_tensor[:,idx,:,:,:].to(device) # 1x3x448x448 + with torch.no_grad(): + batched_feature = resnet_rgb_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len x 2048 x 14 x 14 + + feature_save_dir = join(video_dir, 'feature') + if not os.path.isdir(feature_save_dir): + os.system('mkdir -p '+feature_save_dir) + torch.save(video_feature, join(video_dir, 'feature', 'resnet_pretrain_'+feat_name+'.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + save_featmap_heatmaps(video_feature.cpu().data, + 'results/'+args.dataset_name+'_resnet101_'+feat_name+'/', + (320,180), video_name, dataset_dir) + return videos_feature_dict + +def extract_save_flow_feature (dataset_dir): + feat_name = args.feature_type.split('_')[1] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_flow_extractor = resnet101(pretrained=True, channel=20, output=feat_name).to(device) + pretrained_model = torch.load('Models/ResNet101_flow_pretrain.pth.tar') + resnet_flow_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_flow_extractor.parameters(): + param.requires_grad = False + + if feat_name == 'conv4': + resize = (224,224) + elif feat_name == 'conv5': + resize = (448,448) + transform = transforms.Compose([ + transforms.Resize(resize, interpolation=3), + transforms.ToTensor(), + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_tensor = [] + + seq_len = int( len(os.listdir(join(video_dir, 'flow_tvl1'))) / 2 ) + for frame_idx in range(1, seq_len+1): + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_x_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_y_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor = torch.cat(video_frames_tensor, 0).unsqueeze(0) #1 x seq_lenx2 x 448 x 448 + + # get video's resnet feature maps, 5-times down-sample + video_feature = [] + for idx in range(0, seq_len-9, 5): + batched_frames = video_frames_tensor[:,2*idx:2*(idx+10),:,:].to(device) # 1x20x448x448 + with torch.no_grad(): + batched_feature = resnet_flow_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len-9 x 2048 x 14 x 14 + + feature_save_dir = join(video_dir, 'feature') + if not os.path.isdir(feature_save_dir): + os.system('mkdir -p '+feature_save_dir) +# if os.path.isfile(join(video_dir, 'resnet_flow_10_pretrain_conv5.pt')): +# os.system('rm -f '+join(video_dir, 'resnet_flow_10_pretrain_conv5.pt')) + torch.save(video_feature, join(feature_save_dir, 'resnet_flow_10_pretrain_'+feat_name+'.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + save_featmap_heatmaps(video_feature.cpu().data, + 'results/'+args.dataset_name+'_resnet101_10_pretrain_'+feat_name+'/', + (320,180), video_name, dataset_dir) + return videos_feature_dict + +def get_rgb_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) + if isdir(join(dataset_dir, video_name))] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature + return videos_feature_dict + +def get_flow_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) + if isdir(join(dataset_dir, video_name))] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature + return videos_feature_dict + +# ######################################################################################################################### +# # Chopstick_Dataset # +# ######################################################################################################################### +def video_sample (video_tensor, rand_idx_list): + sampled_video_tensor = torch.stack([video_tensor[idx,:,:,:] for idx in rand_idx_list], dim=0) + return sampled_video_tensor + +class DrawingDataset (Dataset): + def __init__ (self, f_type, video_rgb_feature_dict, video_flow_feature_dict, video_name_list, seg_sample=None): + self.f_type = f_type + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + self.video_name_list = video_name_list + + def __len__ (self): + return len(self.video_name_list) + + def __getitem__ (self, i): + video_name = self.video_name_list[i] + video_sample = self.read_one_video(video_name) + return video_sample + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_rgb_tensor.size(0) + rand_idx_list = utility.avg_last_sample(seq_len, self.seg_sample) + + video_flow_tensor = video_sample(video_flow_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + video_rgb_tensor = video_sample(video_rgb_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + rand_idx_list = [5*num+5 for num in rand_idx_list] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_idx_list} + return sample + +class DrawingDataset_Pair (Dataset): + def __init__ (self, f_type, video_rgb_feature_dict, video_flow_feature_dict, pairs_dict, seg_sample=None): + self.f_type = f_type + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + + self.pairs_dict = pairs_dict + self.pairs_list = list(self.pairs_dict.keys()) + + def __len__ (self): + return len(self.pairs_list) + + def __getitem__ (self, i): + v1_name, v2_name = self.pairs_list[i] + v1_sample = self.read_one_video(v1_name) + v2_sample = self.read_one_video(v2_name) + label = torch.Tensor([ self.pairs_dict[(v1_name, v2_name)] ]) + + pair = {'video1': v1_sample, 'video2': v2_sample, 'label': label} + return pair + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_rgb_tensor.size(0) + rand_idx_list = utility.avg_rand_sample(seq_len, self.seg_sample) + + video_flow_tensor = video_sample(video_flow_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + video_rgb_tensor = video_sample(video_rgb_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + rand_idx_list = [5*num+5 for num in rand_idx_list] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_idx_list} + return sample + +if __name__ == "__main__": +# extract_save_flow_feature('../dataset/'+args.dataset_name+'/'+args.dataset_name+'_Stationary_800x450') +# extract_save_rgb_feature('../dataset/'+args.dataset_name+'/'+args.dataset_name+'_Stationary_800x450') + get_flow_feature_dict('../dataset/'+args.dataset_name+'/'+args.dataset_name+'_Stationary_800x450') \ No newline at end of file diff --git a/dataset/grasp_dataset.py b/dataset/grasp_dataset.py new file mode 100644 index 0000000..b444a5e --- /dev/null +++ b/dataset/grasp_dataset.py @@ -0,0 +1,432 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +import utility +from Models.ResNet import resnet101 +# from Models.c3d_model import C3D +# from Models import resnet_3d +from Models.VGG import vgg16 + +import os +from os.path import join, isdir, isfile +import glob +import math + +from PIL import Image +import pickle +from scipy.io import loadmat +from scipy.stats import spearmanr +from scipy.ndimage.filters import gaussian_filter +import numpy as np +from tqdm import tqdm +import cv2 +import csv + +# def extract_vgg_feature (dataset_dir): +# dataset_name = 'grasp' +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# resnet_rgb_extractor = vgg16(pretrained=True).to(device) +# resnet_rgb_extractor.eval() +# for param in resnet_rgb_extractor.parameters(): +# param.requires_grad = False + +# transform = transforms.Compose([ +# transforms.Resize((448,448), interpolation=3), +# # transforms.Resize((224,224), interpolation=3), +# transforms.ToTensor(), +# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +# ]) + +# videos_feature_dict = {} +# video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] +# for video_name in video_name_list: +# print(video_name+' is being processing.') + +# video_dir = join(dataset_dir, video_name) +# video_frames_dir = join(video_dir, 'frame') + +# seq_len = len(os.listdir(video_frames_dir)) +# video_frames_tensor = torch.stack([ +# transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) +# for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 +# video_frames_tensor = video_frames_tensor.unsqueeze(0) + +# # get video's resnet feature maps +# video_feature = [] +# for idx in range(seq_len): +# batched_frames = video_frames_tensor[:,idx,:,:,:].to(device) # 1x3x448x448 +# with torch.no_grad(): +# batched_feature = resnet_rgb_extractor(batched_frames) # 1xCxWxH +# video_feature.append(batched_feature) +# video_feature = torch.cat(video_feature, 0).cpu() # seq_len x C x W x H +# # torch.save(video_feature, join(video_dir, 'feature', 'resnet101_rgb_pretrain_conv5.pt')) + +# videos_feature_dict[video_name] = video_feature +# print(video_feature.size()) +# print(video_name+' is finished.') +# utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/'+dataset_name+'_vgg16_rgb_pool5/', +# (240,360), video_name, dataset_dir) +# return videos_feature_dict + +# def extract_3d_feature (dataset_dir): +# dataset_name = 'grasp' +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# resnet_rgb_extractor = resnet_3d.resnet34(sample_size=112, sample_duration=16, shortcut_type='A', +# num_classes=400, last_fc=True, output='conv4') +# model_dict = resnet_rgb_extractor.state_dict() +# # print(list(model_dict.keys())[:5]) +# checkpoint = torch.load('Models/resnet34_3d_kinetics.pth') +# pretrain_dict = checkpoint['state_dict'] +# # print(list(pretrain_dict.keys())[:5]) +# # print([k for k in pretrain_dict.keys() if 'module.' not in k]) +# pretrain_dict = {k[7:]: v for k,v in pretrain_dict.items() if k in model_dict} +# model_dict.update(pretrain_dict) +# resnet_rgb_extractor.load_state_dict(model_dict) +# # resnet_rgb_extractor.load_state_dict(torch.load('Models/resnet34_3d_kinetics.pth')['state_dict']) +# for param in resnet_rgb_extractor.parameters(): +# param.requires_grad = False +# resnet_rgb_extractor.to(device) +# resnet_rgb_extractor.eval() + +# transform = transforms.Compose([ +# transforms.Resize((224,224), interpolation=3), +# transforms.ToTensor(), +# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +# # transforms.Normalize(mean=[114.7748, 107.7354, 99.4750], std=[1,1,1]) +# ]) + +# videos_feature_dict = {} +# video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] +# for video_name in video_name_list[0:5]: +# print(video_name+' is being processing.') + +# video_dir = join(dataset_dir, video_name) +# video_frames_dir = join(video_dir, 'frame') + +# seq_len = len(os.listdir(video_frames_dir)) +# video_frames_tensor = torch.stack([ +# transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) +# for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 +# video_frames_tensor = video_frames_tensor.unsqueeze(0) + +# # get video's resnet feature maps +# video_feature = [] +# for idx in range(0, seq_len-15): +# batched_frames = video_frames_tensor[:,idx:idx+16,:,:,:].transpose(1,2).to(device) # 1x3x16x224x224 +# with torch.no_grad(): +# batched_feature = resnet_rgb_extractor(batched_frames) # 1x512x1x7x7 +# h,w = batched_feature.size()[3:] +# batched_feature = batched_feature.view(1,-1,h,w) # 1xCxWxH +# video_feature.append(batched_feature) +# video_feature = torch.cat(video_feature, 0).cpu() # seq_len x C x W x H +# torch.save(video_feature, join(video_dir, 'feature', 'c3d_rgb_pretrain_conv5b.pt')) + +# videos_feature_dict[video_name] = video_feature +# print(video_feature.size()) +# print(video_name+' is finished.') +# utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/'+dataset_name+'_c3d_rgb_pretrain_conv5b/', +# (240,360), video_name, dataset_dir) +# return videos_feature_dict + +# def extract_c3d_feature (dataset_dir): +# dataset_name = 'grasp' +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# c3d_extractor = C3D(output='pool5').to(device) +# c3d_dict = c3d_extractor.state_dict() +# pretrained_dict = torch.load('Models/c3d_rgb_pretrain.pickle') +# pretrained_dict = {k: v for k,v in pretrained_dict.items() if k in c3d_dict} +# # c3d_dict.update(pretrained_dict) +# c3d_extractor.load_state_dict(pretrained_dict) +# for param in c3d_extractor.parameters(): +# param.requires_grad = False +# c3d_extractor.eval() + +# transform = transforms.Compose([ +# transforms.Resize((224,224), interpolation=3), +# transforms.ToTensor(), +# # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +# # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1,1,1]) +# ]) + +# videos_feature_dict = {} +# video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] +# for video_name in video_name_list[0:5]: +# print(video_name+' is being processing.') + +# video_dir = join(dataset_dir, video_name) +# video_frames_dir = join(video_dir, 'frame') + +# seq_len = len(os.listdir(video_frames_dir)) +# video_frames_tensor = torch.stack([ +# transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) +# for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 +# video_frames_tensor = video_frames_tensor.unsqueeze(0) + +# # get video's resnet feature maps +# video_feature = [] +# for idx in range(0, seq_len-15): +# batched_frames = video_frames_tensor[:,idx:idx+16,:,:,:].transpose(1,2).to(device) # 1x3x16x224x224 +# with torch.no_grad(): +# batched_feature = c3d_extractor(batched_frames) # 1x512x2x14x14 +# h,w = batched_feature.size()[3:] +# batched_feature = batched_feature.view(1,-1,h,w) # 1xCxWxH +# video_feature.append(batched_feature) +# video_feature = torch.cat(video_feature, 0).cpu() # seq_len x C x W x H +# torch.save(video_feature, join(video_dir, 'feature', 'c3d_rgb_pretrain_conv5b.pt')) + +# videos_)feature_dict[video_name] = video_feature +# print(video_feature.size()) +# print(video_name+' is finished.') +# utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/'+dataset_name+'_c3d_rgb_pretrain_conv5b/', +# (240,360), video_name, dataset_dir) +# return videos_feature_dict + +def extract_save_rgb_feature (dataset_dir): + dataset_name = 'grasp' + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_rgb_extractor = resnet101(pretrained=True, channel=3, output='conv5').to(device) + pretrained_model = torch.load('Models/ResNet101_rgb_pretrain.pth.tar') + resnet_rgb_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_rgb_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ +# transforms.Resize((448,448), interpolation=3), + transforms.Resize((224,224), interpolation=3), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_dir = join(video_dir, 'frame') + + seq_len = len(os.listdir(video_frames_dir)) + video_frames_tensor = torch.stack([ + transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) + for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 + video_frames_tensor = video_frames_tensor.unsqueeze(0) + + # get video's resnet feature maps + video_feature = [] + for idx in range(seq_len): + batched_frames = video_frames_tensor[:,idx,:,:,:].to(device) # 1x3x448x448 + with torch.no_grad(): + batched_feature = resnet_rgb_extractor(batched_frames) # 1xCxWxH + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len x C x W x H + torch.save(video_feature, join(video_dir, 'feature', 'resnet101_rgb_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/'+dataset_name+'_resnet101_rgb_pretrain_conv5/', + (240,360), video_name, dataset_dir) + return videos_feature_dict + +def extract_save_flow_feature (dataset_dir): + dataset_name = 'grasp' + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_flow_extractor = resnet101(pretrained=True, channel=20, output='conv4').to(device) + pretrained_model = torch.load('Models/ResNet101_flow_pretrain.pth.tar') + resnet_flow_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_flow_extractor.parameters(): + param.requires_grad = False + + transform = transforms.Compose([ + transforms.Resize((224,224), interpolation=3), + transforms.ToTensor(), + ]) + + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_tensor = [] + + seq_len = int( len(os.listdir(join(video_dir, 'flow'))) / 3 ) + for frame_idx in range(1, seq_len+1): + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow', 'flow_x_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow', 'flow_y_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor = torch.cat(video_frames_tensor, 0).unsqueeze(0) #1 x seq_lenx2 x 448 x 448 + + # get video's resnet feature maps + video_feature = [] + for idx in range(seq_len-9): + batched_frames = video_frames_tensor[:,2*idx:2*(idx+10),:,:].to(device) # 1x20x448x448 + with torch.no_grad(): + batched_feature = resnet_flow_extractor(batched_frames) # 1xCxWxH + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len-9 x 256 x 28 x 28 + torch.save(video_feature, join(video_dir, 'feature', 'resnet101_flow_10_pretrain_conv4.pt')) +# if isfile(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv5.pt')): +# os.system('mv '+join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv5.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + utility.save_featmap_heatmaps(video_feature.cpu().data, 'results/'+dataset_name+'_resnet101_flow_pretrain_conv4/', + (240,360), video_name, dataset_dir) + return videos_feature_dict + +def get_rgb_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_rgb_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_rgb_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature + return videos_feature_dict + +def get_flow_feature_dict (dataset_dir, feature_type='resnet101_conv5'): + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_flow_10_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet101_flow_10_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature +# print(video_name, video_feature.size(0)) + return videos_feature_dict + +def del_featmaps (dataset_dir): + video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_feature_dir = join(video_dir, 'feature') + + for file in os.listdir(video_feature_dir): + if 'heatmaps' in file or 'resnet101' in file: + pass + else: + print("delete: ", file) + os.system('rm '+join(video_feature_dir, file)) + if 'resnet101_rgb_pretrain_conv5_7x7' in file: + print("delete: ", file) + os.system('rm '+join(video_feature_dir, file)) + +# ============================================================================= # +# Dataset for One Video # +# ============================================================================= # +def video_sample (video_tensor, rand_idx_list): + sampled_video_tensor = torch.stack([video_tensor[idx,:,:,:] for idx in rand_idx_list], dim=0) + return sampled_video_tensor + +class GraspDataset(Dataset): + def __init__(self, f_type, video_rgb_feature_dict, video_flow_feature_dict, video_name_list, seg_sample=None): + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + self.video_name_list = video_name_list + self.f_type = f_type + + def __len__(self): + return len(self.video_name_list) + + def __getitem__(self, index): + video_name = self.video_name_list[index] + video_sample = self.read_one_video(video_name) + return video_sample + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_flow_tensor.size(0) #length of optical flow stack (index: 0~N-10) + + rand_flow_idx_list = utility.avg_last_sample(seq_len, self.seg_sample) + video_flow_tensor = video_sample(video_flow_tensor, rand_flow_idx_list) #num_seg x 2048 x 14 x 14 + + rand_rgb_idx_list = [flow_idx+5 for flow_idx in rand_flow_idx_list] + video_rgb_tensor = video_sample(video_rgb_tensor, rand_rgb_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_rgb_idx_list} + return sample + +# ============================================================================= # +# Dataset for One Pair # +# ============================================================================= # +class GraspDataset_Pair(Dataset): + def __init__(self, f_type, video_rgb_feature_dict, video_flow_feature_dict, pairs_dict, seg_sample=None): + self.seg_sample = seg_sample + self.video_rgb_feature_dict = video_rgb_feature_dict + self.video_flow_feature_dict = video_flow_feature_dict + self.f_type = f_type + + self.pairs_dict = pairs_dict + self.pairs_list = list(self.pairs_dict.keys()) + + def __len__(self): + return len(self.pairs_dict) + + def __getitem__(self, index): + v1_name, v2_name = self.pairs_list[index] + v1_sample = self.read_one_video(v1_name) + v2_sample = self.read_one_video(v2_name) + label = torch.Tensor([ self.pairs_dict[(v1_name, v2_name)] ]) + + pair = {'video1': v1_sample, 'video2': v2_sample, 'label': label} + return pair + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_flow_tensor.size(0) #length of optical flow stack (index: 0~N-10) + + rand_flow_idx_list = utility.avg_rand_sample(seq_len, self.seg_sample) + video_flow_tensor = video_sample(video_flow_tensor, rand_flow_idx_list) #num_seg x 2048 x 14 x 14 + + rand_rgb_idx_list = [flow_idx+5 for flow_idx in rand_flow_idx_list] + video_rgb_tensor = video_sample(video_rgb_tensor, rand_rgb_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_rgb_idx_list} + return sample + +if __name__ == "__main__": +# extract_save_flow_feature('../dataset/InfantsGrasping/InfantsGrasping_480x720') +# extract_save_rgb_feature('../dataset/InfantsGrasping/InfantsGrasping_480x720') + del_featmaps('../dataset/InfantsGrasping/InfantsGrasping_480x720') diff --git a/dataset/sugury_dataset.py b/dataset/sugury_dataset.py new file mode 100644 index 0000000..12431a7 --- /dev/null +++ b/dataset/sugury_dataset.py @@ -0,0 +1,275 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +import utility +from Models.ResNet import resnet101, resnet34 + +import os +from os.path import join, isdir, isfile +import glob +import math +import argparse + +from PIL import Image +import pickle +from scipy.io import loadmat +from scipy.stats import spearmanr +from scipy.ndimage.filters import gaussian_filter +import numpy as np +from tqdm import tqdm +import cv2 +import csv +import time + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset_name", type=str, default='All', choices=['All', 'KnotTying', 'Suturing', 'NeedlePassing']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', choices=['resnet101_conv4', 'resnet101_conv5']) + +args = parser.parse_args() + +def extract_save_rgb_feature (dataset_dir): + feat_name = args.feature_type.split('_')[1] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_rgb_extractor = resnet101(pretrained=True, channel=3, output=feat_name).to(device) + pretrained_model = torch.load('Models/ResNet101_rgb_pretrain.pth.tar') + resnet_rgb_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_rgb_extractor.parameters(): + param.requires_grad = False + + if feat_name == 'conv4': + resize = (224,224) + elif feat_name == 'conv5': + resize = (448,448) + transform = transforms.Compose([ + transforms.Resize(resize, interpolation=3), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_dir = join(video_dir, 'frame') + + seq_len = len(os.listdir(video_frames_dir)) + video_frames_tensor = torch.stack([ + transform(Image.open(join(video_frames_dir, format(frame_idx, '05d')+'.jpg'))) + for frame_idx in range(seq_len)], dim=0) #1 x seq_len x 3 x 448 x 448 + video_frames_tensor = video_frames_tensor.unsqueeze(0) + + # get video's resnet feature maps + video_feature = [] + for idx in range(5, seq_len-5, 5): + batched_frames = video_frames_tensor[:,idx,:,:,:].to(device) # 1x3x448x448 + with torch.no_grad(): + batched_feature = resnet_rgb_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len x 2048 x 14 x 14 + + feature_save_dir = join(video_dir, 'feature') + if not os.path.isdir(feature_save_dir): + os.system('mkdir -p '+feature_save_dir) + torch.save(video_feature, join(video_dir, 'feature', 'resnet_pretrain_'+feat_name+'.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + utility.save_featmap_heatmaps(video_feature.cpu().data, + 'results/'+args.dataset_name+'_resnet101_'+feat_name+'/', + (320,240), video_name, dataset_dir) + return videos_feature_dict + +def extract_save_flow_feature (dataset_dir): + feat_name = args.feature_type.split('_')[1] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resnet_flow_extractor = resnet101(pretrained=True, channel=20, output=feat_name).to(device) + pretrained_model = torch.load('Models/ResNet101_flow_pretrain.pth.tar') + resnet_flow_extractor.load_state_dict(pretrained_model['state_dict']) + for param in resnet_flow_extractor.parameters(): + param.requires_grad = False + + if feat_name == 'conv4': + resize = (224,224) + elif feat_name == 'conv5': + resize = (448,448) + transform = transforms.Compose([ + transforms.Resize(resize, interpolation=3), + transforms.ToTensor(), + ]) + + videos_feature_dict = {} + video_name_list = os.listdir(dataset_dir) + for video_name in video_name_list: + print(video_name+' is being processing.') + + video_dir = join(dataset_dir, video_name) + video_frames_tensor = [] + + seq_len = int( len(os.listdir(join(video_dir, 'flow_tvl1'))) / 2 ) + for frame_idx in range(1, seq_len+1): + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_x_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor.append( + transform(Image.open(join(video_dir, 'flow_tvl1', 'flow_y_'+format(frame_idx, '05d')+'.jpg')))) + video_frames_tensor = torch.cat(video_frames_tensor, 0).unsqueeze(0) #1 x seq_lenx2 x 448 x 448 + + # get video's resnet feature maps, 5-times down-sample + video_feature = [] + for idx in range(0, seq_len-9, 5): + batched_frames = video_frames_tensor[:,2*idx:2*(idx+10),:,:].to(device) # 1x20x448x448 + with torch.no_grad(): + batched_feature = resnet_flow_extractor(batched_frames) # 1x2048x14x14 + video_feature.append(batched_feature) + video_feature = torch.cat(video_feature, 0).cpu() # seq_len-9 x 2048 x 14 x 14 + + feature_save_dir = join(video_dir, 'feature') + if not os.path.isdir(feature_save_dir): + os.system('mkdir -p '+feature_save_dir) + torch.save(video_feature, join(feature_save_dir, 'resnet_flow_10_pretrain_'+feat_name+'.pt')) + + videos_feature_dict[video_name] = video_feature + print(video_feature.size()) + print(video_name+' is finished.') + utility.save_featmap_heatmaps(video_feature.cpu().data, + 'results/'+args.dataset_name+'_resnet101_10_pretrain_'+feat_name+'/', + (320,240), video_name, dataset_dir) + return videos_feature_dict + +def get_rgb_feature_dict (dataset_dir, feature_type='resnet_conv5'): + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) + if isdir(join(dataset_dir, video_name))] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature + return videos_feature_dict + +def get_flow_feature_dict (dataset_dir, feature_type='resnet_conv5'): + videos_feature_dict = {} + video_name_list = [video_name for video_name in os.listdir(dataset_dir) + if isdir(join(dataset_dir, video_name))] + for video_name in video_name_list: + video_dir = join(dataset_dir, video_name) + + if feature_type == 'resnet101_conv5': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv5.pt')) + elif feature_type == 'resnet101_conv4': + video_feature = torch.load(join(video_dir, 'feature', 'resnet_flow_10_pretrain_conv4.pt')) + + videos_feature_dict[video_name] = video_feature + return videos_feature_dict + +# ######################################################################################################################### +# # Surgery_Dataset # +# ######################################################################################################################### +def video_sample (video_tensor, rand_idx_list): + sampled_video_tensor = torch.stack([video_tensor[idx,:,:,:] for idx in rand_idx_list], dim=0) + return sampled_video_tensor + +class SuguryDataset (Dataset): + def __init__ (self, ds_root, video_name_list, seg_sample=None): + self.ds_root = ds_root + self.video_name_list = video_name_list + self.seg_sample = seg_sample + + resize = (448, 448) + self.rgb_transform = transforms.Compose([ + transforms.Resize(resize, interpolation=3), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.flow_transform = transforms.Compose([ + transforms.Resize(resize, interpolation=3), + transforms.ToTensor(), + ]) + + def __len__ (self): + return len(self.video_name_list) + + def __getitem__ (self, i): + video_name = self.video_name_list[i] + video_sample = self.read_one_video(video_name) + return video_sample + + def read_one_video (self, video_name): + video_dir = join(self.ds_root, video_name) + video_rgb_dir = join(video_dir, 'frame') + video_flow_dir = join(video_dir, 'flow_tvl1') + + seq_len = len(os.listdir(video_rgb_dir)) + rand_idx_list = utility.avg_last_sample(seq_len-10, self.seg_sample) + + rgb_tensors = [] + flow_tensors = [] + sampled_index = [] + for idx in rand_idx_list: + frame_idx = idx + 5 + rgb_tensors.append(self.rgb_transform(Image.open( + join(video_rgb_dir, f'{frame_idx:05d}.jpg')))) + for fi in range(frame_idx-5, frame_idx+5): + flow_tensors.append(self.flow_transform(Image.open( + join(video_flow_dir, f'flow_x_{fi:05d}.jpg')))) + flow_tensors.append(self.flow_transform(Image.open( + join(video_flow_dir, f'flow_y_{fi:05d}.jpg')))) + sampled_index.append(frame_idx) + + sample = {'name': video_name, 'rgb': rgb_tensors, 'flow': flow_tensors, 'sampled_index': rand_idx_list} + return sample + +class SuguryDataset_Pair (Dataset): + def __init__ (self, pairs_dict, seg_sample=None): + self.seg_sample = seg_sample + + self.pairs_dict = pairs_dict + self.pairs_list = list(self.pairs_dict.keys()) + + def __len__ (self): + return len(self.pairs_list) + + def __getitem__ (self, i): + v1_name, v2_name = self.pairs_list[i] + v1_sample = self.read_one_video(v1_name) + v2_sample = self.read_one_video(v2_name) + label = torch.Tensor([ self.pairs_dict[(v1_name, v2_name)] ]) + + pair = {'video1': v1_sample, 'video2': v2_sample, 'label': label} + return pair + + def read_one_video (self, video_name): + video_rgb_tensor = self.video_rgb_feature_dict[video_name] + video_flow_tensor = self.video_flow_feature_dict[video_name] + + seq_len = video_rgb_tensor.size(0) + rand_idx_list = utility.avg_rand_sample(seq_len, self.seg_sample) + + video_flow_tensor = video_sample(video_flow_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + video_rgb_tensor = video_sample(video_rgb_tensor, rand_idx_list) #num_seg x 2048 x 14 x 14 + + if self.f_type == 'flow': + video_tensor = video_flow_tensor + elif self.f_type == 'rgb': + video_tensor = video_rgb_tensor + elif self.f_type == 'fusion': + video_tensor = torch.cat([video_rgb_tensor, video_flow_tensor], dim=1) ##num_seg x 4096 x 14 x 14 + + rand_idx_list = [5*num+5 for num in rand_idx_list] + sample = {'name': video_name, 'video': video_tensor, 'sampled_index': rand_idx_list} + return sample + diff --git a/diving_fusion_attention_transition.py b/diving_fusion_attention_transition.py new file mode 100644 index 0000000..1ab8f95 --- /dev/null +++ b/diving_fusion_attention_transition.py @@ -0,0 +1,349 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +from dataset.diving_dataset import MITDiveDataset, MITDiveDataset_Pair +from dataset.diving_dataset import get_flow_feature_dict, get_rgb_feature_dict +import utility + +import os +from os.path import join, isdir, isfile +import glob +import math +import argparse + +from PIL import Image +import pickle +from scipy.io import loadmat +from scipy.stats import spearmanr +from scipy.ndimage.filters import gaussian_filter +import numpy as np +from tqdm import tqdm +import cv2 +import csv +import pdb + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default='full', + choices=['full', 'only_x', 'only_htop', 'fc_att', 'no_att', 'cbam', 'sca', 'video_lstm', 'visual']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', choices=['resnet101_conv4', 'resnet101_conv5']) +parser.add_argument("--epoch_num", type=int, default=30) +parser.add_argument("--split_index", type=int, default=0, choices=[0,1,2,3,4]) +parser.add_argument("--label", type=str, default='Full_model') + +args = parser.parse_args() + +''' +class model (nn.Module): + def __init__ (self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2,2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1,1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.rnn_att = nn.GRUCell(self.x_size+self.rnn_top_size, self.rnn_att_size) + for param in self.rnn_att.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.a_size = 32 + self.xa_fc = nn.Linear(self.x_size, self.a_size, bias=True) + self.ha_fc = nn.Linear(self.rnn_att_size, self.a_size, bias=True) + self.a_fc = nn.Linear(self.a_size, 1, bias=False) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + +# self.x_ln = nn.LayerNorm(self.x_size) +# self.h_ln = nn.LayerNorm(self.rnn_top_size) + self.ln = nn.LayerNorm(self.rnn_top_size+self.x_size) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward (self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.zeros(batch_size, self.rnn_top_size).to(video_tensor.device) + h_att = torch.zeros(batch_size, self.rnn_att_size).to(video_tensor.device) + for frame_idx in range(seq_len): + featmap = video_tensor[:,frame_idx,:,:,:] #batch_size x 2D x 14 x 14 + + X = self.relu(self.pre_conv1(featmap)) #batch_size x C x 7 x 7 + X = self.pre_conv2(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) #batch_size x C + x_max = self.x_maxpool(X).view(batch_size, -1) + +# rnn_att_in = torch.cat((self.x_ln(x_avg+x_max),self.h_ln(h_top)), dim=1) +# rnn_att_in = torch.cat((x_avg+x_max, h_top), dim=1) + rnn_att_in = self.ln( torch.cat((x_avg+x_max, h_top), dim=1) ) + h_att = self.rnn_att(rnn_att_in, h_att) #batch_size x rnn_att_size + + X_tmp = X.view(batch_size, self.x_size, -1).transpose(1,2) #batch_size x 49 x C + h_att_tmp = h_att.unsqueeze(1).expand(-1,X_tmp.size(1),-1) #batch_size x 49 x rnn_att_size + a = self.tanh(self.xa_fc(X_tmp)+self.ha_fc(h_att_tmp)) + a = self.a_fc(a).unsqueeze(2) #batch_size x 49 + alpha = self.softmax(a) + s_att = alpha.view(batch_size, 1, X.size(2), X.size(3)) + video_soft_att.append(s_att) + + X = X * s_att #batch_size x C x 7 x 7 + rnn_top_in = torch.sum(X.view(batch_size, self.x_size, -1), dim=2) #batch_size x C + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + video_soft_att = torch.stack(video_soft_att, dim=1) #batch_size x seq_len x 1 x 14 x 14 + video_tmpr_att = torch.zeros(batch_size, seq_len) + return final_score, video_soft_att, video_tmpr_att +''' + + +def read_model(model_type, feature_type, num_seg): + feature_size = 2048 if feature_type == 'resnet101_conv5' else 1024 + if model_type in ['full', 'only_x', 'only_htop', 'fc_att', 'no_att']: + from model_def.Spa_Att import model + return model(feature_size, num_seg, variant=model_type) + elif model_type in ['cbam']: + from model_def.CBAM_Att import model + return model(feature_size, num_seg) + elif model_type in ['sca']: + from model_def.SCA_Att import model + return model(feature_size, num_seg) + elif model_type in ['video_lstm']: + from model_def.VideoLSTM import model + return model(feature_size, num_seg) + elif model_type in ['visual']: + from model_def.Visual_Att import model + return model(feature_size, num_seg) + else: + raise Exception(f'Unsupport model type of {model_type}.') + +################################################## +# Train & Test # +################################################## +def train (dataloader, model, criterion, optimizer, epoch, device, write_txt=False): + model.train() + + running_loss = 0.0 + running_acc = 0.0 + + for batch_idx, pair in enumerate(tqdm(dataloader)): + v1_tensor = pair['video1']['video'].to(device) + v2_tensor = pair['video2']['video'].to(device) + label = pair['label'].to(device).squeeze(1) + + v1_score, v1_satt, v1_tatt = model(v1_tensor) + v2_score, v2_satt, v2_tatt = model(v2_tensor) + + loss = criterion(v1_score, v2_score, label) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + running_acc += torch.nonzero((label*(v1_score-v2_score))>0).size(0) / v1_tensor.size(0) + + epoch_loss = running_loss / len(dataloader) + epoch_acc = running_acc / len(dataloader) + output = 'Epoch:{} '.format(epoch,)+'Train, Loss:{:.4f}, Acc:{:.4f}'.format(epoch_loss, epoch_acc) + print(output) + + if write_txt: + file = open("results/diving_fusion_attention_transition.txt", "a") + if epoch==0: + file.write("============================================================================\n") + file.write(output+'\n') + file.close() + +def test (dataloader, pairs_dict, model, criterion, epoch, device, write_txt): + model.eval() + + pred_score_list = torch.Tensor([]).to(dtype=torch.float32) + gt_score_list = torch.Tensor([]).to(dtype=torch.float32) + + videos_score = {} + for batch_idx, sample in enumerate(tqdm(dataloader)): + video_name = sample['name'] + gt_score = sample['score'] + v_tensor = sample['video'].to(device) + sampled_idx_list = sample['sampled_index'] + + with torch.no_grad(): + v_score, v_satt, v_tatt = model(v_tensor) + for i in range(v_tensor.size(0)): + videos_score[video_name[i]] = v_score[i].unsqueeze(0) + + pred_score_list = torch.cat((pred_score_list, v_score.cpu().data), 0) + gt_score_list = torch.cat((gt_score_list, gt_score.cpu().data), 0) + + running_loss = 0.0 + running_acc = 0.0 + + pairs_list = list(pairs_dict.keys()) + for v1_idx, v2_idx in pairs_list: + v1_name = format(v1_idx, '03d') + v2_name = format(v2_idx, '03d') + v1_score = videos_score[v1_name] + v2_score = videos_score[v2_name] + label = torch.Tensor([ pairs_dict[(v1_idx, v2_idx)] ]).to(device) + + loss = criterion(v1_score, v2_score, label) + + running_loss += loss.item() + running_acc += torch.nonzero((label*(v1_score-v2_score))>0).size(0) + + epoch_loss = running_loss / len(pairs_list) + epoch_acc = running_acc / len(pairs_list) + + rankcorr, _ = spearmanr(pred_score_list, gt_score_list) + + output = 'Epoch:{} '.format(epoch,)+'Test, Loss:{:.4f}, Acc:{:.4f}, RankCor:{:.4f}'.format(epoch_loss, epoch_acc, rankcorr) + print(output) + + if write_txt: + file = open("results/diving_fusion_attention_transition.txt", "a") + file.write(output+'\n') + file.close() + +# print(pred_score_list) + return epoch_loss, epoch_acc, rankcorr + +def save_best_result (dataloader, model, device, best_rankcorr, dataset_dir, test_video_idx_list): + model.eval() + file_name = 'diving_tr/'+args.label + + utility.save_best_checkpoint(epoch, model, best_rankcorr, join('checkpoints',file_name)) + + videos_score = {} + for sample in dataloader: + video_name = sample['name'] + v_tensor = sample['video'].to(device) + gt_score = sample['score'] + sampled_idx_list = sample['sampled_index'] + + with torch.no_grad(): + v_score, v_satt, v_tatt = model(v_tensor) + for i in range(v_tensor.size(0)): + videos_score[video_name[i]] = [v_score[i].item(), gt_score[i].item()] + + utility.save_heatmaps(v_satt.cpu().data, join('results',file_name), (320,320), + video_name, sampled_idx_list, v_tatt.cpu().data, dataset_dir) + + os.system('mkdir -p '+join('results/videos_score',file_name)) + with open(join('results/videos_score',file_name,'0.pickle'),'wb') as f: + pickle.dump(videos_score, f) + f.close() + +def get_train_test_pairs_dict (dataset_root, train_video_idx_list, test_video_idx_list, cross=False): + overall_scores = np.load(join(dataset_root, 'diving_overall_scores.npy')) + overall_scores = torch.FloatTensor(np.squeeze(overall_scores)) + + score_diff_thres = 5.0 + + train_pairs_dict = {} + for idx_1, video_idx_1 in enumerate(train_video_idx_list): + for idx_2 in range(idx_1+1, len(train_video_idx_list)): + video_idx_2 = train_video_idx_list[idx_2] + key = tuple((video_idx_1, video_idx_2)) + score_diff = overall_scores[video_idx_1-1] - overall_scores[video_idx_2-1] + if score_diff > score_diff_thres: + label = 1.0 + train_pairs_dict[key] = label + elif score_diff < -score_diff_thres: + label = -1.0 + train_pairs_dict[key] = label + + test_pairs_dict = {} + for idx_1, video_idx_1 in enumerate(test_video_idx_list): + for idx_2 in range(idx_1+1, len(test_video_idx_list)): + video_idx_2 = test_video_idx_list[idx_2] + key = tuple((video_idx_1, video_idx_2)) + score_diff = overall_scores[video_idx_1-1] - overall_scores[video_idx_2-1] + if score_diff > score_diff_thres: + label = 1.0 + test_pairs_dict[key] = label + elif score_diff < -score_diff_thres: + label = -1.0 + test_pairs_dict[key] = label + if cross: + for video_idx_1 in test_video_idx_list: + for video_idx_2 in train_video_idx_list: + key = tuple((video_idx_1, video_idx_2)) + score_diff = overall_scores[video_idx_1-1] - overall_scores[video_idx_2-1] + if score_diff > score_diff_thres: + label = 1.0 + test_pairs_dict[key] = label + elif score_diff < -score_diff_thres: + label = -1.0 + test_pairs_dict[key] = label + + return train_pairs_dict, test_pairs_dict + +# ============================================================================= # +# main # +# ============================================================================= # +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset_dir = '../dataset/MIT_Dive_Dataset/diving_samples_len_ori_800x450' + video_rgb_feature_dict = get_rgb_feature_dict(dataset_dir, 'resnet101_conv5') + video_flow_feature_dict = get_flow_feature_dict(dataset_dir, 'resnet101_conv5') + + video_idx_list = list(range(1, 160)) + train_video_idx_list = list(range(1, 101)) + test_video_idx_list = list(range(101, 160)) + train_pairs_dict, test_pairs_dict = get_train_test_pairs_dict( + dataset_dir, train_video_idx_list, test_video_idx_list, cross=False) + + num_seg = 25 + dataset_train = MITDiveDataset_Pair('fusion', video_rgb_feature_dict, video_flow_feature_dict, + train_pairs_dict, num_seg) + dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True) + + dataset_test = MITDiveDataset('fusion', video_rgb_feature_dict, video_flow_feature_dict, + test_video_idx_list, num_seg) + dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + + # model_ins = model(2048, num_seg) + # model_ins.to(device) + model_ins = read_model(args.model, args.feature_type, num_seg) + + criterion = nn.MarginRankingLoss(margin=0.5) + +# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_ins.parameters()), +# lr=1e-6, weight_decay=5e-4, amsgrad=True) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ins.parameters()), + lr=5e-5, momentum=0.9, weight_decay=1e-3) + + best_rankcorr = -1.1 + for epoch in range(20): + train(dataloader_train, model_ins, criterion, optimizer, epoch, device, write_txt=True) + epoch_loss, epoch_acc, rankcorr = test(dataloader_test, test_pairs_dict, + model_ins, criterion, epoch, device, write_txt=True) + if rankcorr > best_rankcorr: + best_rankcorr = rankcorr + save_best_result(dataloader_test, model_ins, device, best_rankcorr, dataset_dir, test_video_idx_list) + print('best rankcorr: {:.3f}'.format(best_rankcorr)) diff --git a/dough_fusion_attention_transition.py b/dough_fusion_attention_transition.py new file mode 100644 index 0000000..d821495 --- /dev/null +++ b/dough_fusion_attention_transition.py @@ -0,0 +1,231 @@ +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader + +from dataset.dough_dataset import DoughDataset, DoughDataset_Pair +from dataset.dough_dataset import get_flow_feature_dict, get_rgb_feature_dict +from common import train, test, save_best_result + +import os +from os.path import join, isdir, isfile, exists +import argparse +import csv + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default='full', + choices=['full', 'only_x', 'only_htop', 'fc_att', 'no_att', 'cbam', 'sca', 'video_lstm', 'visual']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', choices=['resnet101_conv5', 'resnet101_conv4']) +parser.add_argument("--epoch_num", type=int, default=70) +parser.add_argument("--split_index", type=int, default=0, choices=[0,1,2,3,4]) +parser.add_argument("--label", type=str, default='') + +args = parser.parse_args() + +''' +class model (nn.Module): + def __init__ (self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2,2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1,1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.rnn_att = nn.GRUCell(self.x_size+self.rnn_top_size, self.rnn_att_size) + for param in self.rnn_att.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.a_size = 32 + self.xa_fc = nn.Linear(self.x_size, self.a_size, bias=True) + self.ha_fc = nn.Linear(self.rnn_att_size, self.a_size, bias=True) + self.a_fc = nn.Linear(self.a_size, 1, bias=False) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + + self.x_ln = nn.LayerNorm(self.x_size) + self.h_ln = nn.LayerNorm(self.rnn_top_size) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.1) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward (self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.randn(batch_size, self.rnn_top_size).to(video_tensor.device) + h_att = torch.randn(batch_size, self.rnn_att_size).to(video_tensor.device) + for frame_idx in range(seq_len): + featmap = video_tensor[:,frame_idx,:,:,:] #batch_size x 2D x 14 x 14 + + X = self.relu(self.pre_conv1(featmap)) #batch_size x C x 7 x 7 + X = self.pre_conv2(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) #batch_size x C + x_max = self.x_maxpool(X).view(batch_size, -1) + + rnn_att_in = torch.cat((self.x_ln(x_avg+x_max),self.h_ln(h_top)), dim=1) +# rnn_att_in = torch.cat((x_avg+x_max, h_top), dim=1) + h_att = self.rnn_att(rnn_att_in, h_att) #batch_size x rnn_att_size + + X_tmp = X.view(batch_size, self.x_size, -1).transpose(1,2) #batch_size x 49 x C + h_att_tmp = h_att.unsqueeze(1).expand(-1,X_tmp.size(1),-1) #batch_size x 49 x rnn_att_size + + a = self.tanh(self.xa_fc(X_tmp)+self.ha_fc(h_att_tmp)) + a = self.a_fc(a).unsqueeze(2) #batch_size x 49 + alpha = self.softmax(a) + s_att = alpha.view(batch_size, 1, X.size(2), X.size(3)) + video_soft_att.append(s_att) + + X = X * s_att #batch_size x C x 7 x 7 + rnn_top_in = torch.sum(X.view(batch_size, self.x_size, -1), dim=2) #batch_size x C x 7 x 7 + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + video_soft_att = torch.stack(video_soft_att, dim=1) #batch_size x seq_len x 1 x 14 x 14 + video_tmpr_att = torch.zeros(batch_size, seq_len) + return final_score, video_soft_att, video_tmpr_att +''' + +def read_model(model_type, feature_type, num_seg): + feature_size = 2048 if feature_type == 'resnet101_conv5' else 1024 + if model_type in ['full', 'only_x', 'only_htop', 'fc_att', 'no_att']: + from model_def.Spa_Att import model + return model(feature_size, num_seg, variant=model_type) + elif model_type in ['cbam']: + from model_def.CBAM_Att import model + return model(feature_size, num_seg) + elif model_type in ['sca']: + from model_def.SCA_Att import model + return model(feature_size, num_seg) + elif model_type in ['video_lstm']: + from model_def.VideoLSTM import model + return model(feature_size, num_seg) + elif model_type in ['visual']: + from model_def.Visual_Att import model + return model(feature_size, num_seg) + else: + raise Exception(f'Unsupport model type of {model_type}.') + +def get_train_test_pairs_dict (annotation_dir, split_idx): + train_pairs_dict = {} + train_videos = set() + train_csv = join(annotation_dir, 'DoughRolling_train_'+format(split_idx, '01d')+'.csv') + with open(train_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + train_pairs_dict[key] = 1 + train_videos.update(key) + csvfile.close() + + test_pairs_dict = {} + test_videos = set() + test_csv = join(annotation_dir, 'DoughRolling_val_'+format(split_idx, '01d')+'.csv') + with open(test_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((row[0], row[1])) + test_pairs_dict[key] = 1 + if key[0] not in train_videos: + test_videos.add(key[0]) + if key[1] not in train_videos: + test_videos.add(key[1]) + csvfile.close() + + return train_pairs_dict, test_pairs_dict, train_videos, test_videos + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset_dir = '../dataset/DoughRolling/DoughRolling_600x450' + annotation_dir = '../dataset/DoughRolling/DoughRolling_Annotation/splits' + + video_name_list = os.listdir(dataset_dir) + video_rgb_feature_dict = get_rgb_feature_dict(dataset_dir, args.feature_type) + video_flow_feature_dict = get_flow_feature_dict(dataset_dir, args.feature_type) + + best_acc_keeper = [] + for split_idx in range(1, 5): + print("Split: "+format(split_idx, '01d')) + train_pairs_dict, test_pairs_dict, train_videos, test_videos = get_train_test_pairs_dict(annotation_dir, split_idx) + + num_seg = 25 + dataset_train = DoughDataset_Pair('fusion', video_rgb_feature_dict, video_flow_feature_dict, + train_pairs_dict, seg_sample=num_seg) + dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True) + + dataset_test = DoughDataset('fusion', video_rgb_feature_dict, video_flow_feature_dict, + video_name_list, seg_sample=num_seg) + dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + + model_ins = read_model(args.model, args.feature_type, num_seg) + + save_label = f'DoughRolling/{args.model}/{split_idx:01d}' + + best_acc = 0.0 + if args.continue_train: + ckpt_dir = join('checkpoints', save_label, + 'best_checkpoint.pth.tar') + if exists(checkpoint): + checkpoint = torch.load(ckpt_dir) + model_ins.load_state_dict(checkpoint['state_dict']) + best_acc = checkpoint['best_acc'] + print("Start from previous checkpoint, with rank_cor: {:.4f}".format( + checkpoint['best_acc'])) + else: + print("No previous checkpoint. \nStart from scratch.") + else: + print("Start from scratch.") + + model_ins.to(device) + + criterion = nn.MarginRankingLoss(margin=0.5) + + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_ins.parameters()), + lr=5e-6, weight_decay=0, amsgrad=False) +# optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ins.parameters()), +# lr=1e-3, momentum=0.9, weight_decay=1e-3) + +# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2) + + min_loss = 1.0 + no_imprv = 0 + for epoch in range(args.epoch_num): + train(dataloader_train, model_ins, criterion, optimizer, epoch, device) + epoch_loss, epoch_acc = test(dataloader_test, test_pairs_dict, model_ins, criterion, epoch, device) + + if epoch_acc >= best_acc: + best_acc = epoch_acc + save_best_result(dataloader_test, test_videos, + model_ins, device, best_acc, save_label) + + if epoch_loss <= min_loss: + min_loss = epoch_loss + else: + no_imprv += 1 + print('Best acc: {:.3f}'.format(best_acc)) +# if no_imprv > 3: +# break + best_acc_keeper.append(best_acc) + + for split_idx, best_acc in enumerate(best_acc_keeper): + print(f'Split: {split_idx+1}, {best_acc:.4f}') + print('Avg:', '{:.4f}'.format(sum(best_acc_keeper)/4)) diff --git a/drawing_fusion_attention_transition.py b/drawing_fusion_attention_transition.py new file mode 100644 index 0000000..4db2a7a --- /dev/null +++ b/drawing_fusion_attention_transition.py @@ -0,0 +1,270 @@ +from logging import raiseExceptions +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms + +from dataset.drawing_dataset import DrawingDataset, DrawingDataset_Pair +from dataset.drawing_dataset import get_flow_feature_dict, get_rgb_feature_dict +from common import train, test, save_best_result + +import os +from os.path import join, isdir, isfile, exists +import argparse +import csv + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset_name", type=str, default='All', + choices=['All', 'HandDrawing', 'SonicDrawing']) +parser.add_argument("--model", type=str, default='full', + choices=['full', 'only_x', 'only_htop', 'fc_att', 'no_att', 'cbam', 'sca', 'video_lstm', 'visual']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', choices=['resnet101_conv4', 'resnet101_conv5']) +parser.add_argument("--epoch_num", type=int, default=30, choices=[10,20,30]) +parser.add_argument("--split_index", type=int, default=0, choices=[0,1,2,3,4]) +parser.add_argument("--label", type=str, default='') + +args = parser.parse_args() + +''' +class model (nn.Module): + def __init__ (self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2,2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1,1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.rnn_att = nn.GRUCell(self.x_size+self.rnn_top_size, self.rnn_att_size) + for param in self.rnn_att.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.a_size = 32 + self.xa_fc = nn.Linear(self.x_size, self.a_size, bias=True) + self.ha_fc = nn.Linear(self.rnn_att_size, self.a_size, bias=True) + self.a_fc = nn.Linear(self.a_size, 1, bias=False) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + + self.x_ln = nn.LayerNorm(self.x_size) + self.h_ln = nn.LayerNorm(self.rnn_top_size) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.1) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward (self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.zeros(batch_size, self.rnn_top_size).to(video_tensor.device) + h_att = torch.zeros(batch_size, self.rnn_att_size).to(video_tensor.device) + for frame_idx in range(seq_len): + featmap = video_tensor[:,frame_idx,:,:,:] #batch_size x 2D x 14 x 14 + + X = self.relu(self.pre_conv1(featmap)) #batch_size x C x 7 x 7 + X = self.pre_conv2(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) #batch_size x C + x_max = self.x_maxpool(X).view(batch_size, -1) + + rnn_att_in = torch.cat((self.x_ln(x_avg+x_max),self.h_ln(h_top)), dim=1) +# rnn_att_in = torch.cat((x_avg+x_max, h_top), dim=1) + h_att = self.rnn_att(rnn_att_in, h_att) #batch_size x rnn_att_size + + X_tmp = X.view(batch_size, self.x_size, -1).transpose(1,2) #batch_size x 49 x C + h_att_tmp = h_att.unsqueeze(1).expand(-1,X_tmp.size(1),-1) #batch_size x 49 x rnn_att_size + + a = self.tanh(self.xa_fc(X_tmp)+self.ha_fc(h_att_tmp)) + a = self.a_fc(a).unsqueeze(2) #batch_size x 49 + alpha = self.softmax(a) + s_att = alpha.view(batch_size, 1, X.size(2), X.size(3)) + video_soft_att.append(s_att) + + X = X * s_att #batch_size x C x 7 x 7 + rnn_top_in = torch.sum(X.view(batch_size, self.x_size, -1), dim=2) #batch_size x C x 7 x 7 + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + video_soft_att = torch.stack(video_soft_att, dim=1) #batch_size x seq_len x 1 x 14 x 14 + video_tmpr_att = torch.zeros(batch_size, seq_len) + return final_score, video_soft_att, video_tmpr_att +''' + +def read_model(model_type, feature_type, num_seg): + feature_size = 2048 if feature_type == 'resnet101_conv5' else 1024 + if model_type in ['full', 'only_x', 'only_htop', 'fc_att', 'no_att']: + from model_def.Spa_Att import model + return model(feature_size, num_seg, variant=model_type) + elif model_type in ['cbam']: + from model_def.CBAM_Att import model + return model(feature_size, num_seg) + elif model_type in ['sca']: + from model_def.SCA_Att import model + return model(feature_size, num_seg) + elif model_type in ['video_lstm']: + from model_def.VideoLSTM import model + return model(feature_size, num_seg) + elif model_type in ['visual']: + from model_def.Visual_Att import model + return model(feature_size, num_seg) + else: + raise Exception(f'Unsupport model type of {model_type}.') + +def get_train_test_pairs_dict (annotation_dir, dataset_name, split_idx): + train_pairs_dict = {} + train_videos = set() + train_csv = join(annotation_dir, dataset_name+'_train_'+format(split_idx, '01d')+'.csv') + with open(train_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((dataset_name+'_'+row[0], dataset_name+'_'+row[1])) + train_pairs_dict[key] = 1 + train_videos.update(key) + csvfile.close() + + test_pairs_dict = {} + tets_videos = set() + test_csv = join(annotation_dir, dataset_name+'_val_'+format(split_idx, '01d')+'.csv') + with open(test_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + v0 = dataset_name+'_'+row[0] + v1 = dataset_name+'_'+row[1] + key = tuple((v0, v1)) + test_pairs_dict[key] = 1 + if v0 not in train_videos: + test_videos.add(v0) + if v1 not in train_videos: + test_videos.add(v1) + csvfile.close() + + return train_pairs_dict, test_pairs_dict, train_videos, test_videos + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset_name_list = ['SonicDrawing', 'HandDrawing'] if args.dataset_name=='All' else [args.dataset_name] + + # read name_list & feature dict of all the videos + video_name_list = [] + video_rgb_feature_dict = {} + video_flow_feature_dict = {} + for dataset_name in dataset_name_list: + dataset_dir = join('../dataset', dataset_name, dataset_name+'_Stationary_800x450') + sub_video_name_list = [video_name for video_name in os.listdir(dataset_dir) if isdir(join(dataset_dir, video_name))] + sub_video_rgb_feature_dict = get_rgb_feature_dict(dataset_dir, args.feature_type) + sub_video_flow_feature_dict = get_flow_feature_dict(dataset_dir, args.feature_type) + for video_name in sub_video_name_list: + sub_video_rgb_feature_dict[dataset_name+'_'+video_name] = sub_video_rgb_feature_dict.pop(video_name) + sub_video_flow_feature_dict[dataset_name+'_'+video_name] = sub_video_flow_feature_dict.pop(video_name) + sub_video_name_list = [dataset_name+'_'+video_name for video_name in sub_video_name_list] + + video_name_list += sub_video_name_list + video_rgb_feature_dict.update(sub_video_rgb_feature_dict) + video_flow_feature_dict.update(sub_video_flow_feature_dict) + del sub_video_name_list, sub_video_rgb_feature_dict, sub_video_flow_feature_dict + + best_acc_keeper = [] + for split_idx in range(1, 5): + print("Split: "+format(split_idx, '01d')) + + # read pairs dict of videos belonging to this split + train_pairs_dict = {} + test_pairs_dict = {} + train_videos = set() + test_videos = set() + for dataset_name in dataset_name_list: + annotation_dir = join('../dataset', dataset_name, dataset_name+'_Annotation/splits') + sub_train_pairs_dict, sub_test_pairs_dict, sub_train_videos, sub_test_videos = get_train_test_pairs_dict( + annotation_dir, dataset_name, split_idx) + + train_pairs_dict.update(sub_train_pairs_dict) + test_pairs_dict.update(sub_test_pairs_dict) + train_videos.update(sub_train_videos) + test_videos.update(sub_test_videos) + del sub_train_pairs_dict, sub_test_pairs_dict, sub_train_videos, sub_test_videos + + num_seg = 25 + dataset_train = DrawingDataset_Pair('fusion', video_rgb_feature_dict, video_flow_feature_dict, + train_pairs_dict, seg_sample=num_seg) + dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True) + + dataset_test = DrawingDataset('fusion', video_rgb_feature_dict, video_flow_feature_dict, + video_name_list, seg_sample=num_seg) + dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + + model_ins = read_model(args.model, args.feature_type, num_seg) + + save_label = f'Drawing_{args.dataset_name}/{args.model}/{split_idx:01d}' + + best_acc = 0.0 + if args.continue_train: + ckpt_dir = join('checkpoints', save_label, + 'best_checkpoint.pth.tar') + if exists(checkpoint): + checkpoint = torch.load(ckpt_dir) + model_ins.load_state_dict(checkpoint['state_dict']) + best_acc = checkpoint['best_acc'] + print("Start from previous checkpoint, with rank_cor: {:.4f}".format( + checkpoint['best_acc'])) + else: + print("No previous checkpoint. \nStart from scratch.") + else: + print("Start from scratch.") + + model_ins.to(device) + + criterion = nn.MarginRankingLoss(margin=0.5) + +# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_ins.parameters()), +# lr=1e-5, weight_decay=5e-4, amsgrad=True) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ins.parameters()), + lr=1e-3, momentum=0.9, weight_decay=1e-3) #real l2 reg = weight_decay*lr + +# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2) + + min_loss = 1.0 + no_imprv = 0 + for epoch in range(args.epoch_num): + train(dataloader_train, model_ins, criterion, optimizer, epoch, device) + epoch_loss, epoch_acc = test(dataloader_test, test_pairs_dict, model_ins, criterion, epoch, device) + + if epoch_acc >= best_acc: + best_acc = epoch_acc + save_best_result(dataloader_test, test_videos, + model_ins, device, best_acc, save_label) + + if epoch_loss <= min_loss: + min_loss = epoch_loss + no_imprv = 0 + else: + no_imprv += 1 + print('Best acc: {:.3f}'.format(best_acc)) +# if no_imprv > 3: +# break + best_acc_keeper.append(best_acc) + + for split_idx, best_acc in enumerate(best_acc_keeper): + print(f'Split: {split_idx+1}, {best_acc:.4f}') + print('Avg:', '{:.4f}'.format(sum(best_acc_keeper)/4)) diff --git a/grasp_fusion_attention_transition.py b/grasp_fusion_attention_transition.py new file mode 100644 index 0000000..dbbd278 --- /dev/null +++ b/grasp_fusion_attention_transition.py @@ -0,0 +1,267 @@ +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader + +from dataset.grasp_dataset import GraspDataset, GraspDataset_Pair +from dataset.grasp_dataset import get_flow_feature_dict, get_rgb_feature_dict + +import os +from os.path import join, isdir, isfile, exists +import math +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default='full', + choices=['full', 'only_x', 'only_htop', 'fc_att', 'no_att', 'cbam', 'sca', 'video_lstm', 'visual']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', + choices=['resnet101_conv4', 'resnet101_conv5']) +parser.add_argument("--epoch_num", type=int, default=20) +parser.add_argument("--split_index", type=int, default=0, choices=[0,1,2,3,4]) +parser.add_argument("--label", type=str, default='') + +args = parser.parse_args() + +''' +class model (nn.Module): + def __init__ (self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2,2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1,1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.rnn_att = nn.GRUCell(self.x_size+self.rnn_top_size, self.rnn_att_size) + for param in self.rnn_att.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.a_size = 32 + self.xa_fc = nn.Linear(self.x_size, self.a_size, bias=True) + self.ha_fc = nn.Linear(self.rnn_att_size, self.a_size, bias=True) + self.a_fc = nn.Linear(self.a_size, 1, bias=False) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + + self.x_ln = nn.LayerNorm(self.x_size) + self.h_ln = nn.LayerNorm(self.rnn_top_size) +# self.ln = nn.LayerNorm(self.rnn_top_size+self.x_size) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward (self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.randn(batch_size, self.rnn_top_size).to(video_tensor.device) + h_att = torch.randn(batch_size, self.rnn_att_size).to(video_tensor.device) + for frame_idx in range(seq_len): + featmap = video_tensor[:,frame_idx,:,:,:] #batch_size x 2D x 14 x 14 + + X = self.relu(self.pre_conv1(featmap)) #batch_size x C x 7 x 7 + X = self.pre_conv2(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) #batch_size x C + x_max = self.x_maxpool(X).view(batch_size, -1) + + rnn_att_in = torch.cat((self.x_ln(x_avg+x_max),self.h_ln(h_top)), dim=1) +# rnn_att_in = torch.cat((x_avg+x_max, h_top), dim=1) +# rnn_att_in = self.ln( torch.cat((x_avg+x_max, h_top), dim=1) ) + h_att = self.rnn_att(rnn_att_in, h_att) #batch_size x rnn_att_size + + X_tmp = X.view(batch_size, self.x_size, -1).transpose(1,2) #batch_size x 49 x C + h_att_tmp = h_att.unsqueeze(1).expand(-1,X_tmp.size(1),-1) #batch_size x 49 x rnn_att_size + a = self.tanh(self.xa_fc(X_tmp)+self.ha_fc(h_att_tmp)) + a = self.a_fc(a).unsqueeze(2) #batch_size x 49 + alpha = self.softmax(a) + s_att = alpha.view(batch_size, 1, X.size(2), X.size(3)) + video_soft_att.append(s_att) + + X = X * s_att #batch_size x C x 7 x 7 + rnn_top_in = torch.sum(X.view(batch_size, self.x_size, -1), dim=2) #batch_size x C + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + video_soft_att = torch.stack(video_soft_att, dim=1) #batch_size x seq_len x 1 x 14 x 14 + video_tmpr_att = torch.zeros(batch_size, seq_len) + return final_score, video_soft_att, video_tmpr_att +''' + +def read_model(model_type, feature_type, num_seg): + feature_size = 2048 if feature_type == 'resnet101_conv5' else 1024 + if model_type in ['full', 'only_x', 'only_htop', 'fc_att', 'no_att']: + from model_def.Spa_Att import model + return model(feature_size, num_seg, variant=model_type) + elif model_type in ['cbam']: + from model_def.CBAM_Att import model + return model(feature_size, num_seg) + elif model_type in ['sca']: + from model_def.SCA_Att import model + return model(feature_size, num_seg) + elif model_type in ['video_lstm']: + from model_def.VideoLSTM import model + return model(feature_size, num_seg) + elif model_type in ['visual']: + from model_def.Visual_Att import model + return model(feature_size, num_seg) + else: + raise Exception(f'Unsupport model type of {model_type}.') + +def get_train_test_videos_list (video_name_list, split_index, split_num): + train_video_list = [] + test_video_list = [] + + video_num = len(video_name_list) + test_video_num = int(math.floor(video_num / split_num)) + test_video_indexs = range(split_index*test_video_num, (split_index+1)*test_video_num) + + for video_index, video_name in enumerate(video_name_list): + if video_index in test_video_indexs: + test_video_list.append(video_name) + else: + train_video_list.append(video_name) + return train_video_list, test_video_list + +def get_train_test_pairs_dict (dataset_root, train_video_list, test_video_list, cross=True): + # Read pairs' annotation file + pairs_annotation_file = open(join(dataset_root, "annotation.txt"), "r") + all_pairs_dict = {} + lines = pairs_annotation_file.readlines() + for line in lines: + video_name_1, video_name_2, label = line.strip().split(' ') + all_pairs_dict[tuple((video_name_1, video_name_2))] = int(label) + + train_pairs_dict = {} + train_videos_num = len(train_video_list) +# print('training videos num:', train_videos_num) + for video_index, video_name_1 in enumerate(train_video_list): + for i in range(video_index+1, train_videos_num): + video_name_2 = train_video_list[i] + key = tuple((video_name_1, video_name_2)) + key_inv = tuple((video_name_2, video_name_1)) + if (key in all_pairs_dict) and (all_pairs_dict[key] != 0): + train_pairs_dict[key] = all_pairs_dict[key] + elif (key_inv in all_pairs_dict) and (all_pairs_dict[key_inv] != 0): + train_pairs_dict[key_inv] = all_pairs_dict[key_inv] + + test_pairs_dict = {} + test_video_num = len(test_video_list) +# print('validation videos num:', test_video_num) + for video_index, video_name_1 in enumerate(test_video_list): + for i in range(video_index+1, test_video_num): + video_name_2 = test_video_list[i] + key = tuple((video_name_1, video_name_2)) + key_inv = tuple((video_name_2, video_name_1)) + if (key in all_pairs_dict) and (all_pairs_dict[key] != 0): + test_pairs_dict[key] = all_pairs_dict[key] + elif (key_inv in all_pairs_dict) and (all_pairs_dict[key_inv] != 0): + test_pairs_dict[key_inv] = all_pairs_dict[key_inv] + if cross: + for video_name_1 in test_video_list: + for video_name_2 in train_video_list: + key = tuple((video_name_1, video_name_2)) + key_inv = tuple((video_name_2, video_name_1)) + if (key in all_pairs_dict) and (all_pairs_dict[key] != 0): + test_pairs_dict[key] = all_pairs_dict[key] + elif (key_inv in all_pairs_dict) and (all_pairs_dict[key_inv] != 0): + test_pairs_dict[key_inv] = all_pairs_dict[key_inv] + + return train_pairs_dict, test_pairs_dict + +# ============================================================================= # +# main # +# ============================================================================= # +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset_dir = '../dataset/InfantsGrasping/InfantsGrasping_480x720' + video_name_list = [video_name for video_name in os.listdir(dataset_dir) if '_L' in video_name] + video_rgb_feature_dict = get_rgb_feature_dict(dataset_dir, args.feature_type) + video_flow_feature_dict = get_flow_feature_dict(dataset_dir, args.feature_type) + + num_seg = 25 + + split_num = 4 + best_acc_keeper = [] + for split_idx in range(0, split_num): + print("Split: "+format(split_idx, '01d')) + train_video_list, test_video_list = get_train_test_videos_list(video_name_list, split_idx, split_num) + train_pairs_dict, test_pairs_dict = get_train_test_pairs_dict( + dataset_dir, train_video_list, test_video_list, cross=True) + + dataset_train = GraspDataset_Pair('fusion', video_rgb_feature_dict, video_flow_feature_dict, + train_pairs_dict, seg_sample=num_seg) + dataloader_train = DataLoader(dataset_train, batch_size=30, shuffle=True) + + dataset_test = GraspDataset('fusion', video_rgb_feature_dict, video_flow_feature_dict, + video_name_list, seg_sample=num_seg) + dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + + model_ins = read_model(args.model, args.feature_type, num_seg) + + save_label = f'Grasp/{args.model}/{split_idx:01d}' + + best_acc = 0.0 + if args.continue_train: + ckpt_dir = join('checkpoints', save_label, + 'best_checkpoint.pth.tar') + if exists(checkpoint): + checkpoint = torch.load(ckpt_dir) + model_ins.load_state_dict(checkpoint['state_dict']) + best_acc = checkpoint['best_acc'] + print("Start from previous checkpoint, with rank_cor: {:.4f}".format( + checkpoint['best_acc'])) + else: + print("No previous checkpoint. \nStart from scratch.") + else: + print("Start from scratch.") + model_ins.to(device) + + criterion = nn.MarginRankingLoss(margin=0.5) + +# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_ins.parameters()), +# lr=5e-6, weight_decay=5e-4, amsgrad=True) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ins.parameters()), + lr=5e-4, momentum=0.9, weight_decay=1e-2) + +# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2) + + min_loss = 1.0 + no_imprv = 0 + for epoch in range(args.epoch_num): + train(dataloader_train, model_ins, criterion, optimizer, epoch, device) + epoch_loss, epoch_acc = test(dataloader_test, test_pairs_dict, model_ins, criterion, epoch, device) + + if epoch_acc >= best_acc: + best_acc = epoch_acc + save_best_result(dataloader_test, test_video_list, model_ins, device, best_acc, save_label) + + if epoch_loss <= min_loss: + min_loss = epoch_loss + no_imprv = 0 + else: + no_imprv += 1 + print('Best acc: {:.3f}'.format(best_acc)) + # if no_imprv > 3: + # break + best_acc_keeper.append(best_acc) + + for split_idx, best_acc in enumerate(best_acc_keeper): + print(f'Split: {split_idx+1}, {best_acc:.4f}') + print('Avg:', '{:.4f}'.format(sum(best_acc_keeper)/4)) diff --git a/model_def/Att_Pool.py b/model_def/Att_Pool.py new file mode 100644 index 0000000..d4667ee --- /dev/null +++ b/model_def/Att_Pool.py @@ -0,0 +1,58 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable + +# For Attention Pooling (NeurIPS 2017) + +class model (nn.Module): + def __init__(self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2, 2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1, 1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.bottom_up = nn.Conv2d(self.x_size, 1, (1, 1)) + self.top_down = nn.Conv2d(self.x_size, 1, (1, 1)) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward(self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + video_score = [] + for frame_idx in range(seq_len): + # batch_size x 2D x 14 x 14 + featmap = video_tensor[:, frame_idx, :, :, :] + + X = self.relu(self.pre_conv1(featmap)) # batch_size x C x 7 x 7 + X = self.relu(self.pre_conv2(X)) # batch_size x C x 7 x 7 + + x_bu = self.bottom_up(X) # batch_size x 1 x 7 x 7 + x_td = self.top_down(X) # btahc_size x 1 x 7 x 7 + + score = torch.sum( + (x_bu*x_td).view(batch_size, -1), dim=1) # batch_size + video_score.append(score) + + s_att = x_bu * x_td + video_soft_att.append(s_att.detach().cpu()) + + video_score = torch.stack(video_score, dim=1) # batch_size x seq_len + final_score = torch.mean(video_score, dim=1) # batch_size + # batch_size x seq_len x 1 x 7 x 7 + video_soft_att = torch.stack(video_soft_att, dim=1) + return final_score, video_soft_att diff --git a/model_def/CBAM_Att.py b/model_def/CBAM_Att.py new file mode 100644 index 0000000..72b82e2 --- /dev/null +++ b/model_def/CBAM_Att.py @@ -0,0 +1,89 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable + +# For baseline CBAM Attention (ECCV 2018) + +class model (nn.Module): + def __init__(self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2, 2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1, 1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.x_mid_size = int(feature_size/8) + self.fc_shrk = nn.Linear(self.x_size, self.x_mid_size, bias=True) + self.fc_clps = nn.Linear(self.x_mid_size, self.x_size, bias=True) + + self.ins_norm = nn.InstanceNorm2d(2, affine=True) + self.conv_s1 = nn.Conv2d(2, 32, (3, 3), padding=(1, 1)) + self.conv_s2 = nn.Conv2d(32, 1, (1, 1)) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward(self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.randn(batch_size, self.rnn_top_size).to( + video_tensor.device) + for frame_idx in range(seq_len): + # batch_size x 2D x 14 x 14 + featmap = video_tensor[:, frame_idx, :, :, :] + + X = self.relu(self.pre_conv1(featmap)) # batch_size x C x 7 x 7 + X = self.pre_conv2(X) + + x_avg = self.x_avgpool(X).view(batch_size, -1) # batch_size x C + x_avg = self.relu(self.fc_shrk(x_avg)) + x_avg = self.fc_clps(x_avg) + x_max = self.x_maxpool(X).view(batch_size, -1) + x_max = self.relu(self.fc_shrk(x_max)) + x_max = self.fc_clps(x_max) + ch_att = self.sigmoid(x_avg+x_max) # batch_size x D + ch_att = ch_att.view(batch_size, self.x_size, 1, 1) + X = X * ch_att # batch_size x D x 14 x 14 + + s_avg = torch.mean(X, dim=1, keepdim=True) + s_max, _ = torch.max(X, dim=1, keepdim=True) + # batch_size x 2 x 14 x 14 + s_cat = torch.cat((s_avg, s_max), dim=1) + + s_cat = self.ins_norm(s_cat) + s_att = self.relu(self.conv_s1(s_cat)) + s_att = self.conv_s2(s_att) # batch_size x 1 x 7 x 7 + s_att = self.softmax(s_att.view(batch_size, -1)).view(s_att.size()) + video_soft_att.append(s_att.detach().cpu()) + + X = X * s_att # batch_size x C x 7 x 7 + rnn_top_in = torch.sum( + X.view(batch_size, self.x_size, -1), dim=2) # batch_size x C + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + # batch_size x seq_len x 1 x 14 x 14 + video_soft_att = torch.stack(video_soft_att, dim=1) + return final_score, video_soft_att diff --git a/model_def/Feature_Extractor.py b/model_def/Feature_Extractor.py new file mode 100644 index 0000000..b0f9753 --- /dev/null +++ b/model_def/Feature_Extractor.py @@ -0,0 +1,38 @@ +import os + +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F + +import torchvision +from torchvision.models import resnet101 + +class feature_extractor (nn.Module): + def __init__(self, input_type): + super(feature_extractor, self).__init__() + assert input_type in ['rgb', 'flow'] + self.input_type = input_type + + self.resnet = resnet101(pretrained=False) + self.resnet.fc.out_features = 101 + + if self.input_type == 'rgb': + pretrained_wgt = torch.load(os.path.join('model_param', 'ResNet101_rgb_pretrain.pth.tar')) + elif self.input_type == 'flow': + pretrained_wgt = torch.load(os.path.join('model_param', 'ResNet101_flow_pretrain.pth.tar')) + pretrained_wgt = pretrained_wgt['state_dict'] + pretrained_wgt = {k.replace('fc_custom', 'fc'): v for k, v in pretrained_wgt.items()} + self.resnet.load_state_dict(pretrained_wgt) + + self.feat_extractor = nn.Sequential(*list(self.resnet.children())[:-2]) # Remove avgpool and final fc + + def forward (self, inputs): + with torch.no_grad(): + bs, ch, nt, h, w = inputs.shape + frames = torch.cat(torch.unbind(inputs, dim=2), dim=0) # T*B x C x H x W + features = self.feat_extractor(frames) # T*B x 2048 x 14 x 14 + features = torch.stack(features.split(bs, dim=0), dim=1) # B x T x 2048 x 14 x 14 + return features + + diff --git a/model_def/SCA_Att.py b/model_def/SCA_Att.py new file mode 100644 index 0000000..e77867c --- /dev/null +++ b/model_def/SCA_Att.py @@ -0,0 +1,87 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable + +# For baseline SCA-CNN (CVPR 2017) + +class model (nn.Module): + def __init__(self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2, 2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1, 1)) + self.x_avgpool = nn.AvgPool2d(7) + + self.h_size = 128 + self.rnn_top = nn.GRUCell(self.x_size, self.h_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.k_size = int(self.x_size/4) + self.fc_xc = nn.Linear(1, self.k_size, bias=True) + self.fc_hc = nn.Linear(self.h_size, self.k_size, bias=False) + self.fc_b = nn.Linear(self.k_size, 1, bias=True) + + self.conv_s = nn.Conv2d(self.x_size, self.k_size, (1, 1), bias=True) + self.fc_hs = nn.Linear(self.h_size, self.k_size, bias=False) + self.fc_a = nn.Linear(self.k_size, 1, bias=True) + + self.score_fc = nn.Linear(self.h_size, 1, bias=True) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward(self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.randn(batch_size, self.h_size).to(video_tensor.device) + for frame_idx in range(seq_len): + # batch_size x 2D x 14 x 14 + featmap = video_tensor[:, frame_idx, :, :, :] + + X = self.relu(self.pre_conv1(featmap)) + X = self.pre_conv2(X) # batch_size x C x 7 x 7 + + x_avg = self.x_avgpool(X).view(batch_size, -1) # batch_size x C + # batch_size x k x C + tmp_bx = self.fc_xc(x_avg.unsqueeze(-1)).transpose(1, 2) + tmp_bh = self.fc_hc(h_top).unsqueeze(-1) # batch_size x k x 1 + b = self.tanh(tmp_bx + tmp_bh) # batch_size x k x C + # batch_size x C x 1 + beta = self.sigmoid(self.fc_b(b.transpose(1, 2))) + + ch_att = beta.unsqueeze(-1) # batch_size x C x 1 x 1 + X = X * ch_att # batch_size x C x 14 x 14 + + tmp_ax = self.conv_s(X).view( + batch_size, self.k_size, -1) # batch_size x k x 49 + tmp_ah = self.fc_hs(h_top).unsqueeze(-1) # batch_size x k x 1 + a = self.tanh(tmp_ax + tmp_ah) # batch_size x k x 49 + # batch_size x 49 x 1 + alpha = self.softmax(self.fc_a(a.transpose(1, 2))) + + s_att = alpha.view(batch_size, 1, X.size( + 2), X.size(3)) # batch_size x 1 x 7 x 7 + X = X * s_att # batch_size x C x 7 x 7 + + video_soft_att.append(s_att.detach().cpu()) + rnn_top_in = torch.sum( + X.view(batch_size, self.x_size, -1), dim=2) # batch_size x C + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + # batch_size x seq_len x 1 x 14 x 14 + video_soft_att = torch.stack(video_soft_att, dim=1) + return final_score, video_soft_att diff --git a/model_def/Spa_Att.py b/model_def/Spa_Att.py new file mode 100644 index 0000000..fc3772f --- /dev/null +++ b/model_def/Spa_Att.py @@ -0,0 +1,121 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F + +# For our ICCVW 2019 + +class model (nn.Module): + def __init__(self, feature_size, num_seg, variant='full'): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + self.variant = variant + assert variant in ['full', 'only_x', 'only_htop', 'fc_att', 'no_att'] + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2, 2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1, 1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + if self.variant == 'fc_att': + self.fc_att = nn.Linear( + self.x_size+self.rnn_top_size, self.rnn_att_size, bias=True) + else: + if self.variant == 'full': + self.rnn_att = nn.GRUCell(self.x_size+self.rnn_top_size, self.rnn_att_size) + elif self.variant == 'only_x': + self.rnn_att = nn.GRUCell(self.x_size, self.rnn_att_size) + elif self.variant == 'only_htop': + self.rnn_att = nn.GRUCell(self.rnn_top_size, self.rnn_att_size) + for param in self.rnn_att.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.a_size = 32 + self.xa_fc = nn.Linear(self.x_size, self.a_size, bias=True) + self.ha_fc = nn.Linear(self.rnn_att_size, self.a_size, bias=True) + self.a_fc = nn.Linear(self.a_size, 1, bias=False) + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + + self.x_ln = nn.LayerNorm(self.x_size) + self.h_ln = nn.LayerNorm(self.rnn_top_size) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward(self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.randn(batch_size, self.rnn_top_size).to( + video_tensor.device) + h_att = torch.randn(batch_size, self.rnn_att_size).to( + video_tensor.device) + for frame_idx in range(seq_len): + # batch_size x 2D x 14 x 14 + featmap = video_tensor[:, frame_idx, :, :, :] + + X = self.relu(self.pre_conv1(featmap)) # batch_size x C x 7 x 7 + X = self.pre_conv2(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) # batch_size x C + x_max = self.x_maxpool(X).view(batch_size, -1) + + if self.variant == 'no_att': + rnn_top_in = x_avg + else: + if self.variant == 'full': + rnn_att_in = torch.cat((self.x_ln(x_avg+x_max), self.h_ln(h_top)), dim=1) + # batch_size x rnn_att_size + h_att = self.rnn_att(rnn_att_in, h_att) + elif self.variant == 'only_x': + rnn_att_in = self.x_ln(x_avg+x_max) + # batch_size x rnn_att_size + h_att = self.rnn_att(rnn_att_in, h_att) + elif self.variant == 'only_htop': + rnn_att_in = self.h_ln(h_top) + # batch_size x rnn_att_size + h_att = self.rnn_att(rnn_att_in, h_att) + elif self.variant == 'fc_att': + rnn_att_in = torch.cat( + (self.x_ln(x_avg+x_max), self.h_ln(h_top)), dim=1) + h_att = self.fc_att(rnn_att_in) + + # batch_size x 49 x C + X_tmp = X.view(batch_size, self.x_size, -1).transpose(1, 2) + # batch_size x 49 x rnn_att_size + h_att_tmp = h_att.unsqueeze(1).expand(-1, X_tmp.size(1), -1) + a = self.tanh(self.xa_fc(X_tmp)+self.ha_fc(h_att_tmp)) + a = self.a_fc(a).unsqueeze(2) # batch_size x 49 + alpha = self.softmax(a) + s_att = alpha.view(batch_size, 1, X.size(2), X.size(3)) + video_soft_att.append(s_att) + + X = X * s_att # batch_size x C x 7 x 7 + rnn_top_in = torch.sum( + X.view(batch_size, self.x_size, -1), dim=2) # batch_size x C + + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + if self.variant == 'no_att': + video_soft_att = torch.zeros(batch_size, seq_len, 1, 14, 14) + else: + # batch_size x seq_len x 1 x 14 x 14 + video_soft_att = torch.stack(video_soft_att, dim=1) + return final_score, video_soft_att diff --git a/model_def/VideoLSTM.py b/model_def/VideoLSTM.py new file mode 100644 index 0000000..bf5eccc --- /dev/null +++ b/model_def/VideoLSTM.py @@ -0,0 +1,122 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F + +# For baseline Video LSTM + +class ConvLSTMCell(nn.Module): + def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): + super(ConvLSTMCell, self).__init__() + + self.height, self.width = input_size + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.kernel_size = kernel_size + self.padding = kernel_size[0] // 2, kernel_size[1] // 2 + self.bias = bias + + self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, + out_channels=4 * self.hidden_dim, + kernel_size=self.kernel_size, + padding=self.padding, + bias=self.bias) + + def forward(self, input_tensor, cur_state): + h_cur, c_cur = cur_state + + # concatenate along channel axis + combined = torch.cat([input_tensor, h_cur], dim=1) + + combined_conv = self.conv(combined) + cc_i, cc_f, cc_o, cc_g = torch.split( + combined_conv, self.hidden_dim, dim=1) + i = torch.sigmoid(cc_i) + f = torch.sigmoid(cc_f) + o = torch.sigmoid(cc_o) + g = torch.tanh(cc_g) + + c_next = f * c_cur + i * g + h_next = o * torch.tanh(c_next) + + return h_next, c_next + + +class model (nn.Module): + def __init__(self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + self.f_width = 14 + self.f_height = 14 + + self.x_size = 512 + self.conv_rgb = nn.Conv2d(self.f_size, self.x_size, (2, 2), stride=2) + self.conv_flow = nn.Conv2d(self.f_size, self.x_size, (2, 2), stride=2) + + self.htop_size = 128 + self.hatt_size = 128 + + self.top_lstm = ConvLSTMCell((7, 7), self.x_size, self.htop_size, + kernel_size=(3, 3), bias=True) + self.att_lstm = ConvLSTMCell((7, 7), self.x_size+self.htop_size, self.hatt_size, + kernel_size=(3, 3), bias=True) + + self.conv_z1 = nn.Conv2d( + self.x_size+self.hatt_size, 256, (1, 1), bias=True) + self.conv_z2 = nn.Conv2d(256, 1, (1, 1), bias=False) + + self.avgpool = nn.AvgPool2d(7) + self.score_fc = nn.Linear(self.htop_size*49, 1, bias=True) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.3) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward(self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] +# video_htop = [] + htop = torch.randn(batch_size, self.htop_size, + 7, 7).to(video_tensor.device) + ctop = torch.randn(batch_size, self.htop_size, + 7, 7).to(video_tensor.device) + hatt = torch.randn(batch_size, self.hatt_size, + 7, 7).to(video_tensor.device) + catt = torch.randn(batch_size, self.hatt_size, + 7, 7).to(video_tensor.device) + for frame_idx in range(seq_len): + # batch_size x 2D x 14 x 14 + featmap = video_tensor[:, frame_idx, :, :, :] +# rgb, flow = torch.split(featmap, self.f_size, dim=1) + rgb = self.conv_rgb(featmap[:, :2048, :, :]) + flow = self.conv_flow(featmap[:, 2048:, :, :]) + + att_lstm_in = torch.cat((flow, htop), dim=1) + hatt, catt = self.att_lstm(att_lstm_in, (hatt, catt)) + + z = self.conv_z1(torch.cat((flow, hatt), dim=1)) + z = self.tanh(z) + z = self.conv_z2(z) + s_att = self.softmax(z.view(batch_size, -1) + ).view(batch_size, 1, 7, 7) + video_soft_att.append(s_att) + + top_lstm_in = rgb * s_att + htop, ctop = self.top_lstm(top_lstm_in, (htop, ctop)) +# video_htop.append(htop) + +# video_htop = torch.stack(video_htop, dim=1) +# tmp = torch.mean(video_htop, dim=1, keepdim=False).view(batch_size, -1) + tmp = htop.view(batch_size, -1) + tmp = self.dropout(tmp) + final_score = self.score_fc(tmp).squeeze(1) + # batch_size x seq_len x 1 x 14 x 14 + video_soft_att = torch.stack(video_soft_att, dim=1) + return final_score, video_soft_att \ No newline at end of file diff --git a/model_def/Visual_Att.py b/model_def/Visual_Att.py new file mode 100644 index 0000000..431b264 --- /dev/null +++ b/model_def/Visual_Att.py @@ -0,0 +1,96 @@ +import torch +from torch import nn +from torch.nn import Parameter +from torch.nn import functional as F +from torch.autograd import Variable + +# For baseline Visual Attention (Arxiv 2015) + +class CustomLoss (nn.Module): + def __init__(self, lamda): + super(CustomLoss, self).__init__() + self.lamda = lamda + self.ranking_loss = nn.MarginRankingLoss(margin=0.5) + + def forward(self, v1_score, v2_score, label, video_soft_att): + #video_soft_att: batch_size x seq_len x 1 x 7 x 7 + batch_size = video_soft_att.size(0) + seq_len = video_soft_att.size(1) + + l1 = self.ranking_loss(v1_score, v2_score, label) + + video_soft_att = video_soft_att.view( + batch_size, seq_len, -1) # batch_size x seq_len x 49 + l2 = (1-torch.sum(video_soft_att, dim=1))**2 # batch_size x 49 + l2 = torch.mean(l2, dim=1) # batch_size + + return l1+l2 + +class model (nn.Module): + def __init__(self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2, 2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1, 1)) + self.x_avgpool = nn.AvgPool2d(7) + + self.h_size = 128 + self.fc_initC = nn.Linear(self.x_size, self.h_size, bias=True) + self.fc_initH = nn.Linear(self.x_size, self.h_size, bias=True) + self.rnn_top = nn.LSTMCell(self.x_size, self.h_size) + + self.fc_hl = nn.Linear(self.h_size, 49, bias=False) + + self.score_fc = nn.Linear(self.h_size, 1, bias=True) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward(self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_X = [] + init_vec = 0 + # featmap: batch_size x seq_len x 2D x 14 x 14 + for frame_idx in range(seq_len): + # batch_size x 2D x 14 x 14 + featmap = video_tensor[:, frame_idx, :, :, :] + X = self.relu(self.pre_conv1(featmap)) # batch_size x C x 7 x 7 + X = self.relu(self.pre_conv2(X)) # batch_size x C x 7 x 7 + video_X.append(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) # batch_size x C + init_vec += x_avg + # batch_size x seq_len x C x 7 x 7 + video_X = torch.stack(video_X, dim=1) + init_vec /= seq_len + + c_top = self.fc_initC(init_vec) # batch_size x h_size + h_top = self.fc_initH(init_vec) # batch_size x h_size + + video_soft_att = [] + for frame_idx in range(seq_len): + X = video_X[:, frame_idx, :, :, :] # batch_size x C x 7 x 7 + + l = self.fc_hl(h_top) # batch_size x 49 + l = self.softmax(l) + s_att = l.view(batch_size, 1, 7, 7) + + X = X * s_att # batch_size x C x 7 x 7 + video_soft_att.append(s_att.detach().cpu()) + + rnn_top_in = torch.sum( + X.view(batch_size, self.x_size, -1), dim=2) # batch_size x C + h_top, c_top = self.rnn_top(rnn_top_in, (h_top, c_top)) + + final_score = self.score_fc(h_top).squeeze(1) + # batch_size x seq_len x 1 x 14 x 14 + video_soft_att = torch.stack(video_soft_att, dim=1) + return final_score, video_soft_att diff --git a/sugury_fusion_attention_transition.py b/sugury_fusion_attention_transition.py new file mode 100644 index 0000000..c659191 --- /dev/null +++ b/sugury_fusion_attention_transition.py @@ -0,0 +1,266 @@ +import torch +from torch import nn +from torch.utils.data import Dataset, DataLoader + +from dataset.sugury_dataset import SuguryDataset, SuguryDataset_Pair +from dataset.sugury_dataset import get_flow_feature_dict, get_rgb_feature_dict +from common import train, test, save_best_result + +import os +from os.path import join, isdir, isfile, exists +import argparse +import csv + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset_name", type=str, default='All', + choices=['All', 'KnotTying', 'Suturing', 'NeedlePassing']) +parser.add_argument("--model", type=str, default='full', + choices=['full', 'only_x', 'only_htop', 'fc_att', 'no_att', 'cbam', 'sca', 'video_lstm', 'visual']) +parser.add_argument("--feature_type", type=str, default='resnet101_conv5', choices=['resnet101_conv4', 'resnet101_conv5']) +parser.add_argument("--epoch_num", type=int, default=30) +parser.add_argument("--label", type=str, default='') +parser.add_argument("--continue_train", action='store_true') + +args = parser.parse_args() + +''' +class model (nn.Module): + def __init__ (self, feature_size, num_seg): + super(model, self).__init__() + self.f_size = feature_size + self.num_seg = num_seg + + self.x_size = 256 + self.pre_conv1 = nn.Conv2d(2*self.f_size, 512, (2,2), stride=2) + self.pre_conv2 = nn.Conv2d(512, self.x_size, (1,1)) + self.x_avgpool = nn.AvgPool2d(7) + self.x_maxpool = nn.MaxPool2d(7) + + self.rnn_att_size = 128 + self.rnn_top_size = 128 + + self.rnn_top = nn.GRUCell(self.x_size, self.rnn_top_size) + for param in self.rnn_top.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + +# self.fc_att = nn.Linear(self.x_size+self.rnn_top_size, self.rnn_att_size, bias=False) +# self.rnn_att = nn.GRUCell(self.rnn_att_size, self.rnn_att_size) + self.rnn_att = nn.GRUCell(self.x_size+self.rnn_top_size, self.rnn_att_size) + for param in self.rnn_att.parameters(): + if param.dim() > 1: + torch.nn.init.orthogonal_(param) + + self.a_size = 32 + self.xa_fc = nn.Linear(self.x_size, self.a_size, bias=True) + self.ha_fc = nn.Linear(self.rnn_att_size, self.a_size, bias=True) + self.a_fc = nn.Linear(self.a_size, 1, bias=False) + + self.score_fc = nn.Linear(self.rnn_top_size, 1, bias=True) + + self.x_ln = nn.LayerNorm(self.x_size) + self.h_ln = nn.LayerNorm(self.rnn_top_size) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(1) + self.dropout = nn.Dropout(p=0.2) + + # video_featmaps: batch_size x seq_len x D x w x h + def forward (self, video_tensor): + batch_size = video_tensor.shape[0] + seq_len = video_tensor.shape[1] + + video_soft_att = [] + h_top = torch.randn(batch_size, self.rnn_top_size).to(video_tensor.device) + h_att = torch.randn(batch_size, self.rnn_att_size).to(video_tensor.device) + for frame_idx in range(seq_len): + featmap = video_tensor[:,frame_idx,:,:,:] #batch_size x 2D x 14 x 14 + + X = self.relu(self.pre_conv1(featmap)) #batch_size x C x 7 x 7 + X = self.pre_conv2(X) + x_avg = self.x_avgpool(X).view(batch_size, -1) #batch_size x C + x_max = self.x_maxpool(X).view(batch_size, -1) + + +# rnn_att_in = torch.cat((x_avg+x_max, h_top), dim=1) +# rnn_att_in = self.fc_att(rnn_att_in) + rnn_att_in = torch.cat((self.x_ln(x_avg+x_max),self.h_ln(h_top)), dim=1) + h_att = self.rnn_att(rnn_att_in, h_att) #batch_size x rnn_att_size + + X_tmp = X.view(batch_size, self.x_size, -1).transpose(1,2) #batch_size x 49 x C + h_att_tmp = h_att.unsqueeze(1).expand(-1,X_tmp.size(1),-1) #batch_size x 49 x rnn_att_size + a = self.tanh(self.xa_fc(X_tmp)+self.ha_fc(h_att_tmp)) + a = self.a_fc(a).unsqueeze(2) #batch_size x 49 + alpha = self.softmax(a) + s_att = alpha.view(batch_size, 1, X.size(2), X.size(3)) + video_soft_att.append(s_att) + + X = X * s_att #batch_size x C x 7 x 7 + rnn_top_in = torch.sum(X.view(batch_size, self.x_size, -1), dim=2) #batch_size x C + h_top = self.rnn_top(rnn_top_in, h_top) + + final_score = self.score_fc(h_top).squeeze(1) + video_soft_att = torch.stack(video_soft_att, dim=1) #batch_size x seq_len x 1 x 14 x 14 + video_tmpr_att = torch.zeros(batch_size, seq_len) + return final_score, video_soft_att, video_tmpr_att +''' + +def read_model(model_type, feature_type, num_seg): + feature_size = 2048 if feature_type == 'resnet101_conv5' else 1024 + if model_type in ['full', 'only_x', 'only_htop', 'fc_att', 'no_att']: + from model_def.Spa_Att import model + return model(feature_size, num_seg, variant=model_type) + elif model_type in ['cbam']: + from model_def.CBAM_Att import model + return model(feature_size, num_seg) + elif model_type in ['sca']: + from model_def.SCA_Att import model + return model(feature_size, num_seg) + elif model_type in ['video_lstm']: + from model_def.VideoLSTM import model + return model(feature_size, num_seg) + elif model_type in ['visual']: + from model_def.Visual_Att import model + return model(feature_size, num_seg) + else: + raise Exception(f'Unsupport model type of {model_type}.') + +def get_train_test_pairs_dict(annotation_dir, dataset_name, split_idx): + train_pairs_dict = {} + train_videos = set() + train_csv = join(annotation_dir, dataset_name+'_train_' + + format(split_idx, '01d')+'.csv') + with open(train_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + key = tuple((dataset_name+'_'+row[0], dataset_name+'_'+row[1])) + train_pairs_dict[key] = 1 + train_videos.update(key) + csvfile.close() + + test_pairs_dict = {} + test_videos = set() + test_csv = join(annotation_dir, dataset_name+'_val_' + + format(split_idx, '01d')+'.csv') + with open(test_csv, 'r') as csvfile: + csvreader = csv.reader(csvfile) + for row_idx, row in enumerate(csvreader): + if row_idx != 0: + v0 = dataset_name+'_'+row[0] + v1 = dataset_name+'_'+row[1] + key = tuple((v0, v1)) + test_pairs_dict[key] = 1 + if v0 not in train_videos: + test_videos.add(v0) + if v1 not in train_videos: + test_videos.add(v1) + csvfile.close() + + return train_pairs_dict, test_pairs_dict, train_videos, test_videos + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + dataset_name_list = ['KnotTying', 'NeedlePassing', 'Suturing'] if args.dataset_name=='All' else [args.dataset_name] + + # read name_list & feature dict of all the videos + video_name_list = [] + video_rgb_feature_dict = {} + video_flow_feature_dict = {} + for dataset_name in dataset_name_list: + dataset_dir = join('../dataset', dataset_name, dataset_name+'_640x480') + sub_video_name_list = [video_name for video_name in os.listdir(dataset_dir) if isdir(join(dataset_dir, video_name))] + sub_video_rgb_feature_dict = get_rgb_feature_dict(dataset_dir, args.feature_type) + sub_video_flow_feature_dict = get_flow_feature_dict(dataset_dir, args.feature_type) + for video_name in sub_video_name_list: + sub_video_rgb_feature_dict[dataset_name+'_'+video_name] = sub_video_rgb_feature_dict.pop(video_name) + sub_video_flow_feature_dict[dataset_name+'_'+video_name] = sub_video_flow_feature_dict.pop(video_name) + sub_video_name_list = [dataset_name+'_'+video_name for video_name in sub_video_name_list] + + video_name_list += sub_video_name_list + video_rgb_feature_dict.update(sub_video_rgb_feature_dict) + video_flow_feature_dict.update(sub_video_flow_feature_dict) + del sub_video_name_list, sub_video_rgb_feature_dict, sub_video_flow_feature_dict + + best_acc_keeper = [] + for split_idx in range(1, 5): + print("Split: "+format(split_idx, '01d')) + + # read pairs dict of videos belonging to this split + train_pairs_dict = {} + test_pairs_dict = {} + train_videos = set() + test_videos = set() + for dataset_name in dataset_name_list: + annotation_dir = join('../dataset', dataset_name, dataset_name+'_Annotation/splits') + sub_train_pairs_dict, sub_test_pairs_dict, sub_train_videos, sub_test_videos = get_train_test_pairs_dict( + annotation_dir, dataset_name, split_idx) + + train_pairs_dict.update(sub_train_pairs_dict) + test_pairs_dict.update(sub_test_pairs_dict) + train_videos.update(sub_train_videos) + test_videos.update(sub_test_videos) + del sub_train_pairs_dict, sub_test_pairs_dict, sub_train_videos, sub_test_videos + + num_seg = 25 + dataset_train = SuguryDataset_Pair('fusion', video_rgb_feature_dict, video_flow_feature_dict, + train_pairs_dict, seg_sample=num_seg) + dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True) + + dataset_test = SuguryDataset('fusion', video_rgb_feature_dict, video_flow_feature_dict, + video_name_list, seg_sample=num_seg) + dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False) + + model_ins = read_model(args.model, args.feature_type, num_seg) + + save_label = f'Surgery_{args.dataset_name}/{args.model}/{split_idx:01d}' + + best_acc = 0.0 + if args.continue_train: + ckpt_dir = join('checkpoints', save_label, 'best_checkpoint.pth.tar') + if exists(checkpoint): + checkpoint = torch.load(ckpt_dir) + model_ins.load_state_dict(checkpoint['state_dict']) + best_acc = checkpoint['best_acc'] + print("Start from previous checkpoint, with rank_cor: {:.4f}".format(checkpoint['best_acc'])) + else: + print("No previous checkpoint. \nStart from scratch.") + else: + print("Start from scratch.") + + model_ins.to(device) + + criterion = nn.MarginRankingLoss(margin=0.5) + +# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_ins.parameters()), +# lr=1e-5, weight_decay=0, amsgrad=True) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model_ins.parameters()), + lr=5e-4, momentum=0.9, weight_decay=1e-2) #real l2 reg = weight_decay*lr +# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2) + + min_loss = 1.0 + no_imprv = 0 + for epoch in range(args.epoch_num): + train(dataloader_train, model_ins, criterion, optimizer, epoch, device) + epoch_loss, epoch_acc = test(dataloader_test, test_pairs_dict, model_ins, criterion, epoch, device) + + if epoch_acc >= best_acc: + best_acc = epoch_acc + save_best_result(dataloader_test, test_videos, model_ins, device, best_acc, save_label) + + if epoch_loss <= min_loss: + min_loss = epoch_loss + no_imprv = 0 + else: + no_imprv += 1 + print('Best acc: {:.3f}'.format(best_acc)) +# if no_imprv > 3: +# break + best_acc_keeper.append(best_acc) + + for split_idx, best_acc in enumerate(best_acc_keeper): + print(f'Split: {split_idx+1}, {best_acc:.4f}') + print('Avg:', '{:.4f}'.format(sum(best_acc_keeper)/4)) diff --git a/utility.py b/utility.py new file mode 100644 index 0000000..0bddb21 --- /dev/null +++ b/utility.py @@ -0,0 +1,311 @@ +import torch +from torch import nn +from torch.nn import functional as F +import os +from os.path import join, isdir, isfile +import cv2 +import numpy as np +import random +import math +import csv + +# batched_heatmaps: batch_size x seq_len x 1 x 7 x 7 +def save_heatmaps (batched_heatmaps, save_dir, size, video_name, rand_idx_list, t_att, dataset_dir): + batch_size = batched_heatmaps.size(0) + seq_len = batched_heatmaps.size(1) + + for batch_offset in range(batch_size): + att_save_dir = join(save_dir, video_name[batch_offset]) + ori_frames_dir = join(dataset_dir, video_name[batch_offset], 'frame') + + if not os.path.isdir(att_save_dir): +# os.system('mkdir -p '+att_save_dir) + os.makedirs(att_save_dir) + else: + os.system('rm -rf '+att_save_dir) +# os.system('mkdir -p '+att_save_dir) + os.makedirs(att_save_dir) + +# print(rand_idx_list) + for seq_idx in range(seq_len): + frame_idx = int(rand_idx_list[seq_idx][batch_offset].item()) + + heatmap = batched_heatmaps[batch_offset,seq_idx,0,:,:] + heatmap = (heatmap-heatmap.min()) / (heatmap.max()-heatmap.min()) + heatmap = np.array(heatmap*255.0).astype(np.uint8) + heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) + heatmap = cv2.resize(heatmap, size) + + ori_frame = cv2.imread(join(ori_frames_dir, format(frame_idx, '05d')+'.jpg')) + if 'diving' in save_dir: + ori_frame = ori_frame[1:449,176:624] + ori_frame = cv2.resize(ori_frame, size) + + comb = cv2.addWeighted(ori_frame, 0.6, heatmap, 0.4, 0) +# print(t_att) + t_att_value = t_att[batch_offset, seq_idx].item() + pic_save_dir = join(att_save_dir, format(frame_idx, '05d')+'_'+format(t_att_value, '.2f')+'.jpg') + cv2.imwrite(pic_save_dir, comb) + +# featmap: seq_len x 2048 x 14 x 14 +def save_featmap_heatmaps (featmap, save_dir, size, video_name, dataset_dir): + seq_len = featmap.size(0) + + s = torch.norm(featmap, p=2, dim=1, keepdim=True) # seq_len x 1 x 14 x 14 + s = F.normalize(s.view(seq_len, -1),dim=1).view(s.size()) + + att_save_dir = join(save_dir, video_name) + ori_frames_dir = join(dataset_dir, video_name, 'frame') + + if not os.path.isdir(att_save_dir): + os.system('mkdir -p '+att_save_dir) + else: + os.system('rm -rf '+att_save_dir) + os.system('mkdir -p '+att_save_dir) + + for seq_idx in range(seq_len): + frame_idx = int(seq_idx) + + heatmap = s[seq_idx,0,:,:] + heatmap = (heatmap-heatmap.min()) / (heatmap.max()-heatmap.min()) + heatmap = np.array(heatmap*255.0).astype(np.uint8) + heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) + heatmap = cv2.resize(heatmap, size) + + ori_frame = cv2.imread(join(ori_frames_dir, format(frame_idx, '05d')+'.jpg')) + if 'Dive' in save_dir: + ori_frame = ori_frame[1:449,176:624] + ori_frame = cv2.resize(ori_frame, size) + + comb = cv2.addWeighted(ori_frame, 0.6, heatmap, 0.4, 0) + pic_save_dir = join(att_save_dir, format(frame_idx, '05d')+'.jpg') + cv2.imwrite(pic_save_dir, comb) + +def save_best_checkpoint (epoch, model, best_acc, save_dir): + best_checkpoint = {'epoch': epoch, + 'state_dict': model.state_dict(), + 'best_acc': best_acc} + + if not os.path.isdir(save_dir): +# os.system('mkdir -p '+save_dir) + os.makedirs(save_dir) + torch.save(best_checkpoint, join(save_dir, 'best_checkpoint.pth.tar')) + +def save_checkpoint (model, epoch, file_dir): + checkpoint_dir = join(file_dir, format(epoch, '03d')) + if not os.path.isdir(checkpoint_dir): + os.system('mkdir -p '+checkpoint_dir) + + torch.save(model.state_dict(), join(checkpoint_dir, 'checkpoint.pth.tar')) + +def save_record (file_name, type, epoch, epoch_loss, rank_cor): + new_file = open(file_name, "a") if isfile(file_name) else open(file_name, "w") + with new_file: + writer = csv.writer(new_file) + writer.writerow(["Epoch: ", epoch, "type: ", type]) + writer.writerow(["epoch_loss: ", format(epoch_loss,".2f"), "rank_cor: ", format(rank_cor,".2f")]) + +def avg_rand_sample (seq_len, num_seg): + r = int(seq_len / num_seg) + real_num_seg = int(math.ceil(seq_len / r)) + + frame_ind = [] + for i in range(0, real_num_seg-1): + frame_ind.append(random.randint(i*r, (i+1)*r-1)) + frame_ind.append(random.randint((real_num_seg-1)*r, seq_len-1)) + + frame_ind = frame_ind[len(frame_ind)-num_seg:] + return frame_ind + +def avg_first_sample (seq_len, num_seg): + r = int(seq_len / num_seg) + + frame_ind = [] + for i in range(0, seq_len, r): + frame_ind.append(i) + frame_ind = frame_ind[len(frame_ind)-num_seg:] + return frame_ind + +def avg_last_sample (seq_len, num_seg): + r = int(seq_len / num_seg) + + frame_ind = [] + for i in range(seq_len-1, -1, -r): + frame_ind.append(i) + + frame_ind = frame_ind[0:num_seg] + frame_ind.reverse() + return frame_ind + +def get_train_test_videos_list (video_record_list, split_index, split_num): + train_video_list = [] + test_video_list = [] + + video_num = len(video_record_list) + test_video_num = int(math.floor(video_num / split_num)) + test_video_indexs = range(split_index*test_video_num, (split_index+1)*test_video_num) + + for video_index, video_record in enumerate(video_record_list): + if video_index in test_video_indexs: + test_video_list.append(video_record) + else: + train_video_list.append(video_record) + return train_video_list, test_video_list + +# Ensure the file 'pairs_annotation.txt' exist +def get_train_test_pairs_dict (dataset_root, train_video_list, test_video_list, cross=True): + # Read pairs' annotation file + pairs_annotation_file = open(join(dataset_root, "annotation.txt"), "r") + all_pairs_dict = {} + lines = pairs_annotation_file.readlines() + for line in lines: + video_name_1, video_name_2, label = line.strip().split(' ') + all_pairs_dict[tuple((video_name_1, video_name_2))] = int(label) + + train_pairs_dict = {} + train_videos_num = len(train_video_list) +# print('training videos num:', train_videos_num) + for video_index, video_record in enumerate(train_video_list): + video_name_1 = video_record.video_name + for i in range(video_index+1, train_videos_num): + video_name_2 = train_video_list[i].video_name + key = tuple((video_name_1, video_name_2)) + key_inv = tuple((video_name_2, video_name_1)) + if (key in all_pairs_dict) and (all_pairs_dict[key] != 0): + train_pairs_dict[key] = all_pairs_dict[key] + elif (key_inv in all_pairs_dict) and (all_pairs_dict[key_inv] != 0): + train_pairs_dict[key_inv] = all_pairs_dict[key_inv] + + test_pairs_dict = {} + test_video_num = len(test_video_list) +# print('validation videos num:', test_video_num) + for video_index, video_record in enumerate(test_video_list): + video_name_1 = video_record.video_name + for i in range(video_index+1, test_video_num): + video_name_2 = test_video_list[i].video_name + key = tuple((video_name_1, video_name_2)) + key_inv = tuple((video_name_2, video_name_1)) + if (key in all_pairs_dict) and (all_pairs_dict[key] != 0): + test_pairs_dict[key] = all_pairs_dict[key] + elif (key_inv in all_pairs_dict) and (all_pairs_dict[key_inv] != 0): + test_pairs_dict[key_inv] = all_pairs_dict[key_inv] + if cross: + for video_record_test in test_video_list: + video_name_1 = video_record_test.video_name + for video_record_train in train_video_list: + video_name_2 = video_record_train.video_name + key = tuple((video_name_1, video_name_2)) + key_inv = tuple((video_name_2, video_name_1)) + if (key in all_pairs_dict) and (all_pairs_dict[key] != 0): + test_pairs_dict[key] = all_pairs_dict[key] + elif (key_inv in all_pairs_dict) and (all_pairs_dict[key_inv] != 0): + test_pairs_dict[key_inv] = all_pairs_dict[key_inv] + + return train_pairs_dict, test_pairs_dict + +# heatmaps_tensor: seq_len x 1 x w x h +def merge_heatmaps (heatmaps_tensor, num_merge, type='max'): + if type=='max': + pooling = nn.MaxPool3d((num_merge,1,1)) + elif type=='avg': + pooling = nn.AvgPool3d((num_merge,1,1)) + + seq_len = heatmaps_tensor.size(0) + heatmaps_tensor = heatmaps_tensor.unsqueeze(0) + + heatmaps_tensor = heatmaps_tensor.transpose(1,2) + merged_heatmaps_tensor = [] + for i in range(0, seq_len-num_merge+1): + merged_heatmaps_tensor.append(pooling(heatmaps_tensor[:,:,i:i+num_merge,:,:])) + + merged_heatmaps_tensor = torch.cat(merged_heatmaps_tensor, 1) +# print(merged_heatmaps_tensor.shape) + merged_heatmaps_tensor = merged_heatmaps_tensor.squeeze(0) + return merged_heatmaps_tensor + +class HingeL1Loss(nn.Module): + def __init__ (self, margin=0, size_average=True): + super(HingeL1Loss, self).__init__() + self.margin = margin + self.size_average=True + + def forward (self, input, target): + d = torch.clamp(torch.abs(input-target)-self.margin, min=0) + return torch.mean(d) if self.size_average else torch.sum(d) + +class SoftAttLoss (nn.Module): + def __init__ (self, size_average=True): + super(SoftAttLoss, self).__init__() + self.size_average = size_average + + def forward (self, pred_heatmaps, target_heatmaps): + batch_size = pred_heatmaps.size(0) + seq_len = pred_heatmaps.size(1) + + pred_heatmaps = F.normalize(pred_heatmaps.view(batch_size, seq_len, -1),dim=2) + target_heatmaps = F.normalize(target_heatmaps.view(batch_size, seq_len, -1),dim=2) + l = torch.norm(pred_heatmaps-target_heatmaps, p=2, dim=2) #batch_size x seq_len + l = torch.mean(l, dim=1) #batch_size + l = torch.mean(l) if self.size_average else torch.sum(l) + return l + +class HardAttLoss (nn.Module): + def __init__ (self, size_average=True): + super(HardAttLoss, self).__init__() + self.size_average = size_average + + def forward (self, pred_heatmaps): + batch_size = pred_heatmaps.size(0) + seq_len = pred_heatmaps.size(1) + + hard_att_max, _ = torch.max(pred_heatmaps.view(batch_size, seq_len, -1), dim=2) + l = 1.0 - hard_att_max #batch_size x seq_len + l = torch.mean(l, dim=1) + l = torch.mean(l) if self.size_average else torch.sum(l) + return l + +class OuterAttLoss (nn.Module): + def __init__ (self, size_average=True): + super(OuterAttLoss, self).__init__() + self.size_average = size_average + + def forward (self, pred_heatmaps, target_heatmaps): + batch_size = pred_heatmaps.size(0) + seq_len = pred_heatmaps.size(1) + + +# pred_heatmaps = F.normalize(pred_heatmaps.view(batch_size, seq_len, -1),dim=2) + pred_heatmaps = pred_heatmaps.view(batch_size, seq_len, -1) +# target_heatmaps = F.normalize(target_heatmaps.view(batch_size, seq_len, -1),dim=2) + target_heatmaps = target_heatmaps.view(batch_size, seq_len, -1) + + outer = (target_heatmaps==0).to(dtype=torch.float) + outer = pred_heatmaps*outer #batch_size x seq_len x 14*14 + l = torch.sum(outer, dim=2) #batch_size x seq_len +# l = torch.norm(outer, p=2, dim=2) + l = torch.mean(l, dim=1) #batch_size + l = torch.mean(l) if self.size_average else torch.sum(l) + return l +# ============================================================================= # +# Video info store class # +# ============================================================================= # +class VideoRecord(object): + def __init__(self, file_name): + self._file_name = file_name + self._data = file_name.strip().split('_') + + @property + def label(self): + return float(self._data[2][1:]) + + @property + def frame_rate(self): + return int(self._data[5]) + + @property + def video_name(self): + return str(self._file_name) + + @property + def video_len(self): + return int(self._data[4])-int(self._data[3]) + 1 \ No newline at end of file diff --git a/utils/ImageShow.py b/utils/ImageShow.py new file mode 100644 index 0000000..a2a81d8 --- /dev/null +++ b/utils/ImageShow.py @@ -0,0 +1,283 @@ +import matplotlib.pyplot as plt +import torch +import torchvision.transforms.functional as TF + +import numpy as np +from PIL import Image +from skimage import transform, filters +import math +import os + +import matplotlib +matplotlib.use("Agg") + + +def pil_to_tensor(pil_image): + r"""Convert a PIL image to a tensor. + Args: + pil_image (:class:`PIL.Image`): PIL image. + Returns: + :class:`torch.Tensor`: the image as a :math:`3\times H\times W` tensor + in the [0, 1] range. + """ + pil_image = np.array(pil_image) + if len(pil_image.shape) == 2: + pil_image = pil_image[:, :, None] + return torch.tensor(pil_image, dtype=torch.float32).permute(2, 0, 1) / 255 + + +def img_tensor_to_pil(img_tensor): + lim = [img_tensor.min(), img_tensor.max()] + img_tensor = img_tensor - lim[0] # also makes a copy + img_tensor.mul_(1 / (lim[1] - lim[0])) + img_tensor = torch.clamp(img_tensor, min=0, max=1) + img_pil = TF.to_pil_image(img_tensor) + return img_pil + + +def img_tensor_to_np(img_tensor): + lim = [img_tensor.min(), img_tensor.max()] + img_tensor = img_tensor - lim[0] # also makes a copy + img_tensor.mul_(1 / (lim[1] - lim[0])) + img_tensor = torch.clamp(img_tensor, min=0, max=1) + img_np = img_tensor.numpy() + return img_np + + +def voxel_tensor_to_np(voxel_tensor): + # voxel_tensor: CxTxHxW + voxel_np = [] + for t in range(voxel_tensor.shape[1]): + img_np = img_tensor_to_np(voxel_tensor[:, t, :, :]) + voxel_np.append(img_np) + voxel_np = np.stack(voxel_np, axis=1) + return voxel_np + + +def map_to_colormap(attMap, resize=(), norm_map=False, blur=False): + attMap = attMap.copy() + if norm_map: + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + + if resize != (): + attMap = transform.resize(attMap, resize, order=3) + + if blur: + attMap = filters.gaussian(attMap, 0.02*max(resize)) + attMap -= attMap.min() + attMap /= attMap.max() + + cmap = plt.get_cmap('jet') + colormap = cmap(attMap) + colormap = np.delete(colormap, 3, 2) + colormap = np.transpose(colormap, (2, 0, 1)) # 3x224x224 + return attMap, colormap + + +def overlap_map_on_img(img_np, attMap, norm_map=False, blur=False): + # img_np: CxHxW, attMap: h x w + plt.axis('off') + resized_map, colormap = map_to_colormap( + attMap, resize=(img_np.shape[1:]), norm_map=False, blur=blur) + resized_map = 1*(1-resized_map**0.8)*img_np + (resized_map**0.8)*colormap + return resized_map + + +def overlap_maps_on_voxel_np(voxel_np, attMaps, norm_map=False, blur=False): + overlaps = [] + for t in range(voxel_np.shape[1]): + img_np = voxel_np[:, t, :, :] + attMap = attMaps[t, :, :] + overlap = overlap_map_on_img(img_np, attMap, norm_map, blur) + overlaps.append(overlap) + overlaps = np.stack(overlaps, axis=1) # 3x16x224x224 + return overlaps + + +def img_np_show(img_np, interpolation='lanczos'): + bitmap = np.transpose(img_np, (1, 2, 0)) # HxWxC + handle = plt.imshow( + bitmap, interpolation=interpolation, vmin=0, vmax=1) + curr_ax = plt.gca() + curr_ax.axis('off') + return handle + + +def imsc(img, *args, lim=None, quiet=False, interpolation='lanczos', **kwargs): + r"""Rescale and displays an image represented as a img. + The function scales the img :attr:`im` to the [0 ,1] range. + The img is assumed to have shape :math:`3\times H\times W` (RGB) + :math:`1\times H\times W` (grayscale). + Args: + img (:class:`torch.Tensor` or :class:`PIL.Image`): image. + quiet (bool, optional): if False, do not display image. + Default: ``False``. + lim (list, optional): maximum and minimum intensity value for + rescaling. Default: ``None``. + interpolation (str, optional): The interpolation mode to use with + :func:`matplotlib.pyplot.imshow` (e.g. ``'lanczos'`` or + ``'nearest'``). Default: ``'lanczos'``. + Returns: + :class:`torch.Tensor`: Rescaled image img. + """ + if isinstance(img, Image.Image): + img = pil_to_tensor(img) + handle = None + with torch.no_grad(): + if not lim: + lim = [img.min(), img.max()] + img = img - lim[0] # also makes a copy + img.mul_(1 / (lim[1] - lim[0])) + img = torch.clamp(img, min=0, max=1) + if not quiet: + bitmap = img.expand( + 3, *img.shape[1:]).permute(1, 2, 0).cpu().numpy() + handle = plt.imshow( + bitmap, *args, interpolation=interpolation, **kwargs) + curr_ax = plt.gca() + curr_ax.axis('off') + return img, handle + + +def plot_voxel(voxel, saliency, show_plot=False, save_path=None): + # batch_size = len(input) + num_frame = voxel.shape[1] + num_row = 2 * num_frame//8 + + plt.clf() + fig = plt.figure(figsize=(16, num_row*2)) + for i in range(num_frame): + plt.subplot(num_row, 8, (i//8)*16+i % 8+1) + imsc(voxel[:, i, :, :]) + plt.title(i, fontsize=8) + + plt.subplot(num_row, 8, (i//8)*16+i % 8+8+1) + imsc(saliency[:, i, :, :], interpolation='none') + + # Save figure if path is specified. + if save_path: + save_dir = os.path.dirname(os.path.abspath(save_path)) + # Create directory if necessary. + if not os.path.exists(save_dir): + os.makedirs(save_dir) + ext = os.path.splitext(save_path)[1].strip('.') + plt.savefig(save_path, format=ext, bbox_inches='tight') + + # Show plot if desired. + if show_plot: + plt.show() + + +def plot_voxel_wbbox(voxel, saliency, bbox_tensor, + show_plot=False, save_path=None): + # batch_size = len(input) + num_frame = voxel.shape[1] + num_row = 2 * num_frame//8 + + for idx in range(num_frame): + x0, y0, x1, y1 = bbox_tensor[idx, :].tolist() + voxel[1, idx, y0:y1+1, x0] = 1.0 + voxel[1, idx, y0:y1+1, x1] = 1.0 + voxel[1, idx, y0, x0:x1+1] = 1.0 + voxel[1, idx, y1, x0:x1+1] = 1.0 + + plt.clf() + fig = plt.figure(figsize=(16, num_row*2)) + for i in range(num_frame): + plt.subplot(num_row, 8, (i//8)*16+i % 8+1) + imsc(voxel[:, i, :, :]) + plt.title(i, fontsize=8) + + plt.subplot(num_row, 8, (i//8)*16+i % 8+8+1) + imsc(saliency[:, i, :, :], interpolation='none') + + # Save figure if path is specified. + if save_path: + save_dir = os.path.dirname(os.path.abspath(save_path)) + # Create directory if necessary. + if not os.path.exists(save_dir): + os.makedirs(save_dir) + ext = os.path.splitext(save_path)[1].strip('.') + plt.savefig(save_path, format=ext, bbox_inches='tight') + + # Show plot if desired. + if show_plot: + plt.show() + + +def plot_voxel_np(voxel_np, saliency_np, title=None, + show_plot=False, save_path=None): + # batch_size = len(input) + num_frame = voxel_np.shape[1] + num_row = 2 * num_frame//8 + + plt.clf() + fig = plt.figure(figsize=(16, num_row*2)) + for i in range(num_frame): + plt.subplot(num_row, 8, (i//8)*16+i % 8+1) + img_np_show(voxel_np[:, i, :, :]) + plt.title(i, fontsize=8) + + plt.subplot(num_row, 8, (i//8)*16+i % 8+8+1) + img_np_show(saliency_np[:, i, :, :], interpolation='none') + # fig.close() + + if title is not None: + fig.suptitle(title, fontsize=14) + + # Save figure if path is specified. + if save_path: + save_dir = os.path.dirname(os.path.abspath(save_path)) + # Create directory if necessary. + if not os.path.exists(save_dir): + os.makedirs(save_dir) + ext = os.path.splitext(save_path)[1].strip('.') + plt.savefig(save_path, format=ext, bbox_inches='tight') + + # Show plot if desired. + if show_plot: + plt.show() + + +def plot_voxel_wbbox_np(voxel_np, saliency_np, bbox_tensor, title=None, + show_plot=False, save_path=None): + # batch_size = len(input) + num_frame = voxel_np.shape[1] + num_row = 2 * num_frame//8 + + for idx in range(num_frame): + x0, y0, x1, y1 = bbox_tensor[idx, :].tolist() + voxel_np[1, idx, y0:y1+1, x0] = 1.0 + voxel_np[1, idx, y0:y1+1, x1] = 1.0 + voxel_np[1, idx, y0, x0:x1+1] = 1.0 + voxel_np[1, idx, y1, x0:x1+1] = 1.0 + + plt.clf() + fig = plt.figure(figsize=(16, num_row*2)) + for i in range(num_frame): + plt.subplot(num_row, 8, (i//8)*16+i % 8+1) + img_np_show(voxel_np[:, i, :, :]) + plt.title(i, fontsize=8) + + plt.subplot(num_row, 8, (i//8)*16+i % 8+8+1) + img_np_show(saliency_np[:, i, :, :], interpolation='none') + + if title is not None: + fig.suptitle(title, fontsize=14) + + # Save figure if path is specified. + if save_path: + save_dir = os.path.dirname(os.path.abspath(save_path)) + # Create directory if necessary. + if not os.path.exists(save_dir): + os.makedirs(save_dir) + ext = os.path.splitext(save_path)[1].strip('.') + plt.savefig(save_path, format=ext, bbox_inches='tight') + + # Show plot if desired. + if show_plot: + plt.show() + + plt.close(fig)