-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
acb11711tx
committed
Apr 27, 2020
1 parent
77298ba
commit fe388ca
Showing
405 changed files
with
2,781 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
visual_res/* | ||
frames/ | ||
__pycache__ | ||
*.pt | ||
*.pth | ||
*.mp4 | ||
*.pyc | ||
surgery_gesture_r2plus1d.py | ||
surgery* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,46 @@ | ||
# VideoVisual | ||
|
||
This is a PyTorch demo implemented several visualization methods for video classification networks. The target is to provide a toolkit (as [TorchRay](https://github.com/facebookresearch/TorchRay) to image) to interprete commonly utilized video classfication networks, such as I3D, R(2+1)D, TSM et al., which is also called *attribution* task, namely the problem of determining which part of the input video is responsible for the value computed by a neural network. | ||
|
||
The current version supports attribution methods and video classification models as following: | ||
|
||
#### Video classification models: | ||
* **Pretrained on Kinetics-400**: I3D, R(2+1)D, R3D, MC3, TSM; | ||
* **Pretrained on EPIC-Kitchens**: (noun & verb): TSM. | ||
|
||
#### Attribution methods: | ||
* **Backprop-based**: Gradients, Gradients x Inputs, Integrated Gradients; | ||
* **Activation-based**: GradCAM (does not support TSM now); | ||
* **Perturbation-based**: Extremal Perturbation and Spatiotemporal Perturbation (An extension vision of extremal perturbation on video inputs). | ||
|
||
## Requirements | ||
|
||
* Python 3.6.5 or greater | ||
* PyTorch 1.2.0 or greater | ||
* matplotlib==2.2.3 | ||
* numpy==1.14.3 | ||
* opencv_python==4.1.2.30 | ||
* torchvision==0.4.0a0 | ||
* torchray==1.0.0.2 | ||
* tqdm==4.45.0 | ||
* pandas==0.23.3 | ||
* scikit_image==0.15.0 | ||
* Pillow==7.1.2 | ||
* scikit_learn==0.22.2.post1 | ||
|
||
## Running the code | ||
|
||
### Examples | ||
|
||
#### Saptiotemporal Perturbation + I3D (pretrained on Kinetics-400) | ||
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method perturb --num_iter 2000 --perturb_area 0.1` | ||
|
||
#### Spatiotemporal Perturbation + TSM (pretrained on EPIC-Kitchens-noun) | ||
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/epic-kitchens-noun/sampled_frames --model tsm --pretrain_dataset epic-kitchens-noun --vis_method perturb --num_iter 2000 --perturb_area 0.05` | ||
|
||
#### Integrated Gradients + R(2+1)D (pretrained on Kinetics-400) | ||
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method integrated_grad` | ||
|
||
## License | ||
|
||
TorchRay is CC-BY-NC licensed, as found in the [LICENSE](LICENSE) file. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch | ||
from torch import nn | ||
from torchvision import transforms | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
import os | ||
import json | ||
import numpy as np | ||
from tqdm import tqdm | ||
from PIL import Image | ||
|
||
import sys | ||
sys.path.append(".") | ||
sys.path.append("..") | ||
from utils.LongRangeSample import long_range_sample | ||
|
||
class UniversalDataset (Dataset): | ||
def __init__ (self, data_dir, model_name, class_namelist, clip_length=16): | ||
self.data_dir = data_dir | ||
self.model_name = model_name | ||
self.class_namelist = class_namelist | ||
self.clip_length = clip_length | ||
|
||
self.video_names = sorted(os.listdir(data_dir)) | ||
assert len(self.video_names) > 0, f'Given directory contains no video.' | ||
|
||
if model_name == 'i3d': | ||
self.transform = transforms.Compose([ | ||
transforms.Resize((344, 256)), | ||
transforms.CenterCrop((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.43216, 0.39467, 0.37645], [0.22803, 0.22145, 0.21699]), | ||
]) | ||
elif model_name in ['r2plus1d', 'r3d', 'mc3']: | ||
self.transform = transforms.Compose([ | ||
transforms.Resize((172, 128)), | ||
transforms.CenterCrop((112, 112)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.43216, 0.39467, 0.37645], [0.22803, 0.22145, 0.21699]), | ||
]) | ||
elif model_name in ['tsm', 'tsn']: # mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225] | ||
self.transform = transforms.Compose([ | ||
transforms.Resize((344, 256)), | ||
transforms.CenterCrop((224, 224)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
|
||
def __len__ (self): | ||
return len(self.video_names) | ||
|
||
def __getitem__ (self, idx): | ||
video_name = self.video_names[idx] | ||
if '.mp4' not in video_name: | ||
video_frames_dir = os.path.join(self.data_dir, video_name) | ||
frame_names = sorted([f for f in os.listdir(video_frames_dir) if '.png' in f or '.jpg' in f]) | ||
num_frame = len(frame_names) | ||
assert num_frame > self.clip_length, \ | ||
f"Number of frames should be larger than {self.clip_length}, given {num_frame}" | ||
|
||
clip_fidxs = long_range_sample(num_frame, self.clip_length, 'first') | ||
clip_fidxs_tensor = torch.tensor(clip_fidxs).long() | ||
|
||
clip_frames = [Image.open(os.path.join(video_frames_dir, | ||
f'{fidx+1:09d}.png')) for fidx in clip_fidxs] | ||
clip_tensor = torch.stack([self.transform(frame) for frame in clip_frames], dim=1) | ||
|
||
label_name = video_name.split('-')[0] | ||
label_idx = self.class_namelist.index(label_name) | ||
return clip_tensor, label_idx, video_name, clip_fidxs_tensor | ||
else: | ||
raise Exception('Cannot process MP4 file yet.') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import argparse | ||
import torch | ||
import torchvision | ||
from torch import nn | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
import os | ||
import time | ||
import copy | ||
from tqdm import tqdm | ||
import pickle | ||
import numpy as np | ||
import pandas as pd | ||
import csv | ||
import json | ||
|
||
import warnings | ||
warnings.filterwarnings("ignore") | ||
|
||
import sys | ||
sys.path.append(".") | ||
sys.path.append("..") | ||
|
||
from path_dict import PathDict | ||
path_dict = PathDict() | ||
proj_root = path_dict.proj_root | ||
|
||
from utils.ImageShow import * | ||
|
||
from visual_meth.integrated_grad import integrated_grad | ||
from visual_meth.gradients import gradients | ||
from visual_meth.perturbation import video_perturbation | ||
from visual_meth.grad_cam import grad_cam | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--videos_dir", type=str, default='') | ||
parser.add_argument("--model", type=str, default='r2plus1d', | ||
choices=['r2plus1d', 'r3d', 'mc3', 'i3d', 'tsn', 'trn', 'tsm']) | ||
parser.add_argument("--pretrain_dataset", type=str, default='kinetics', | ||
choices=['', 'kinetics', 'epic-kitchens-verb', 'epic-kitchens-noun']) | ||
parser.add_argument("--vis_method", type=str, default='integrated_grad', | ||
choices=['grad', 'grad*input', 'integrated_grad', 'grad_cam', 'perturb']) | ||
parser.add_argument("--save_label", type=str, default='') | ||
parser.add_argument("--no_gpu", action='store_true') | ||
parser.add_argument("--num_iter", type=int, default=1000) | ||
parser.add_argument('--polarity', type=str, default='positive', | ||
choices=['positive', 'negative']) | ||
parser.add_argument('--perturb_area', type=float, default=0.1, | ||
choices=[0.01, 0.02, 0.05, 0.1, 0.15, 0.2]) | ||
args = parser.parse_args() | ||
|
||
# assert args.num_gpu >= -1 | ||
# if args.num_gpu == 0: | ||
# num_devices = 0 | ||
# multi_gpu = False | ||
# device = torch.device("cpu") | ||
# elif args.num_gpu == 1: | ||
# num_devices = 1 | ||
# multi_gpu = False | ||
# device = torch.device("cuda") | ||
# elif args.num_gpu == -1: | ||
# num_devices = torch.cuda.device_count() | ||
# multi_gpu = (num_devices > 1) | ||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
# else: | ||
# num_devices = args.num_gpu | ||
# assert torch.cuda.device_count() >= num_devices, \ | ||
# f'Assign {args.num_gpu} GPUs, but only detected only {torch.cuda.device_count()} GPUs. Exiting...' | ||
# multi_gpu = True | ||
# device = torch.device("cuda") | ||
|
||
if args.no_gpu: | ||
device = torch.device("cpu") | ||
num_devices = 0 | ||
else: | ||
device = torch.device("cuda") | ||
num_devices = 1 | ||
|
||
assert os.path.isdir(args.videos_dir), \ | ||
f'Given directory of data does not exist: {args.videos_dir}.' | ||
|
||
if args.pretrain_dataset == 'kinetics': | ||
if args.model == 'i3d': | ||
from model_def.i3d import I3D as model | ||
model_ft = model(num_classes=400) | ||
i3d_pt_dir = os.path.join(proj_root, 'model_param/kinetics400_rgb_i3d.pth') | ||
model_ft.load_state_dict(torch.load(i3d_pt_dir)) | ||
clip_length = 16 | ||
elif args.model == 'tsm': | ||
from model_def.tsm import tsm as model | ||
model_ft = model(400, segment_count=8, pretrained=args.pretrain_dataset) | ||
clip_length = 8 | ||
else: # Load pretrained models from PyTorch directly | ||
clip_length = 16 | ||
if args.model == 'r2plus1d': | ||
from torchvision.models.video import r2plus1d_18 as model | ||
elif args.model == 'mc3': | ||
from torchvision.models.video import mc3_18 as model | ||
elif args.model == 'r3d': | ||
from torchvision.models.video import r3d_18 as model | ||
else: | ||
raise Exception(f'Given model of {args.model} has no pretrain on {args.pretrain_dataset}.') | ||
model_ft = model(pretrained=True) | ||
|
||
model_ft = model_ft.to(device) | ||
model_ft.eval() | ||
# if multi_gpu: | ||
# model_ft = nn.DataParallel(model_ft, device_ids=list(range(num_devices))) | ||
|
||
kinetics400_classes = os.path.join(proj_root, 'test_data/kinetics/classes.json') | ||
class_namelist = json.load(open(kinetics400_classes)) | ||
|
||
elif 'epic-kitchens' in args.pretrain_dataset: | ||
if 'noun' in args.pretrain_dataset: | ||
epic_classes = os.path.join(proj_root, 'test_data/epic-kitchens-noun/EPIC_noun_classes.csv') | ||
elif 'verb' in args.pretrain_dataset: | ||
epic_classes = os.path.join(proj_root, 'test_data/epic-kitchens-verb/EPIC_verb_classes.csv') | ||
else: | ||
raise Exception(f'EPIC-Kitchens only supports two sub-tasks (noun & verb), given {args.pretrain_dataset}.') | ||
class_namelist = [row['class_key'] for ridx, row in pd.read_csv(epic_classes).iterrows()] | ||
class_num = len(class_namelist) | ||
|
||
if args.model == 'tsm': | ||
from model_def.tsm import tsm as model | ||
model_ft = model(class_num, segment_count=8, pretrained=args.pretrain_dataset) | ||
clip_length = 16 | ||
else: | ||
raise Exception(f'{args.pretrain_dataset} has only pretrained TSM model. Given {args.model}.') | ||
model_ft = model_ft.to(device) | ||
model_ft.eval() | ||
|
||
from datasets.universal_dataset import UniversalDataset as dataset | ||
test_dataset = dataset(args.videos_dir, args.model, class_namelist, clip_length=clip_length) | ||
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) | ||
print(f'Num of test samples:{len(test_dataset)}') | ||
|
||
for sample in tqdm(test_dataloader): | ||
inp = sample[0].to(device) | ||
label = sample[1].to(dtype=torch.long) | ||
|
||
inp_np = voxel_tensor_to_np(inp[0].detach().cpu()) # 3 x num_f x 224 224 | ||
|
||
if args.vis_method == 'integrated_grad': | ||
res = integrated_grad(inp, label, model_ft, device, steps=50, polarity=args.polarity) | ||
heatmap_np = res[0].numpy() | ||
elif args.vis_method == 'grad': | ||
res = gradients(inp, label, model_ft, device, polarity=args.polarity) | ||
heatmap_np = res[0].numpy() | ||
elif args.vis_method == 'grad*input': | ||
res = gradients(inp, label, model_ft, device, multiply_input=True, polarity=args.polarity) | ||
heatmap_np = res[0].numpy() | ||
elif args.vis_method == 'grad_cam': | ||
if args.model in ['i3d']: | ||
layer_name = ['mixed_5c'] | ||
elif args.model in ['r2plus1d', 'mc3', 'r3d']: # Load pretrained models from PyTorch directly | ||
layer_name = ['layer4'] | ||
# elif args.model in ['tsm', 'tsn']: | ||
# layer_name = ['model', 'base_model', 'layer4'] | ||
else: | ||
raise Exception(f'Grad-CAM does not support {args.model} currently') | ||
res = grad_cam(inp, label, model_ft, device, layer_name=layer_name, norm_vis=True) | ||
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False) | ||
elif args.vis_method == 'perturb': | ||
sigma = 11 if inp.shape[-1] == 112 else 23 | ||
res = video_perturbation( | ||
model_ft, inp, label, areas=[args.perturb_area], sigma=sigma, | ||
max_iter=args.num_iter, variant="preserve", | ||
num_devices=num_devices, print_iter=100, perturb_type="fade")[0] | ||
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False) | ||
|
||
sample_name = sample[2][0].split("/")[-1] | ||
plot_save_name = f"{sample_name}.png" | ||
if args.vis_method in ['grad', 'grad*input', 'integrated_grad']: | ||
plot_save_name = plot_save_name.replace('.png', f'{args.polarity}.png') | ||
plot_save_dir = os.path.join(proj_root, "visual_res", args.vis_method, args.model) | ||
if args.save_label != '': | ||
plot_save_dir = os.path.join(plot_save_dir, args.save_label) | ||
os.makedirs(plot_save_dir, exist_ok=True) | ||
|
||
show_txt = f"{sample_name}" | ||
plot_voxel_np(inp_np, heatmap_np, title=show_txt, | ||
save_path=os.path.join(plot_save_dir, plot_save_name) ) | ||
|
||
|
||
|
||
|
Oops, something went wrong.