Skip to content

Commit

Permalink
Added STEP and updated readme.
Browse files Browse the repository at this point in the history
  • Loading branch information
acb11711tx committed Dec 29, 2020
1 parent 42e74b6 commit b425a72
Show file tree
Hide file tree
Showing 13 changed files with 643 additions and 88 deletions.
31 changes: 19 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ The current version supports attribution methods and video classification models
#### Attribution methods:
* **Backprop-based**: Gradients, Gradients x Inputs, Integrated Gradients;
* **Activation-based**: GradCAM (does not support TSM now);
* **Perturbation-based**: Extremal Perturbation and Spatiotemporal Perturbation (An extension version of extremal perturbation on video inputs).
* **Perturbation-based**:
* **2D-EP**: An extended version of Entremal Perturbations on the video input that perturbs each frame separately and regularizes the perturbation area in each frame to the target ratio equally.
* **3D-EP**: An extended version of Entremal Perturbations on the video input that perturbs across all frames and regularizes the whole perturbation area in all frames to the target ratio.
* **STEP**: Spatio-Temporal Extremal Perturbations with a special regularization term for the spatiotemporal smoothness in the video attribution results.

## Requirements

Expand All @@ -37,7 +40,7 @@ The current version supports attribution methods and video classification models
* **videos_dir**: Directory for video frames. Frames belonging to one video should be put in one file under the directory, and the first part splited by '-' will be considered as label name.
* **model**: Name of test model. Default is R(2+1)D, choices include R(2+1)D, R3D, MC3, I3D and TSM currently.
* **pretrain_dataset**: Dataset name that test model pretrained on. Choices include 'kinetics', 'epic-kitchens-verb', 'epic-kitchens-noun'.
* **vis_method**: Name of visualization methods. Choices include 'grad', 'grad*input', 'integrated_grad', 'grad_cam', 'perturb'. Here the 'perturb' means is spatiotemporal perturbation method.
* **vis_method**: Name of visualization methods. Choices include 'grad', 'grad*input', 'integrated_grad', 'grad_cam', '2d_ep', '3d_ep', 'step'.
* **save_label**: Extra label for saving results. If given, visualization results will be saved in ./visual_res/$vis_method$/$model$/$save_label$.
* **no_gpu**: If set, the demo will be run on CPU, else run on only one GPU.

Expand All @@ -50,36 +53,40 @@ Arguments for gradient methods:

### Examples

#### Saptiotemporal Perturbation + I3D (pretrained on Kinetics-400)
`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method perturb --num_iter 2000 --perturb_area 0.1`
#### Saptiotemporal Perturbation + R(2+1)D (pretrained on Kinetics-400)
`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method step --num_iter 2000 --perturb_area 0.1`

#### Spatiotemporal Perturbation + TSM (pretrained on EPIC-Kitchens-noun)
`$ python main.py --videos_dir VideoVisual/test_data/epic-kitchens-noun/sampled_frames --model tsm --pretrain_dataset epic-kitchens-noun --vis_method perturb --num_iter 2000 --perturb_area 0.05`

#### Integrated Gradients + R(2+1)D (pretrained on Kinetics-400)
`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method integrated_grad`
#### Integrated Gradients + I3D (pretrained on Kinetics-400)
`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method integrated_grad`


## Results

### Kinectis-400 (GT = ironing)
![Kinectis-400 (GT = ironing)](figures/res_fig_kinetics.png)

'Perturbation' denotes 3D-EP here.
### EPIC-Kitchens-Noun (GT = cupboard)
![EPIC-Kitchens-Noun (GT = cupboard)](figures/res_fig_epic.png)
'Perturbation' denotes 3D-EP here.

### GIF visualization of perturbation results (on UCF101 dataset)
#### Long Jump
### GIF visualization of perturbation results (on UCF101 dataset by STEP)
<!-- #### Long Jump
![ucf101-longjump](figures/v_LongJump_g01_c06_frames.gif) ![ucf101-longjump](figures/v_LongJump_g01_c06_ptb.gif)
#### Walking With Dog
![ucf101-walikingdog](figures/v_WalkingWithDog_g06_c05_frames.gif) ![ucf101-walikingdog](figures/v_WalkingWithDog_g06_c05_ptb.gif)
![ucf101-walikingdog](figures/v_WalkingWithDog_g06_c05_frames.gif) ![ucf101-walikingdog](figures/v_WalkingWithDog_g06_c05_ptb.gif) -->
| <div style="width:150px">Basketball 5%</div> | <div style="width:150px">Skijet 5%</div> | <div style="width:150px">Walking-With-Dog 10%</div> | <div style="width:150px">OpenFridge 10%</div> | <div style="width:150px">CloseDrawer 10%</div> | <div style="width:150px">OpenCupboard 15%</div> |
| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
| <img src="figures/step_gif_res/basketball_05.gif" width=150/> | <img src="figures/step_gif_res/skijet_05.gif" width=150/> | <img src="figures/step_gif_res/walking_with_dog_10.gif" width=150/> | <img src="figures/step_gif_res/fencing_10.gif" width=150/> | <img src="figures/step_gif_res/open_fridge_10.gif" width=150/> | <img src="figures/step_gif_res/close_drawer_10.gif" width=150/> | <img src="figures/step_gif_res/open_cupboard_15.gif" width=150/> |

## Reference

### Ours preprint
### Ours preprint (to appear in WACV2021):
```
@article{li2020comprehensive,
title={A Comprehensive Study on Visual Explanations for Spatio-temporal Networks},
title={Towards Visually Explaining Video Understanding Networks with Perturbation},
author={Li, Zhenqiang and Wang, Weimin and Li, Zuoyue and Huang, Yifei and Sato, Yoichi},
journal={arXiv preprint arXiv:2005.00375},
year={2020}
Expand Down
Binary file added figures/step_gif_res/.DS_Store
Binary file not shown.
Binary file added figures/step_gif_res/basketball_05.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/step_gif_res/close_drawer_10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/step_gif_res/fencing_10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/step_gif_res/open_cupboard_15.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/step_gif_res/open_fridge_10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/step_gif_res/skijet_05.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/step_gif_res/walking_with_dog_10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 5 additions & 32 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,14 @@
from visual_meth.perturbation import video_perturbation
from visual_meth.grad_cam import grad_cam

from visual_meth.perturbation_area import spatiotemporal_perturbation

parser = argparse.ArgumentParser()
parser.add_argument("--videos_dir", type=str, default='')
parser.add_argument("--model", type=str, default='r2plus1d',
choices=['r2plus1d', 'r3d', 'mc3', 'i3d', 'tsn', 'trn', 'tsm'])
parser.add_argument("--pretrain_dataset", type=str, default='kinetics',
choices=['', 'kinetics', 'epic-kitchens-verb', 'epic-kitchens-noun'])
parser.add_argument("--vis_method", type=str, default='integrated_grad',
choices=['grad', 'grad*input', 'integrated_grad', 'smooth_grad', 'grad_cam', 'perturb'])
choices=['grad', 'grad*input', 'integrated_grad', 'smooth_grad', 'grad_cam', 'step', '3d_ep', '2d_ep'])
parser.add_argument("--save_label", type=str, default='')
parser.add_argument("--no_gpu", action='store_true')
parser.add_argument("--num_iter", type=int, default=2000)
Expand All @@ -52,26 +50,6 @@
choices=['positive', 'negative', 'both'])
args = parser.parse_args()

# assert args.num_gpu >= -1
# if args.num_gpu == 0:
# num_devices = 0
# multi_gpu = False
# device = torch.device("cpu")
# elif args.num_gpu == 1:
# num_devices = 1
# multi_gpu = False
# device = torch.device("cuda")
# elif args.num_gpu == -1:
# num_devices = torch.cuda.device_count()
# multi_gpu = (num_devices > 1)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# else:
# num_devices = args.num_gpu
# assert torch.cuda.device_count() >= num_devices, \
# f'Assign {args.num_gpu} GPUs, but only detected only {torch.cuda.device_count()} GPUs. Exiting...'
# multi_gpu = True
# device = torch.device("cuda")

if args.no_gpu:
device = torch.device("cpu")
num_devices = 0
Expand Down Expand Up @@ -166,17 +144,12 @@
raise Exception(f'Grad-CAM does not support {args.model} currently')
res = grad_cam(inp, label, model_ft, device, layer_name=layer_name, norm_vis=True)
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False)
elif args.vis_method == 'perturb':
elif 'ep' in args.vis_method:
sigma = 11 if inp.shape[-1] == 112 else 23
res = video_perturbation(
model_ft, inp, label, areas=[args.perturb_area], sigma=sigma,
max_iter=args.num_iter, variant="preserve",
num_devices=num_devices, print_iter=100, perturb_type="fade")[0]
# res = spatiotemporal_perturbation(
# model_ft, inp, label, areas=[0.1, 0.2, 0.3], sigma=sigma,
# max_iter=args.num_iter, variant="preserve",
# num_devices=num_devices, print_iter=100, perturb_type="fade")[0]
# print(res.shape)
model_ft, inp, label, method=args.vis_method, areas=[args.perturb_area],
sigma=sigma, max_iter=args.num_iter, variant="preserve",
num_devices=num_devices, print_iter=200, perturb_type="blur")[0]
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False)

sample_name = sample[2][0].split("/")[-1]
Expand Down
147 changes: 105 additions & 42 deletions visual_meth/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,32 +207,108 @@ def __str__(self):
f"- pyramid shape: {list(self.pyramid.shape)}"
)

def video_perturbation(model, input, target,
areas=[0.1], method='spatiotemporal', perturb_type=FADE_PERTURBATION,

class CoreLossCalulator:
def __init__ (self, bs, nt, nrow, ncol, areas, device,
num_key_frames=7, spatial_range=(11, 11),
core_shape='ellipsoid'):
self.bs = bs
self.nt = nt
self.nrow = nrow
self.ncol = ncol
self.areas = areas
self.narea = len(areas)
self.device = device

self.new_nrow = nrow
self.new_ncol = ncol

self.num_key_frames = num_key_frames
self.spatial_range = spatial_range
self.core_shape = core_shape
self.conv_stride = spatial_range[0]
self.core_kernel_ones = int(num_key_frames * spatial_range[0] * spatial_range[1] * 0.5236)
self.core_kernels = []
self.core_topks = []

for a_idx, area in enumerate(areas):
if self.core_shape == 'ellipsoid':
core_kernel = self.get_ellipsoid_kernel([num_key_frames, spatial_range[0], spatial_range[1]]).to(device)
elif self.core_shape == 'cylinder':
core_kernel = self.get_cylinder_kernel([num_key_frames, spatial_range[0], spatial_range[1]]).to(device)
else:
raise Exception(f'Unsurported core_shape, given {self.core_shape}')
self.core_kernels.append(core_kernel)
self.core_topks.append(math.ceil((area*self.nt*self.new_nrow*self.new_ncol) / self.core_kernel_ones))

# mask: AxNx1xTx S_out x S_out
def calculate (self, masks):
losses = []
small_masks = masks

losses = []
for a_idx, area in enumerate(self.areas):
core_kernel = self.core_kernels[a_idx]
core_kernel_sum = core_kernel.sum()
mask_conv = F.conv3d(small_masks[a_idx], core_kernel, bias=None,
stride=(1, self.conv_stride, self.conv_stride),
padding=0, dilation=1, groups=1)
# mask_conv_sorted = mask_conv.view(self.bs, -1).sort(dim=1, descending=True)[0]
mask_conv_sorted = mask_conv.view(self.bs, -1)
shuffled_inds = torch.randperm(mask_conv_sorted.shape[1], device=mask_conv_sorted.device)
mask_conv_sorted = mask_conv_sorted[:, shuffled_inds].sort(dim=1, descending=True)[0]
# loss = ((mask_conv_sorted[:, :self.core_topks[a_idx]] - 1) ** 2).mean(dim=1) \
# + ((mask_conv_sorted[:, self.core_topks[a_idx]:] - 0) ** 2).mean(dim=1)
loss = ((mask_conv_sorted[:, :self.core_topks[a_idx]]/core_kernel_sum - 1) ** 2).mean(dim=1) \
+ ((mask_conv_sorted[:, self.core_topks[a_idx]:]/core_kernel_sum - 0) ** 2).mean(dim=1)
losses.append(loss)
losses = torch.stack(losses, dim=0) # A x N
return losses

def get_ellipsoid_kernel(self, dims):
assert(len(dims) == 3)
d1, d2, d3 = dims
v1, v2, v3 = np.meshgrid(np.linspace(-1, 1, d1), np.linspace(-1, 1, d2), np.linspace(-1, 1, d3), indexing='ij')
dist = np.sqrt(v1 ** 2 + v2 ** 2 + v3 ** 2)
return torch.from_numpy((dist <= 1)[np.newaxis, np.newaxis, ...]).float()

def get_cylinder_kernel(self, dims):
assert(len(dims) == 3)
d1, d2, d3 = dims
v1, v2, v3 = np.meshgrid(np.linspace(-1, 1, d1), np.linspace(-1, 1, d2), np.linspace(-1, 1, d3), indexing='ij')
dist = np.sqrt(v2 ** 2 + v3 ** 2)
return torch.from_numpy((dist <= 1)[np.newaxis, np.newaxis, ...]).float()

def video_perturbation(model, input, target, method,
areas=[0.1], perturb_type=FADE_PERTURBATION,
max_iter=2000, num_levels=8, step=7, sigma=11, variant=PRESERVE_VARIANT,
print_iter=None, debug=False, reward_func="simple_log", resize=False,
resize_mode='bilinear', smooth=0, num_devices=1):
resize_mode='bilinear', smooth=0, num_devices=1,
core_num_keyframe=7, core_spatial_size=11, core_shape='ellipsoid'):

if isinstance(areas, float):
areas = [areas]
momentum = 0.9
learning_rate = 0.05
regul_weight = 300
reward_weight = 100
if method == 'step':
reward_weight = 1
else:
reward_weight = 100
core_weight = 0
device = input.device

iter_period = 2000

regul_weight_last = max(regul_weight / 2, 1)

# input shape: NxCxTxHxW (1x3x16x112x112)
batch_size = input.shape[0]
num_frame = input.shape[2] #16
num_areas = len(areas)

if debug:
print(
f"spatiotemporal_perturbation:\n"
f"video_perturbation:\n"
f"- method: {method}\n"
f"- target: {target}\n"
f"- areas: {areas}\n"
f"- variant: {variant}\n"
Expand Down Expand Up @@ -270,10 +346,18 @@ def video_perturbation(model, input, target,
max_volume = np.prod(mask_generator.shape_out) * num_frame

# Prepare reference area vector.
if method == 'spatiotemporal':
if method == 'step':
mask_core = CoreLossCalulator(batch_size, num_frame,
mask_generator.shape_out[0], mask_generator.shape_out[1],
areas, device, num_key_frames=core_num_keyframe,
spatial_range=(core_spatial_size, core_spatial_size),
core_shape=core_shape)
vref = torch.ones(batch_size, max_volume).to(device)
vref[:, :int(max_volume * (1 - areas[0]))] = 0
elif method == '3d_ep':
vref = torch.ones(batch_size, max_volume).to(device)
vref[:, :int(max_volume * (1 - areas[0]))] = 0
elif method == 'extremal':
elif method == '2d_ep':
aref = torch.ones(batch_size, num_frame, max_area).to(device)
aref[:, :, :int(max_area * (1 - areas[0]))] = 0

Expand Down Expand Up @@ -326,10 +410,14 @@ def video_perturbation(model, input, target,

# Area regularization.
# padded_masks: N x 1 x T x S_out x S_out
if method == 'spatiotemporal':
if method == 'step':
mask_sorted = padded_masks.squeeze(1).reshape(batch_size, -1).sort(dim=1)[0]
regul = ((mask_sorted - vref)**2).mean(dim=1) * regul_weight # batch_size
elif method == 'extremal':
regul += mask_core.calculate(padded_masks.contiguous().unsqueeze(0)).squeeze(0) * core_weight # batch_size
elif method == '3d_ep':
mask_sorted = padded_masks.squeeze(1).reshape(batch_size, -1).sort(dim=1)[0]
regul = ((mask_sorted - vref)**2).mean(dim=1) * regul_weight # batch_size
elif method == '2d_ep':
mask_sorted = padded_masks.squeeze(1).reshape(batch_size, num_frame, -1).sort(dim=2)[0]
regul = ((mask_sorted - aref)**2).mean(dim=2).mean(dim=1) * regul_weight # batch_size

Expand All @@ -343,6 +431,11 @@ def video_perturbation(model, input, target,

pmasks.data = pmasks.data.clamp(0, 1)

if method == 'step':
regul_weight *= 1.0008668 #1.00069^800=2
if t == 100:
core_weight = 300

# Record energy
# hist: batch_size x 2 x num_iter
hist_item = torch.cat(
Expand All @@ -353,9 +446,6 @@ def video_perturbation(model, input, target,
# print(f"Item shape: {hist_item.shape}")
hist = torch.cat((hist,hist_item), dim=2)

# Diagnostics.
# debug_this_iter = debug and (regul_weight / regul_weight_last >= 2)

# if (print_iter is not None and t % print_iter == 0) or debug_this_iter:
if (print_iter != None) and (t % print_iter == 0):
print("[{:04d}/{:04d}]".format(t + 1, max_iter), end="\n")
Expand All @@ -371,33 +461,6 @@ def video_perturbation(model, input, target,
variant, prob[i], pred_label[i]), end="")
print()

# if debug_this_iter:
# regul_weight_last = regul_weight
# # for i, a in enumerate(areas):
# for i in range(batch_size):
# plt.figure(i, figsize=(20, 6))
# plt.clf()
# ncols = 4 if variant == DUAL_VARIANT else 3
# plt.subplot(1, ncols, 1)
# plt.plot(hist[i, 0].numpy())
# plt.plot(hist[i, 1].numpy())
# plt.plot(hist[i].sum(dim=0).numpy())
# plt.legend(('energy', 'regul', 'both'))
# plt.title(f'target area:{areas[0]:.2f}')
# mask = padded_masks[i,:,7,:,:]
# plt.subplot(1, ncols, 2)
# imsc(mask, lim=[0, 1])
# plt.title(
# f"min:{mask.min().item():.2f}"
# f" max:{mask.max().item():.2f}"
# f" area:{mask.sum() / mask.numel():.2f}")
# plt.subplot(1, ncols, 3)
# imsc(perturb_x[i,:,7,:,:])
# if variant == DUAL_VARIANT:
# plt.subplot(1, ncols, 4)
# imsc(perturb_x[i + batch_size,:,7,:,:])
# plt.pause(0.001)

masks = masks.detach()
# Resize saliency map.
list_mask = []
Expand All @@ -411,4 +474,4 @@ def video_perturbation(model, input, target,
masks = torch.stack(list_mask, dim=2) # NxCxTxHxW

# masks: NxCxTxHxW; hist: Nx3xmax_iter
return masks, hist, perturb_x
return masks, hist
2 changes: 0 additions & 2 deletions visual_meth/perturbation_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
DELETE_VARIANT = "delete"
DUAL_VARIANT = "dual"


def simple_log_reward(activation, target, variant):
N = target.shape[0]
bs = activation.shape[0]
Expand Down Expand Up @@ -147,7 +146,6 @@ def to(self, dev):
self.weight = self.weight.to(dev)
return self


class Perturbation:
def __init__(self, input, num_levels=8, max_blur=20, type=BLUR_PERTURBATION):
self.type = type
Expand Down
Loading

0 comments on commit b425a72

Please sign in to comment.