Skip to content

Commit

Permalink
first commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
acb11711tx committed Apr 27, 2020
1 parent 77298ba commit fe388ca
Show file tree
Hide file tree
Showing 405 changed files with 2,781 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
visual_res/*
frames/
__pycache__
*.pt
*.pth
*.mp4
*.pyc
surgery_gesture_r2plus1d.py
surgery*
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,46 @@
# VideoVisual

This is a PyTorch demo implemented several visualization methods for video classification networks. The target is to provide a toolkit (as [TorchRay](https://github.com/facebookresearch/TorchRay) to image) to interprete commonly utilized video classfication networks, such as I3D, R(2+1)D, TSM et al., which is also called *attribution* task, namely the problem of determining which part of the input video is responsible for the value computed by a neural network.

The current version supports attribution methods and video classification models as following:

#### Video classification models:
* **Pretrained on Kinetics-400**: I3D, R(2+1)D, R3D, MC3, TSM;
* **Pretrained on EPIC-Kitchens**: (noun & verb): TSM.

#### Attribution methods:
* **Backprop-based**: Gradients, Gradients x Inputs, Integrated Gradients;
* **Activation-based**: GradCAM (does not support TSM now);
* **Perturbation-based**: Extremal Perturbation and Spatiotemporal Perturbation (An extension vision of extremal perturbation on video inputs).

## Requirements

* Python 3.6.5 or greater
* PyTorch 1.2.0 or greater
* matplotlib==2.2.3
* numpy==1.14.3
* opencv_python==4.1.2.30
* torchvision==0.4.0a0
* torchray==1.0.0.2
* tqdm==4.45.0
* pandas==0.23.3
* scikit_image==0.15.0
* Pillow==7.1.2
* scikit_learn==0.22.2.post1

## Running the code

### Examples

#### Saptiotemporal Perturbation + I3D (pretrained on Kinetics-400)
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method perturb --num_iter 2000 --perturb_area 0.1`

#### Spatiotemporal Perturbation + TSM (pretrained on EPIC-Kitchens-noun)
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/epic-kitchens-noun/sampled_frames --model tsm --pretrain_dataset epic-kitchens-noun --vis_method perturb --num_iter 2000 --perturb_area 0.05`

#### Integrated Gradients + R(2+1)D (pretrained on Kinetics-400)
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method integrated_grad`

## License

TorchRay is CC-BY-NC licensed, as found in the [LICENSE](LICENSE) file.
74 changes: 74 additions & 0 deletions datasets/universal_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import os
import json
import numpy as np
from tqdm import tqdm
from PIL import Image

import sys
sys.path.append(".")
sys.path.append("..")
from utils.LongRangeSample import long_range_sample

class UniversalDataset (Dataset):
def __init__ (self, data_dir, model_name, class_namelist, clip_length=16):
self.data_dir = data_dir
self.model_name = model_name
self.class_namelist = class_namelist
self.clip_length = clip_length

self.video_names = sorted(os.listdir(data_dir))
assert len(self.video_names) > 0, f'Given directory contains no video.'

if model_name == 'i3d':
self.transform = transforms.Compose([
transforms.Resize((344, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.43216, 0.39467, 0.37645], [0.22803, 0.22145, 0.21699]),
])
elif model_name in ['r2plus1d', 'r3d', 'mc3']:
self.transform = transforms.Compose([
transforms.Resize((172, 128)),
transforms.CenterCrop((112, 112)),
transforms.ToTensor(),
transforms.Normalize([0.43216, 0.39467, 0.37645], [0.22803, 0.22145, 0.21699]),
])
elif model_name in ['tsm', 'tsn']: # mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]
self.transform = transforms.Compose([
transforms.Resize((344, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def __len__ (self):
return len(self.video_names)

def __getitem__ (self, idx):
video_name = self.video_names[idx]
if '.mp4' not in video_name:
video_frames_dir = os.path.join(self.data_dir, video_name)
frame_names = sorted([f for f in os.listdir(video_frames_dir) if '.png' in f or '.jpg' in f])
num_frame = len(frame_names)
assert num_frame > self.clip_length, \
f"Number of frames should be larger than {self.clip_length}, given {num_frame}"

clip_fidxs = long_range_sample(num_frame, self.clip_length, 'first')
clip_fidxs_tensor = torch.tensor(clip_fidxs).long()

clip_frames = [Image.open(os.path.join(video_frames_dir,
f'{fidx+1:09d}.png')) for fidx in clip_fidxs]
clip_tensor = torch.stack([self.transform(frame) for frame in clip_frames], dim=1)

label_name = video_name.split('-')[0]
label_idx = self.class_namelist.index(label_name)
return clip_tensor, label_idx, video_name, clip_fidxs_tensor
else:
raise Exception('Cannot process MP4 file yet.')


186 changes: 186 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import argparse
import torch
import torchvision
from torch import nn
from torch.utils.data import Dataset, DataLoader

import os
import time
import copy
from tqdm import tqdm
import pickle
import numpy as np
import pandas as pd
import csv
import json

import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append(".")
sys.path.append("..")

from path_dict import PathDict
path_dict = PathDict()
proj_root = path_dict.proj_root

from utils.ImageShow import *

from visual_meth.integrated_grad import integrated_grad
from visual_meth.gradients import gradients
from visual_meth.perturbation import video_perturbation
from visual_meth.grad_cam import grad_cam

parser = argparse.ArgumentParser()
parser.add_argument("--videos_dir", type=str, default='')
parser.add_argument("--model", type=str, default='r2plus1d',
choices=['r2plus1d', 'r3d', 'mc3', 'i3d', 'tsn', 'trn', 'tsm'])
parser.add_argument("--pretrain_dataset", type=str, default='kinetics',
choices=['', 'kinetics', 'epic-kitchens-verb', 'epic-kitchens-noun'])
parser.add_argument("--vis_method", type=str, default='integrated_grad',
choices=['grad', 'grad*input', 'integrated_grad', 'grad_cam', 'perturb'])
parser.add_argument("--save_label", type=str, default='')
parser.add_argument("--no_gpu", action='store_true')
parser.add_argument("--num_iter", type=int, default=1000)
parser.add_argument('--polarity', type=str, default='positive',
choices=['positive', 'negative'])
parser.add_argument('--perturb_area', type=float, default=0.1,
choices=[0.01, 0.02, 0.05, 0.1, 0.15, 0.2])
args = parser.parse_args()

# assert args.num_gpu >= -1
# if args.num_gpu == 0:
# num_devices = 0
# multi_gpu = False
# device = torch.device("cpu")
# elif args.num_gpu == 1:
# num_devices = 1
# multi_gpu = False
# device = torch.device("cuda")
# elif args.num_gpu == -1:
# num_devices = torch.cuda.device_count()
# multi_gpu = (num_devices > 1)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# else:
# num_devices = args.num_gpu
# assert torch.cuda.device_count() >= num_devices, \
# f'Assign {args.num_gpu} GPUs, but only detected only {torch.cuda.device_count()} GPUs. Exiting...'
# multi_gpu = True
# device = torch.device("cuda")

if args.no_gpu:
device = torch.device("cpu")
num_devices = 0
else:
device = torch.device("cuda")
num_devices = 1

assert os.path.isdir(args.videos_dir), \
f'Given directory of data does not exist: {args.videos_dir}.'

if args.pretrain_dataset == 'kinetics':
if args.model == 'i3d':
from model_def.i3d import I3D as model
model_ft = model(num_classes=400)
i3d_pt_dir = os.path.join(proj_root, 'model_param/kinetics400_rgb_i3d.pth')
model_ft.load_state_dict(torch.load(i3d_pt_dir))
clip_length = 16
elif args.model == 'tsm':
from model_def.tsm import tsm as model
model_ft = model(400, segment_count=8, pretrained=args.pretrain_dataset)
clip_length = 8
else: # Load pretrained models from PyTorch directly
clip_length = 16
if args.model == 'r2plus1d':
from torchvision.models.video import r2plus1d_18 as model
elif args.model == 'mc3':
from torchvision.models.video import mc3_18 as model
elif args.model == 'r3d':
from torchvision.models.video import r3d_18 as model
else:
raise Exception(f'Given model of {args.model} has no pretrain on {args.pretrain_dataset}.')
model_ft = model(pretrained=True)

model_ft = model_ft.to(device)
model_ft.eval()
# if multi_gpu:
# model_ft = nn.DataParallel(model_ft, device_ids=list(range(num_devices)))

kinetics400_classes = os.path.join(proj_root, 'test_data/kinetics/classes.json')
class_namelist = json.load(open(kinetics400_classes))

elif 'epic-kitchens' in args.pretrain_dataset:
if 'noun' in args.pretrain_dataset:
epic_classes = os.path.join(proj_root, 'test_data/epic-kitchens-noun/EPIC_noun_classes.csv')
elif 'verb' in args.pretrain_dataset:
epic_classes = os.path.join(proj_root, 'test_data/epic-kitchens-verb/EPIC_verb_classes.csv')
else:
raise Exception(f'EPIC-Kitchens only supports two sub-tasks (noun & verb), given {args.pretrain_dataset}.')
class_namelist = [row['class_key'] for ridx, row in pd.read_csv(epic_classes).iterrows()]
class_num = len(class_namelist)

if args.model == 'tsm':
from model_def.tsm import tsm as model
model_ft = model(class_num, segment_count=8, pretrained=args.pretrain_dataset)
clip_length = 16
else:
raise Exception(f'{args.pretrain_dataset} has only pretrained TSM model. Given {args.model}.')
model_ft = model_ft.to(device)
model_ft.eval()

from datasets.universal_dataset import UniversalDataset as dataset
test_dataset = dataset(args.videos_dir, args.model, class_namelist, clip_length=clip_length)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print(f'Num of test samples:{len(test_dataset)}')

for sample in tqdm(test_dataloader):
inp = sample[0].to(device)
label = sample[1].to(dtype=torch.long)

inp_np = voxel_tensor_to_np(inp[0].detach().cpu()) # 3 x num_f x 224 224

if args.vis_method == 'integrated_grad':
res = integrated_grad(inp, label, model_ft, device, steps=50, polarity=args.polarity)
heatmap_np = res[0].numpy()
elif args.vis_method == 'grad':
res = gradients(inp, label, model_ft, device, polarity=args.polarity)
heatmap_np = res[0].numpy()
elif args.vis_method == 'grad*input':
res = gradients(inp, label, model_ft, device, multiply_input=True, polarity=args.polarity)
heatmap_np = res[0].numpy()
elif args.vis_method == 'grad_cam':
if args.model in ['i3d']:
layer_name = ['mixed_5c']
elif args.model in ['r2plus1d', 'mc3', 'r3d']: # Load pretrained models from PyTorch directly
layer_name = ['layer4']
# elif args.model in ['tsm', 'tsn']:
# layer_name = ['model', 'base_model', 'layer4']
else:
raise Exception(f'Grad-CAM does not support {args.model} currently')
res = grad_cam(inp, label, model_ft, device, layer_name=layer_name, norm_vis=True)
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False)
elif args.vis_method == 'perturb':
sigma = 11 if inp.shape[-1] == 112 else 23
res = video_perturbation(
model_ft, inp, label, areas=[args.perturb_area], sigma=sigma,
max_iter=args.num_iter, variant="preserve",
num_devices=num_devices, print_iter=100, perturb_type="fade")[0]
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False)

sample_name = sample[2][0].split("/")[-1]
plot_save_name = f"{sample_name}.png"
if args.vis_method in ['grad', 'grad*input', 'integrated_grad']:
plot_save_name = plot_save_name.replace('.png', f'{args.polarity}.png')
plot_save_dir = os.path.join(proj_root, "visual_res", args.vis_method, args.model)
if args.save_label != '':
plot_save_dir = os.path.join(plot_save_dir, args.save_label)
os.makedirs(plot_save_dir, exist_ok=True)

show_txt = f"{sample_name}"
plot_voxel_np(inp_np, heatmap_np, title=show_txt,
save_path=os.path.join(plot_save_dir, plot_save_name) )




Loading

0 comments on commit fe388ca

Please sign in to comment.