diff --git a/README.md b/README.md
index 1e844e1..a5dc4f4 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,10 @@ The current version supports attribution methods and video classification models
#### Attribution methods:
* **Backprop-based**: Gradients, Gradients x Inputs, Integrated Gradients;
* **Activation-based**: GradCAM (does not support TSM now);
-* **Perturbation-based**: Extremal Perturbation and Spatiotemporal Perturbation (An extension version of extremal perturbation on video inputs).
+* **Perturbation-based**:
+ * **2D-EP**: An extended version of Entremal Perturbations on the video input that perturbs each frame separately and regularizes the perturbation area in each frame to the target ratio equally.
+ * **3D-EP**: An extended version of Entremal Perturbations on the video input that perturbs across all frames and regularizes the whole perturbation area in all frames to the target ratio.
+ * **STEP**: Spatio-Temporal Extremal Perturbations with a special regularization term for the spatiotemporal smoothness in the video attribution results.
## Requirements
@@ -37,7 +40,7 @@ The current version supports attribution methods and video classification models
* **videos_dir**: Directory for video frames. Frames belonging to one video should be put in one file under the directory, and the first part splited by '-' will be considered as label name.
* **model**: Name of test model. Default is R(2+1)D, choices include R(2+1)D, R3D, MC3, I3D and TSM currently.
* **pretrain_dataset**: Dataset name that test model pretrained on. Choices include 'kinetics', 'epic-kitchens-verb', 'epic-kitchens-noun'.
-* **vis_method**: Name of visualization methods. Choices include 'grad', 'grad*input', 'integrated_grad', 'grad_cam', 'perturb'. Here the 'perturb' means is spatiotemporal perturbation method.
+* **vis_method**: Name of visualization methods. Choices include 'grad', 'grad*input', 'integrated_grad', 'grad_cam', '2d_ep', '3d_ep', 'step'.
* **save_label**: Extra label for saving results. If given, visualization results will be saved in ./visual_res/$vis_method$/$model$/$save_label$.
* **no_gpu**: If set, the demo will be run on CPU, else run on only one GPU.
@@ -50,36 +53,40 @@ Arguments for gradient methods:
### Examples
-#### Saptiotemporal Perturbation + I3D (pretrained on Kinetics-400)
-`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method perturb --num_iter 2000 --perturb_area 0.1`
+#### Saptiotemporal Perturbation + R(2+1)D (pretrained on Kinetics-400)
+`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method step --num_iter 2000 --perturb_area 0.1`
#### Spatiotemporal Perturbation + TSM (pretrained on EPIC-Kitchens-noun)
`$ python main.py --videos_dir VideoVisual/test_data/epic-kitchens-noun/sampled_frames --model tsm --pretrain_dataset epic-kitchens-noun --vis_method perturb --num_iter 2000 --perturb_area 0.05`
-#### Integrated Gradients + R(2+1)D (pretrained on Kinetics-400)
-`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method integrated_grad`
+#### Integrated Gradients + I3D (pretrained on Kinetics-400)
+`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method integrated_grad`
## Results
### Kinectis-400 (GT = ironing)

-
+'Perturbation' denotes 3D-EP here.
### EPIC-Kitchens-Noun (GT = cupboard)

+'Perturbation' denotes 3D-EP here.
-### GIF visualization of perturbation results (on UCF101 dataset)
-#### Long Jump
+### GIF visualization of perturbation results (on UCF101 dataset by STEP)
+
+|
Basketball 5%
| Skijet 5%
| Walking-With-Dog 10%
| OpenFridge 10%
| CloseDrawer 10%
| OpenCupboard 15%
|
+| ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- |
+|
|
|
|
|
|
|
|
## Reference
-### Ours preprint
+### Ours preprint (to appear in WACV2021):
```
@article{li2020comprehensive,
- title={A Comprehensive Study on Visual Explanations for Spatio-temporal Networks},
+ title={Towards Visually Explaining Video Understanding Networks with Perturbation},
author={Li, Zhenqiang and Wang, Weimin and Li, Zuoyue and Huang, Yifei and Sato, Yoichi},
journal={arXiv preprint arXiv:2005.00375},
year={2020}
diff --git a/figures/step_gif_res/.DS_Store b/figures/step_gif_res/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/figures/step_gif_res/.DS_Store differ
diff --git a/figures/step_gif_res/basketball_05.gif b/figures/step_gif_res/basketball_05.gif
new file mode 100644
index 0000000..267f165
Binary files /dev/null and b/figures/step_gif_res/basketball_05.gif differ
diff --git a/figures/step_gif_res/close_drawer_10.gif b/figures/step_gif_res/close_drawer_10.gif
new file mode 100644
index 0000000..36e7472
Binary files /dev/null and b/figures/step_gif_res/close_drawer_10.gif differ
diff --git a/figures/step_gif_res/fencing_10.gif b/figures/step_gif_res/fencing_10.gif
new file mode 100644
index 0000000..53164bc
Binary files /dev/null and b/figures/step_gif_res/fencing_10.gif differ
diff --git a/figures/step_gif_res/open_cupboard_15.gif b/figures/step_gif_res/open_cupboard_15.gif
new file mode 100644
index 0000000..5be871e
Binary files /dev/null and b/figures/step_gif_res/open_cupboard_15.gif differ
diff --git a/figures/step_gif_res/open_fridge_10.gif b/figures/step_gif_res/open_fridge_10.gif
new file mode 100644
index 0000000..268d6d6
Binary files /dev/null and b/figures/step_gif_res/open_fridge_10.gif differ
diff --git a/figures/step_gif_res/skijet_05.gif b/figures/step_gif_res/skijet_05.gif
new file mode 100644
index 0000000..b5ffa7a
Binary files /dev/null and b/figures/step_gif_res/skijet_05.gif differ
diff --git a/figures/step_gif_res/walking_with_dog_10.gif b/figures/step_gif_res/walking_with_dog_10.gif
new file mode 100644
index 0000000..1642281
Binary files /dev/null and b/figures/step_gif_res/walking_with_dog_10.gif differ
diff --git a/main.py b/main.py
index 41db11d..c79a0f1 100644
--- a/main.py
+++ b/main.py
@@ -33,8 +33,6 @@
from visual_meth.perturbation import video_perturbation
from visual_meth.grad_cam import grad_cam
-from visual_meth.perturbation_area import spatiotemporal_perturbation
-
parser = argparse.ArgumentParser()
parser.add_argument("--videos_dir", type=str, default='')
parser.add_argument("--model", type=str, default='r2plus1d',
@@ -42,7 +40,7 @@
parser.add_argument("--pretrain_dataset", type=str, default='kinetics',
choices=['', 'kinetics', 'epic-kitchens-verb', 'epic-kitchens-noun'])
parser.add_argument("--vis_method", type=str, default='integrated_grad',
- choices=['grad', 'grad*input', 'integrated_grad', 'smooth_grad', 'grad_cam', 'perturb'])
+ choices=['grad', 'grad*input', 'integrated_grad', 'smooth_grad', 'grad_cam', 'step', '3d_ep', '2d_ep'])
parser.add_argument("--save_label", type=str, default='')
parser.add_argument("--no_gpu", action='store_true')
parser.add_argument("--num_iter", type=int, default=2000)
@@ -52,26 +50,6 @@
choices=['positive', 'negative', 'both'])
args = parser.parse_args()
-# assert args.num_gpu >= -1
-# if args.num_gpu == 0:
-# num_devices = 0
-# multi_gpu = False
-# device = torch.device("cpu")
-# elif args.num_gpu == 1:
-# num_devices = 1
-# multi_gpu = False
-# device = torch.device("cuda")
-# elif args.num_gpu == -1:
-# num_devices = torch.cuda.device_count()
-# multi_gpu = (num_devices > 1)
-# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-# else:
-# num_devices = args.num_gpu
-# assert torch.cuda.device_count() >= num_devices, \
-# f'Assign {args.num_gpu} GPUs, but only detected only {torch.cuda.device_count()} GPUs. Exiting...'
-# multi_gpu = True
-# device = torch.device("cuda")
-
if args.no_gpu:
device = torch.device("cpu")
num_devices = 0
@@ -166,17 +144,12 @@
raise Exception(f'Grad-CAM does not support {args.model} currently')
res = grad_cam(inp, label, model_ft, device, layer_name=layer_name, norm_vis=True)
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False)
- elif args.vis_method == 'perturb':
+ elif 'ep' in args.vis_method:
sigma = 11 if inp.shape[-1] == 112 else 23
res = video_perturbation(
- model_ft, inp, label, areas=[args.perturb_area], sigma=sigma,
- max_iter=args.num_iter, variant="preserve",
- num_devices=num_devices, print_iter=100, perturb_type="fade")[0]
- # res = spatiotemporal_perturbation(
- # model_ft, inp, label, areas=[0.1, 0.2, 0.3], sigma=sigma,
- # max_iter=args.num_iter, variant="preserve",
- # num_devices=num_devices, print_iter=100, perturb_type="fade")[0]
- # print(res.shape)
+ model_ft, inp, label, method=args.vis_method, areas=[args.perturb_area],
+ sigma=sigma, max_iter=args.num_iter, variant="preserve",
+ num_devices=num_devices, print_iter=200, perturb_type="blur")[0]
heatmap_np = overlap_maps_on_voxel_np(inp_np, res[0,0].cpu().numpy(), norm_map=False)
sample_name = sample[2][0].split("/")[-1]
diff --git a/visual_meth/perturbation.py b/visual_meth/perturbation.py
index e1ee734..3abcaec 100644
--- a/visual_meth/perturbation.py
+++ b/visual_meth/perturbation.py
@@ -207,24 +207,99 @@ def __str__(self):
f"- pyramid shape: {list(self.pyramid.shape)}"
)
-def video_perturbation(model, input, target,
- areas=[0.1], method='spatiotemporal', perturb_type=FADE_PERTURBATION,
+
+class CoreLossCalulator:
+ def __init__ (self, bs, nt, nrow, ncol, areas, device,
+ num_key_frames=7, spatial_range=(11, 11),
+ core_shape='ellipsoid'):
+ self.bs = bs
+ self.nt = nt
+ self.nrow = nrow
+ self.ncol = ncol
+ self.areas = areas
+ self.narea = len(areas)
+ self.device = device
+
+ self.new_nrow = nrow
+ self.new_ncol = ncol
+
+ self.num_key_frames = num_key_frames
+ self.spatial_range = spatial_range
+ self.core_shape = core_shape
+ self.conv_stride = spatial_range[0]
+ self.core_kernel_ones = int(num_key_frames * spatial_range[0] * spatial_range[1] * 0.5236)
+ self.core_kernels = []
+ self.core_topks = []
+
+ for a_idx, area in enumerate(areas):
+ if self.core_shape == 'ellipsoid':
+ core_kernel = self.get_ellipsoid_kernel([num_key_frames, spatial_range[0], spatial_range[1]]).to(device)
+ elif self.core_shape == 'cylinder':
+ core_kernel = self.get_cylinder_kernel([num_key_frames, spatial_range[0], spatial_range[1]]).to(device)
+ else:
+ raise Exception(f'Unsurported core_shape, given {self.core_shape}')
+ self.core_kernels.append(core_kernel)
+ self.core_topks.append(math.ceil((area*self.nt*self.new_nrow*self.new_ncol) / self.core_kernel_ones))
+
+ # mask: AxNx1xTx S_out x S_out
+ def calculate (self, masks):
+ losses = []
+ small_masks = masks
+
+ losses = []
+ for a_idx, area in enumerate(self.areas):
+ core_kernel = self.core_kernels[a_idx]
+ core_kernel_sum = core_kernel.sum()
+ mask_conv = F.conv3d(small_masks[a_idx], core_kernel, bias=None,
+ stride=(1, self.conv_stride, self.conv_stride),
+ padding=0, dilation=1, groups=1)
+ # mask_conv_sorted = mask_conv.view(self.bs, -1).sort(dim=1, descending=True)[0]
+ mask_conv_sorted = mask_conv.view(self.bs, -1)
+ shuffled_inds = torch.randperm(mask_conv_sorted.shape[1], device=mask_conv_sorted.device)
+ mask_conv_sorted = mask_conv_sorted[:, shuffled_inds].sort(dim=1, descending=True)[0]
+ # loss = ((mask_conv_sorted[:, :self.core_topks[a_idx]] - 1) ** 2).mean(dim=1) \
+ # + ((mask_conv_sorted[:, self.core_topks[a_idx]:] - 0) ** 2).mean(dim=1)
+ loss = ((mask_conv_sorted[:, :self.core_topks[a_idx]]/core_kernel_sum - 1) ** 2).mean(dim=1) \
+ + ((mask_conv_sorted[:, self.core_topks[a_idx]:]/core_kernel_sum - 0) ** 2).mean(dim=1)
+ losses.append(loss)
+ losses = torch.stack(losses, dim=0) # A x N
+ return losses
+
+ def get_ellipsoid_kernel(self, dims):
+ assert(len(dims) == 3)
+ d1, d2, d3 = dims
+ v1, v2, v3 = np.meshgrid(np.linspace(-1, 1, d1), np.linspace(-1, 1, d2), np.linspace(-1, 1, d3), indexing='ij')
+ dist = np.sqrt(v1 ** 2 + v2 ** 2 + v3 ** 2)
+ return torch.from_numpy((dist <= 1)[np.newaxis, np.newaxis, ...]).float()
+
+ def get_cylinder_kernel(self, dims):
+ assert(len(dims) == 3)
+ d1, d2, d3 = dims
+ v1, v2, v3 = np.meshgrid(np.linspace(-1, 1, d1), np.linspace(-1, 1, d2), np.linspace(-1, 1, d3), indexing='ij')
+ dist = np.sqrt(v2 ** 2 + v3 ** 2)
+ return torch.from_numpy((dist <= 1)[np.newaxis, np.newaxis, ...]).float()
+
+def video_perturbation(model, input, target, method,
+ areas=[0.1], perturb_type=FADE_PERTURBATION,
max_iter=2000, num_levels=8, step=7, sigma=11, variant=PRESERVE_VARIANT,
print_iter=None, debug=False, reward_func="simple_log", resize=False,
- resize_mode='bilinear', smooth=0, num_devices=1):
+ resize_mode='bilinear', smooth=0, num_devices=1,
+ core_num_keyframe=7, core_spatial_size=11, core_shape='ellipsoid'):
if isinstance(areas, float):
areas = [areas]
momentum = 0.9
learning_rate = 0.05
regul_weight = 300
- reward_weight = 100
+ if method == 'step':
+ reward_weight = 1
+ else:
+ reward_weight = 100
+ core_weight = 0
device = input.device
iter_period = 2000
- regul_weight_last = max(regul_weight / 2, 1)
-
# input shape: NxCxTxHxW (1x3x16x112x112)
batch_size = input.shape[0]
num_frame = input.shape[2] #16
@@ -232,7 +307,8 @@ def video_perturbation(model, input, target,
if debug:
print(
- f"spatiotemporal_perturbation:\n"
+ f"video_perturbation:\n"
+ f"- method: {method}\n"
f"- target: {target}\n"
f"- areas: {areas}\n"
f"- variant: {variant}\n"
@@ -270,10 +346,18 @@ def video_perturbation(model, input, target,
max_volume = np.prod(mask_generator.shape_out) * num_frame
# Prepare reference area vector.
- if method == 'spatiotemporal':
+ if method == 'step':
+ mask_core = CoreLossCalulator(batch_size, num_frame,
+ mask_generator.shape_out[0], mask_generator.shape_out[1],
+ areas, device, num_key_frames=core_num_keyframe,
+ spatial_range=(core_spatial_size, core_spatial_size),
+ core_shape=core_shape)
+ vref = torch.ones(batch_size, max_volume).to(device)
+ vref[:, :int(max_volume * (1 - areas[0]))] = 0
+ elif method == '3d_ep':
vref = torch.ones(batch_size, max_volume).to(device)
vref[:, :int(max_volume * (1 - areas[0]))] = 0
- elif method == 'extremal':
+ elif method == '2d_ep':
aref = torch.ones(batch_size, num_frame, max_area).to(device)
aref[:, :, :int(max_area * (1 - areas[0]))] = 0
@@ -326,10 +410,14 @@ def video_perturbation(model, input, target,
# Area regularization.
# padded_masks: N x 1 x T x S_out x S_out
- if method == 'spatiotemporal':
+ if method == 'step':
mask_sorted = padded_masks.squeeze(1).reshape(batch_size, -1).sort(dim=1)[0]
regul = ((mask_sorted - vref)**2).mean(dim=1) * regul_weight # batch_size
- elif method == 'extremal':
+ regul += mask_core.calculate(padded_masks.contiguous().unsqueeze(0)).squeeze(0) * core_weight # batch_size
+ elif method == '3d_ep':
+ mask_sorted = padded_masks.squeeze(1).reshape(batch_size, -1).sort(dim=1)[0]
+ regul = ((mask_sorted - vref)**2).mean(dim=1) * regul_weight # batch_size
+ elif method == '2d_ep':
mask_sorted = padded_masks.squeeze(1).reshape(batch_size, num_frame, -1).sort(dim=2)[0]
regul = ((mask_sorted - aref)**2).mean(dim=2).mean(dim=1) * regul_weight # batch_size
@@ -343,6 +431,11 @@ def video_perturbation(model, input, target,
pmasks.data = pmasks.data.clamp(0, 1)
+ if method == 'step':
+ regul_weight *= 1.0008668 #1.00069^800=2
+ if t == 100:
+ core_weight = 300
+
# Record energy
# hist: batch_size x 2 x num_iter
hist_item = torch.cat(
@@ -353,9 +446,6 @@ def video_perturbation(model, input, target,
# print(f"Item shape: {hist_item.shape}")
hist = torch.cat((hist,hist_item), dim=2)
- # Diagnostics.
- # debug_this_iter = debug and (regul_weight / regul_weight_last >= 2)
-
# if (print_iter is not None and t % print_iter == 0) or debug_this_iter:
if (print_iter != None) and (t % print_iter == 0):
print("[{:04d}/{:04d}]".format(t + 1, max_iter), end="\n")
@@ -371,33 +461,6 @@ def video_perturbation(model, input, target,
variant, prob[i], pred_label[i]), end="")
print()
- # if debug_this_iter:
- # regul_weight_last = regul_weight
- # # for i, a in enumerate(areas):
- # for i in range(batch_size):
- # plt.figure(i, figsize=(20, 6))
- # plt.clf()
- # ncols = 4 if variant == DUAL_VARIANT else 3
- # plt.subplot(1, ncols, 1)
- # plt.plot(hist[i, 0].numpy())
- # plt.plot(hist[i, 1].numpy())
- # plt.plot(hist[i].sum(dim=0).numpy())
- # plt.legend(('energy', 'regul', 'both'))
- # plt.title(f'target area:{areas[0]:.2f}')
- # mask = padded_masks[i,:,7,:,:]
- # plt.subplot(1, ncols, 2)
- # imsc(mask, lim=[0, 1])
- # plt.title(
- # f"min:{mask.min().item():.2f}"
- # f" max:{mask.max().item():.2f}"
- # f" area:{mask.sum() / mask.numel():.2f}")
- # plt.subplot(1, ncols, 3)
- # imsc(perturb_x[i,:,7,:,:])
- # if variant == DUAL_VARIANT:
- # plt.subplot(1, ncols, 4)
- # imsc(perturb_x[i + batch_size,:,7,:,:])
- # plt.pause(0.001)
-
masks = masks.detach()
# Resize saliency map.
list_mask = []
@@ -411,4 +474,4 @@ def video_perturbation(model, input, target,
masks = torch.stack(list_mask, dim=2) # NxCxTxHxW
# masks: NxCxTxHxW; hist: Nx3xmax_iter
- return masks, hist, perturb_x
\ No newline at end of file
+ return masks, hist
\ No newline at end of file
diff --git a/visual_meth/perturbation_area.py b/visual_meth/perturbation_area.py
index 06ca22a..d2e69e9 100644
--- a/visual_meth/perturbation_area.py
+++ b/visual_meth/perturbation_area.py
@@ -22,7 +22,6 @@
DELETE_VARIANT = "delete"
DUAL_VARIANT = "dual"
-
def simple_log_reward(activation, target, variant):
N = target.shape[0]
bs = activation.shape[0]
@@ -147,7 +146,6 @@ def to(self, dev):
self.weight = self.weight.to(dev)
return self
-
class Perturbation:
def __init__(self, input, num_levels=8, max_blur=20, type=BLUR_PERTURBATION):
self.type = type
diff --git a/visual_meth/perturbation_area_new.py b/visual_meth/perturbation_area_new.py
new file mode 100644
index 0000000..9c66846
--- /dev/null
+++ b/visual_meth/perturbation_area_new.py
@@ -0,0 +1,514 @@
+import math
+import matplotlib.pyplot as plt
+import numpy as np
+import time
+
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+
+import sys
+sys.path.append(".")
+sys.path.append("..")
+from utils.CalAcc import process_activations
+from utils.ImageShow import *
+
+from torchray.utils import imsmooth, imsc
+from torchray.attribution.common import resize_saliency
+
+BLUR_PERTURBATION = "blur"
+FADE_PERTURBATION = "fade"
+
+PRESERVE_VARIANT = "preserve"
+DELETE_VARIANT = "delete"
+DUAL_VARIANT = "dual"
+
+def simple_log_reward(activation, target, variant):
+ N = target.shape[0]
+ bs = activation.shape[0]
+ b_repeat = int( bs // N )
+ device = activation.device
+
+ col_idx = target.repeat(b_repeat) # batch_size
+ row_idx = torch.arange(activation.shape[0], dtype=torch.long, device=device) # batch_size
+ prob = activation[row_idx, col_idx] # batch_size
+
+ if variant == DELETE_VARIANT:
+ reward = -torch.log(1-prob)
+ elif variant == PRESERVE_VARIANT:
+ reward = -torch.log(prob)
+ elif variant == DUAL_VARIANT:
+ reward = (-torch.log(1-prob[N:])) + (-torch.log(prob[:N]))
+ else:
+ assert False
+ return reward
+
+class MaskGenerator:
+ def __init__(self, shape, step, sigma, clamp=True, pooling_method='softmax'):
+ self.shape = shape
+ self.step = step
+ self.sigma = sigma
+ self.coldness = 20
+ self.clamp = clamp
+ self.pooling_method = pooling_method
+
+ assert int(step) == step
+
+ # self.kernel = lambda z: (z < 1).float()
+ self.kernel = lambda z: torch.exp(-2 * ((z - .5).clamp(min=0)**2))
+
+ self.margin = self.sigma
+ self.padding = 1 + math.ceil((self.margin + sigma) / step)
+ self.radius = 1 + math.ceil(sigma / step)
+ self.shape_in = [math.ceil(z / step) for z in self.shape]
+ self.shape_mid = [
+ z + 2 * self.padding - (2 * self.radius + 1) + 1
+ for z in self.shape_in
+ ]
+ self.shape_up = [self.step * z for z in self.shape_mid]
+ self.shape_out = [z - step + 1 for z in self.shape_up]
+
+ step_inv = [
+ torch.tensor(zm, dtype=torch.float32) /
+ torch.tensor(zo, dtype=torch.float32)
+ for zm, zo in zip(self.shape_mid, self.shape_up)
+ ]
+
+ # Generate kernel weights for smoothing mask_in
+ self.weight = torch.zeros((
+ 1,
+ (2 * self.radius + 1)**2,
+ self.shape_out[0],
+ self.shape_out[1]
+ ))
+
+ for ky in range(2 * self.radius + 1):
+ for kx in range(2 * self.radius + 1):
+ uy, ux = torch.meshgrid(
+ torch.arange(self.shape_out[0], dtype=torch.float32),
+ torch.arange(self.shape_out[1], dtype=torch.float32)
+ )
+ iy = torch.floor(step_inv[0] * uy) + ky - self.padding
+ ix = torch.floor(step_inv[1] * ux) + kx - self.padding
+
+ delta = torch.sqrt(
+ (uy - (self.margin + self.step * iy))**2 +
+ (ux - (self.margin + self.step * ix))**2
+ )
+
+ k = ky * (2 * self.radius + 1) + kx
+
+ self.weight[0, k] = self.kernel(delta / sigma)
+
+ def generate(self, mask_in):
+ # mask_in: Nx1xHxW --> mask: Nx1xS_outxS_out
+ mask = F.unfold(mask_in,
+ (2 * self.radius + 1,) * 2,
+ padding=(self.padding,) * 2)
+ mask = mask.reshape(
+ mask_in.shape[0], -1, self.shape_mid[0], self.shape_mid[1])
+ mask = F.interpolate(mask, size=self.shape_up, mode='nearest')
+ mask = F.pad(mask, (0, -self.step + 1, 0, -self.step + 1))
+ mask = self.weight * mask
+
+ if self.pooling_method == 'sigmoid':
+ if self.coldness == float('+Inf'):
+ mask = (mask.sum(dim=1, keepdim=True) - 5 > 0).float()
+ else:
+ mask = torch.sigmoid(
+ self.coldness * mask.sum(dim=1, keepdim=True) - 3
+ )
+ elif self.pooling_method == 'softmax':
+ if self.coldness == float('+Inf'): # max normalization
+ mask = mask.max(dim=1, keepdim=True)[0]
+ else: # smax normalization
+ mask = (
+ mask * F.softmax(self.coldness * mask, dim=1)
+ ).sum(dim=1, keepdim=True)
+ elif self.pooling_method == 'sum':
+ mask = mask.sum(dim=1, keepdim=True)
+ else:
+ assert False, f"Unknown pooling method {self.pooling_method}"
+
+ m = round(self.margin)
+ if self.clamp:
+ mask = mask.clamp(min=0, max=1)
+ cropped = mask[:, :, m:m + self.shape[0], m:m + self.shape[1]]
+ return cropped, mask
+
+ def to(self, dev):
+ """Switch to another device.
+ Args:
+ dev: PyTorch device.
+ Returns:
+ MaskGenerator: self.
+ """
+ self.weight = self.weight.to(dev)
+ return self
+
+class Perturbation:
+ def __init__(self, input, num_levels=8, max_blur=20, type=BLUR_PERTURBATION):
+ self.type = type
+ self.num_levels = num_levels
+ self.pyramid = []
+ assert num_levels >= 2
+ assert max_blur > 0
+ with torch.no_grad():
+ for sigma in torch.linspace(0, 1, self.num_levels):
+ if type == BLUR_PERTURBATION:
+ # input could be a batched tensor with size of NxCxHxW
+ y = imsmooth(input, sigma=(1 - sigma) * max_blur)
+ # ouput y has size of NxCxHxW
+ elif type == FADE_PERTURBATION:
+ y = input * sigma
+ else:
+ assert False
+ self.pyramid.append(y)
+ # self.pyramid = torch.cat(self.pyramid, dim=0)
+ self.pyramid = torch.stack(self.pyramid, dim=1) # NxLxCxHxW, L=num_levels
+
+ def apply(self, mask):
+ # mask: A*N*T x1xHxW
+ n = mask.shape[0] # n = A*N*T
+ inp_n = self.pyramid.shape[0] # inp_n = N*T
+ num_area = int(n / inp_n) # A
+ # starred expression: unpack a list to separated numbers
+ w = mask.reshape(n, 1, *mask.shape[1:]) # A*N*T x1x1xHxW, mask.unsqueeze(1)
+ w = w * (self.num_levels - 1) # w = 7*w
+ k = w.floor() # Integral part of w
+ w = w - k # Fractional part of w
+ k = k.long() # Transfer k to long int
+
+ y = self.pyramid.repeat(num_area, 1, 1, 1, 1) # A*N*T xLxCxHxW
+ k = k.expand(n, 1, *y.shape[2:]) # A*N*T x1xCxHxW, channel dim: 1-->3
+
+ y0 = torch.gather(y, 1, k) # select low level, Nx1xCxHxW
+ y1 = torch.gather(y, 1, torch.clamp(k + 1, max=self.num_levels - 1)) # select high level, Nx1xCxHxW
+
+ # return ((1 - w) * y0 + w * y1).squeeze(dim=1)
+ perturb_x = ((1 - w) * y0 + w * y1) #Nx1xCxHxW
+ return perturb_x
+
+ def to(self, dev):
+ """Switch to another device.
+ Args:
+ dev: PyTorch device.
+ Returns:
+ Perturbation: self.
+ """
+ self.pyramid.to(dev)
+ return self
+
+ def __str__(self):
+ return (
+ f"Perturbation:\n"
+ f"- type: {self.type}\n"
+ f"- num_levels: {self.num_levels}\n"
+ f"- pyramid shape: {list(self.pyramid.shape)}"
+ )
+
+class CoreLossCalulator:
+ def __init__ (self, bs, nt, nrow, ncol, areas, device,
+ num_key_frames=7, spatial_range=(11, 11),
+ core_shape='ellipsoid'):
+ self.bs = bs
+ self.nt = nt
+ self.nrow = nrow
+ self.ncol = ncol
+ self.areas = areas
+ self.narea = len(areas)
+ self.device = device
+
+ self.new_nrow = nrow
+ self.new_ncol = ncol
+
+ self.num_key_frames = num_key_frames
+ self.spatial_range = spatial_range
+ self.core_shape = core_shape
+ self.conv_stride = spatial_range[0]
+ self.core_kernel_ones = int(num_key_frames * spatial_range[0] * spatial_range[1] * 0.5236)
+ self.core_kernels = []
+ self.core_topks = []
+
+ for a_idx, area in enumerate(areas):
+ if self.core_shape == 'ellipsoid':
+ core_kernel = self.get_ellipsoid_kernel([num_key_frames, spatial_range[0], spatial_range[1]]).to(device)
+ elif self.core_shape == 'cylinder':
+ core_kernel = self.get_cylinder_kernel([num_key_frames, spatial_range[0], spatial_range[1]]).to(device)
+ else:
+ raise Exception(f'Unsurported core_shape, given {self.core_shape}')
+ self.core_kernels.append(core_kernel)
+ self.core_topks.append(math.ceil((area*self.nt*self.new_nrow*self.new_ncol) / self.core_kernel_ones))
+
+ # mask: AxNx1xTx S_out x S_out
+ def calculate (self, masks):
+ losses = []
+ small_masks = masks
+
+ losses = []
+ for a_idx, area in enumerate(self.areas):
+ core_kernel = self.core_kernels[a_idx]
+ core_kernel_sum = core_kernel.sum()
+ mask_conv = F.conv3d(small_masks[a_idx], core_kernel, bias=None,
+ stride=(1, self.conv_stride, self.conv_stride),
+ padding=0, dilation=1, groups=1)
+ # mask_conv_sorted = mask_conv.view(self.bs, -1).sort(dim=1, descending=True)[0]
+ mask_conv_sorted = mask_conv.view(self.bs, -1)
+ shuffled_inds = torch.randperm(mask_conv_sorted.shape[1], device=mask_conv_sorted.device)
+ mask_conv_sorted = mask_conv_sorted[:, shuffled_inds].sort(dim=1, descending=True)[0]
+ # loss = ((mask_conv_sorted[:, :self.core_topks[a_idx]] - 1) ** 2).mean(dim=1) \
+ # + ((mask_conv_sorted[:, self.core_topks[a_idx]:] - 0) ** 2).mean(dim=1)
+ loss = ((mask_conv_sorted[:, :self.core_topks[a_idx]]/core_kernel_sum - 1) ** 2).mean(dim=1) \
+ + ((mask_conv_sorted[:, self.core_topks[a_idx]:]/core_kernel_sum - 0) ** 2).mean(dim=1)
+ losses.append(loss)
+ losses = torch.stack(losses, dim=0) # A x N
+ return losses
+
+ def get_ellipsoid_kernel(self, dims):
+ assert(len(dims) == 3)
+ d1, d2, d3 = dims
+ v1, v2, v3 = np.meshgrid(np.linspace(-1, 1, d1), np.linspace(-1, 1, d2), np.linspace(-1, 1, d3), indexing='ij')
+ dist = np.sqrt(v1 ** 2 + v2 ** 2 + v3 ** 2)
+ return torch.from_numpy((dist <= 1)[np.newaxis, np.newaxis, ...]).float()
+
+ def get_cylinder_kernel(self, dims):
+ assert(len(dims) == 3)
+ d1, d2, d3 = dims
+ v1, v2, v3 = np.meshgrid(np.linspace(-1, 1, d1), np.linspace(-1, 1, d2), np.linspace(-1, 1, d3), indexing='ij')
+ dist = np.sqrt(v2 ** 2 + v3 ** 2)
+ return torch.from_numpy((dist <= 1)[np.newaxis, np.newaxis, ...]).float()
+
+# Control smoothness
+# masks: AxNx1xTx S_out x S_out
+def diff_loss(masks, num_diff=3):
+ narea = masks.shape[0]
+ bs = masks.shape[1]
+ diffs = [masks[:, :, :, 1:, ...] - masks[:, :, :, :-1, ...]]
+ for _ in range(num_diff - 1):
+ diffs.append(diffs[-1][:, :, :, 1:, ...] - diffs[-1][:, :, :, :-1, ...])
+ diffs_mean = torch.stack([(item.view(narea, bs, -1) ** 2).mean(dim=2) for item in diffs], dim=0)
+ return diffs_mean.mean(dim=0) # A x N
+
+# masks: AxNx1xTx S_out x S_out
+# total_ones: A
+def sum_loss(masks, total_ones):
+ narea = masks.shape[0]
+ bs = masks.shape[1]
+ return (masks.view(narea, bs, -1).sum(dim=2) / total_ones.unsqueeze(-1) - 1) ** 2
+
+def video_perturbation(model,
+ input,
+ target,
+ areas=[0.1],
+ perturb_type=BLUR_PERTURBATION,
+ max_iter=2000,
+ num_levels=8,
+ step=7,
+ sigma=11,
+ tsigma=0,
+ with_diff=False,
+ with_core=False,
+ variant=PRESERVE_VARIANT,
+ print_iter=None,
+ debug=False,
+ reward_func="simple_log",
+ resize=False,
+ resize_mode='bilinear',
+ smooth=0,
+ gpu_id=0,
+ core_num_keyframe=5,
+ core_spatial_size=11,
+ core_shape='ellipsoid'):
+
+ if isinstance(areas, float):
+ areas = [areas]
+ momentum = 0.9
+ learning_rate = 0.05
+ regul_weight = 300
+ reward_weight = 1
+ core_weight = 0
+ # core_weight = 300
+
+ device = input.device
+ torch.cuda.set_device(gpu_id)
+
+ regul_weight_last = max(regul_weight / 2, 1)
+
+ # input shape: NxCxTxHxW (1x3x16x112x112)
+ batch_size = input.shape[0] # N
+ num_frame = input.shape[2] # T=16
+ num_area = len(areas) # A
+
+ if debug:
+ print(
+ f"spatiotemporal_perturbation:\n"
+ f"- target: {target}\n"
+ f"- areas: {areas}\n"
+ f"- variant: {variant}\n"
+ f"- max_iter: {max_iter}\n"
+ f"- step/sigma: {step}, {sigma}\n"
+ f"- voxel size: {list(input.shape)}\n"
+ f"- reward function: {reward_func}"
+ )
+ print(f"- Target: {target.detach().cpu().tolist()}")
+
+ # Disable gradients for model parameters.
+ for p in model.parameters():
+ p.requires_grad_(False)
+ model.eval()
+
+ # y = model(input)
+ # ymin, ymin_idx = torch.min(y, dim=1)
+ # print(f'Min index: {ymin_idx[0].item()}, Min: {ymin[0].item()}')
+
+ # NxCxTxHxW --> N*T x CxHxW
+ pmt_inp = input.transpose(1,2).contiguous() # NxTxCxHxW
+ pmt_inp = pmt_inp.view(batch_size*num_frame, *pmt_inp.shape[2:]) # N*T x CxHxW
+
+ # Get the perturbation operator.
+ # perturbation.pyramid: T*N x LxCxHxW
+ perturbation = Perturbation(pmt_inp, num_levels=num_levels,
+ type=perturb_type).to(device)
+
+ # Prepare the mask generator (generating mask(134x134) from pmask(16x16)).
+ shape = perturbation.pyramid.shape[3:] # 112x112
+ mask_generator = MaskGenerator(shape, step, sigma, pooling_method='softmax').to(device)
+ h, w = mask_generator.shape_in # h=112/step, w=112/step, 16x16
+ pmasks = torch.ones(num_area*batch_size*num_frame, 1, h, w).to(device) #A*N*T x 1x16x16
+
+ if with_core:
+ mask_core = CoreLossCalulator(batch_size, num_frame,
+ mask_generator.shape_out[0], mask_generator.shape_out[1],
+ areas, device, num_key_frames=core_num_keyframe,
+ spatial_range=(core_spatial_size, core_spatial_size),
+ core_shape=core_shape)
+
+ max_area = np.prod(mask_generator.shape_out)
+ max_volume = np.prod(mask_generator.shape_out) * num_frame
+ total_ones = torch.zeros(num_area).to(device)
+ # Prepare reference area vector.
+ vref = torch.ones(num_area, batch_size, max_volume).to(device)
+ for a_idx, area in enumerate(areas):
+ total_ones[a_idx] = int(area * max_volume)
+ vref[a_idx, :, :int(max_volume * (1 - area))] = 0
+
+ # Initialize optimizer.
+ optimizer = optim.SGD([pmasks],
+ lr=learning_rate,
+ momentum=momentum,
+ dampening=momentum)
+ hist = torch.zeros((num_area, batch_size, 2, 0))
+
+ sum_time = 0
+ for t in range(max_iter):
+ end_time = time.time()
+
+ pmasks.requires_grad_(True)
+ masks, padded_masks = mask_generator.generate(pmasks)
+
+ if variant == DELETE_VARIANT:
+ perturb_x = perturbation.apply(1 - masks) # A*N*T x 1xCxHxW
+ elif variant == PRESERVE_VARIANT:
+ perturb_x = perturbation.apply(masks) # A*N*T x 1xCxHxW
+ elif variant == DUAL_VARIANT:
+ perturb_x = torch.cat((
+ perturbation.apply(masks), #preserve
+ perturbation.apply(1 - masks), #delete
+ ), dim = 1) # A*N*T x 2xCxHxW
+ else:
+ assert False
+
+ perturb_x = perturb_x.view(num_area, batch_size, num_frame, *perturb_x.shape[1:]) # AxNxTx2xCxHxW
+ perturb_x = perturb_x.permute(3,0,1,4,2,5,6).contiguous() # 2xAxNxCxTxHxW
+ perturb_x = perturb_x.view(perturb_x.shape[0]*num_area*batch_size, *perturb_x.shape[3:]) # 2*A*N x CxTxHxW
+
+ masks = masks.view(num_area, batch_size, num_frame, *masks.shape[1:]).transpose(2,3) # AxNx1xTxHxW
+ padded_masks = padded_masks.view(num_area, batch_size, num_frame, \
+ *padded_masks.shape[1:]).transpose(2,3) # AxNx1xTx S_out x S_out
+
+ # Evaluate the model on the masked data
+ # The input of model should have size of NxCxTxHxW
+ y = model(perturb_x) # 2*A*N x num_classes, default the model has softmax
+
+ # Cal probability
+ prob, pred_label, pred_label_prob = process_activations(y, target, softmaxed=True)
+
+ # Get reward.
+ if reward_func == "simple":
+ reward = simple_reward(y, target, variant=variant)
+ elif reward_func == "contrastive":
+ reward = contrastive_reward(y, target, variant=variant)
+ elif reward_func == "simple_log":
+ reward = simple_log_reward(y, target, variant=variant) # 2*A*N
+ reward = reward.view(-1, num_area, batch_size).mean(dim=0) * reward_weight # A x N
+
+ # Area regularization.
+ # padded_masks: A x N x 1 x T x S_out x S_out
+ # mask_sorted = padded_masks.squeeze(2).reshape(num_area, batch_size, -1).sort(dim=2)[0] # A x N x T*S_out*S_out
+ mask_sorted = padded_masks.squeeze(2).reshape(num_area, batch_size, -1) # A x N x T*S_out*S_out
+ shuffled_inds = torch.randperm(mask_sorted.shape[2], device=mask_sorted.device)
+ mask_sorted = mask_sorted[:,:,shuffled_inds]
+ mask_sorted = mask_sorted.sort(dim=2)[0]
+
+ regul = ((mask_sorted - vref)**2).mean(dim=2) * regul_weight # A x N
+ if with_diff:
+ regul += diff_loss(padded_masks) * regul_weight
+ if with_core:
+ regul += mask_core.calculate(padded_masks.contiguous()) * core_weight
+
+ # Energy summary
+ energy = (reward + regul).sum()
+
+ # Record energy
+ # hist: num_area x batch_size x 2 x num_iter
+ hist_item = torch.cat((reward.detach().cpu().view(num_area, batch_size, 1, 1),
+ regul.detach().cpu().view(num_area, batch_size, 1, 1)), dim=2)
+ hist = torch.cat((hist, hist_item), dim=3)
+
+ # Gradient step.
+ optimizer.zero_grad()
+ energy.backward()
+ optimizer.step()
+
+ pmasks.data = pmasks.data.clamp(0, 1)
+
+ regul_weight *= 1.0008668 #1.00069^800=2
+ # regul_weight *= 1.001734
+ if t == 100:
+ core_weight = 300
+
+ sum_time += time.time() - end_time
+ # Print iteration information
+ if (print_iter != None) and (t % print_iter == 0):
+ print(f"[{t+1:04d}/{max_iter:04d}]", end="\n")
+ for i in range(batch_size):
+ for a_idx, area in enumerate(areas):
+ if variant == "dual":
+ print(f" [GPU: {gpu_id} area:{area:.2f} loss:{hist[a_idx,i,0,-1]:.2f} reg:{hist[a_idx,i,1,-1]:.2f} "\
+ f"presv:{prob[a_idx*batch_size+i]:.2f}/{pred_label[a_idx*batch_size+i]} "\
+ f"del:{prob[(num_area+a_idx)*batch_size+i]:.2f}/{pred_label[(num_area+a_idx)*batch_size+i]}]", end="")
+ else:
+ print(f" [GPU: {gpu_id} area:{area:.2f} loss:{hist[a_idx,i,0,-1]:.2f} reg:{hist[a_idx,i,1,-1]:.2f} "\
+ f"{variant}:{prob[a_idx*batch_size+i]:.2f}/{pred_label[a_idx*batch_size+i]}]", end="")
+ print()
+ avg_time = sum_time/(t+1)
+ print(f"GPU: {gpu_id}, Average time: {avg_time:.3f}.")
+
+ masks = masks.detach()
+ # Resize saliency map.
+ list_mask = []
+ for a_idx, area in enumerate(areas):
+ area_mask = []
+ for frame_idx in range(num_frame):
+ mask = masks[a_idx,:,:,frame_idx,:,:] # NxCxHxW
+ mask = resize_saliency(pmt_inp, mask, resize, mode=resize_mode)
+ # Smooth saliency map.
+ if smooth > 0:
+ mask = imsmooth(mask, sigma=smooth * min(mask.shape[2:]), padding_mode='constant')
+ area_mask.append(mask)
+ area_mask = torch.stack(area_mask, dim=2) # NxCxTxHxW
+ list_mask.append(area_mask)
+ masks = torch.stack(list_mask, dim=1).cpu() # NxAxCxTxHxW
+
+ # masks: AxNxCxTxHxW; hist: AxNx2xmax_iter; perturb_x: 2*A*N x CxTxHxW
+ return masks, hist
\ No newline at end of file