Skip to content

Commit

Permalink
update readme.
Browse files Browse the repository at this point in the history
  • Loading branch information
acb11711tx committed Apr 28, 2020
1 parent 847ac1a commit c9ec5b0
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The current version supports attribution methods and video classification models
* **pretrain_dataset**: Dataset name that test model pretrained on. Choices include 'kinetics', 'epic-kitchens-verb', 'epic-kitchens-noun'.
* **vis_method**: Name of visualization methods. Choices include 'grad', 'grad*input', 'integrated_grad', 'grad_cam', 'perturb'. Here the 'perturb' means is spatiotemporal perturbation method.
* **save_label**: Extra label for saving results. If given, visualization results will be saved in ./visual_res/$vis_method$/$model$/$save_label$.
* no_gpu: If set, the demo will be run on CPU, else run on only one GPU.
* **no_gpu**: If set, the demo will be run on CPU, else run on only one GPU.

Arguments for perturb:
* **num_iter**: Number of iterations to get the perturbation results. Default is set to 2000 for better convergence.
Expand All @@ -51,13 +51,13 @@ Arguments for gradient methods:
### Examples

#### Saptiotemporal Perturbation + I3D (pretrained on Kinetics-400)
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method perturb --num_iter 2000 --perturb_area 0.1`
`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model i3d --pretrain_dataset kinetics --vis_method perturb --num_iter 2000 --perturb_area 0.1`

#### Spatiotemporal Perturbation + TSM (pretrained on EPIC-Kitchens-noun)
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/epic-kitchens-noun/sampled_frames --model tsm --pretrain_dataset epic-kitchens-noun --vis_method perturb --num_iter 2000 --perturb_area 0.05`
`$ python main.py --videos_dir VideoVisual/test_data/epic-kitchens-noun/sampled_frames --model tsm --pretrain_dataset epic-kitchens-noun --vis_method perturb --num_iter 2000 --perturb_area 0.05`

#### Integrated Gradients + R(2+1)D (pretrained on Kinetics-400)
`$ python main.py --videos_dir /home/acb11711tx/lzq/VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method integrated_grad`
`$ python main.py --videos_dir VideoVisual/test_data/kinetics/sampled_frames --model r2plus1d --pretrain_dataset kinetics --vis_method integrated_grad`


## Results
Expand Down
10 changes: 9 additions & 1 deletion model_def/tsm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch
from torch import nn

import os
import sys
sys.path.append(".")
sys.path.append("..")

class tsm (nn.Module):
def __init__ (self, num_classes, segment_count, pretrained):
super(tsm, self).__init__()
Expand All @@ -14,7 +19,10 @@ def __init__ (self, num_classes, segment_count, pretrained):
kinetics_classes_num = 400
self.model = torch.hub.load(self.repo, 'TSM', kinetics_classes_num, segment_count, 'RGB',
base_model='resnet50')
checkpoint_path = '/home/acb11711tx/lzq/ModelVisualization/model_param/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth'
checkpoint_path = 'model_param/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth'
assert os.path.isfile(checkpoint_path), \
f'Something wrong with pretrained parameters of TSM-Kinetics, Given {checkpoint_path}.'
print(f'Load checkpoint of TSM from {checkpoint_path}.')
state_dict = torch.load(checkpoint_path)['state_dict']
state_dict = {k[7:]: v for k, v in state_dict.items()}
self.model.load_state_dict(state_dict)
Expand Down
4 changes: 0 additions & 4 deletions path_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ class PathDict (object):
def __init__ (self):
hostname = socket.gethostname()
if 'abci' in hostname:
self.surgery_ds_root = '/groups1/gcb50205/lzq/dataset/JIGSAWS'
self.proj_root = os.path.dirname(os.path.abspath(__file__))
print(self.proj_root)
else:
# self.surgery_ds_root = '/home/shinkyo/dataset/JIGSAWS'
# self.proj_root = '/home/shinkyo/lzq/ModelVisualization'
self.proj_root = os.path.dirname(os.path.abspath(__file__))

0 comments on commit c9ec5b0

Please sign in to comment.