diff --git a/README.md b/README.md index 4a32fa6..8361688 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ Our implementation is mainly based on the following codebases. We gratefully tha If you find this repository helpful, please consider citing our paper: ``` -@inproceedings{zbontar2021barlow, +@inproceedings{icml2023a2mim, title={Architecture-Agnostic Masked Image Modeling -- From ViT back to CNN}, author={Li, Siyuan and Wu, Di and Wu, Fang and Zang, Zelin and Li, Stan. Z.}, booktitle={International Conference on Machine Learning}, diff --git a/analysis_tools/fourier_analysis/fourier_analysis.py b/analysis_tools/fourier_analysis/fourier_analysis.py new file mode 100644 index 0000000..461f2f0 --- /dev/null +++ b/analysis_tools/fourier_analysis/fourier_analysis.py @@ -0,0 +1,548 @@ +import os +import copy +import math +import mmcv +import numpy as np +from einops import rearrange + +import torch +import torch.nn as nn +from tqdm import tqdm + +import matplotlib.cm as cm +import matplotlib.pyplot as plt +from matplotlib.collections import LineCollection + +from openmixup import datasets as openmixup_datasets + +import sys +sys.path.append('./') +from utils import get_model, parse_args + + +class PatchEmbed(nn.Module): + def __init__(self, model): + super().__init__() + self.model = copy.deepcopy(model) + + def forward(self, x, **kwargs): + x = self.model.patch_embed(x) + cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x = self.model.pos_drop(x + self.model.pos_embed) + return x + + +class Residual(nn.Module): + def __init__(self, *fn): + super().__init__() + self.fn = nn.Sequential(*fn) + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) + x + + +class Lambda(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + + +def flatten(xs_list): + return [x for xs in xs_list for x in xs] + + +def fourier(x): # 2D Fourier transform + f = torch.fft.fft2(x) + f = f.abs() + 1e-6 + f = f.log() + return f + + +def fft_shift(x): # shift Fourier transformed feature map + b, c, h, w = x.shape + return torch.roll(x, shifts=(int(h/2), int(w/2)), dims=(2, 3)) + + +def make_segments(x, y): # make segment for `plot_segment` + points = np.array([x, y]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + return segments + + +def plot_segment(ax, xs, ys, marker, liner='solid', cmap_name="plasma", alpha=1.0, zorder=1): + # plot with cmap segments + z = np.linspace(0.0, 1.0, len(ys)) + z = np.asarray(z) + + cmap = cm.get_cmap(cmap_name) + norm = plt.Normalize(0.0, 1.0) + segments = make_segments(xs, ys) + lc = LineCollection(segments, array=z, cmap=cmap_name, norm=norm, + linewidth=2.0, linestyles=liner, alpha=alpha) + ax.add_collection(lc) + + colors = [cmap(x) for x in xs] + ax.scatter(xs, ys, color=colors, marker=marker, zorder=100 + zorder) + return lc + + +def create_cmap(color_name, end=0.95): + """ create custom cmap """ + from matplotlib import cm + from matplotlib.colors import ListedColormap, LinearSegmentedColormap + + color = cm.get_cmap(color_name, 200) + if end == 0.8: + newcolors = color(np.linspace(0.75, end, 200)) + else: + newcolors = color(np.linspace(max(0.5, end-0.4), end, 200)) + newcmp = ListedColormap(newcolors, name=color_name+"05_09") + return newcmp + + +def make_proxy(color, marker, liner, **kwargs): + """ add custom legend """ + from matplotlib.lines import Line2D + cmap = cm.get_cmap(color) + color = cmap(np.arange(4) / 4) + return Line2D([0, 1], [0, 1], color=color[3], marker=marker, linestyle=liner) + + +def plot_fourier_features(latents): + # Fourier transform feature maps + fourier_latents = [] + for latent in latents: # `latents` is a list of hidden feature maps in latent spaces + latent = latent.cpu() + + if len(latent.shape) == 3: # for ViT + b, n, c = latent.shape + h, w = int(math.sqrt(n)), int(math.sqrt(n)) + latent = rearrange(latent, "b (h w) c -> b c h w", h=h, w=w) + elif len(latent.shape) == 4: # for CNN + b, c, h, w = latent.shape + else: + raise Exception("shape: %s" % str(latent.shape)) + latent = fourier(latent) + latent = fft_shift(latent).mean(dim=(0, 1)) + latent = latent.diag()[int(h/2):] # only use the half-diagonal components + latent = latent - latent[0] # visualize 'relative' log amplitudes + # (i.e., low-freq amp - high freq amp) + fourier_latents.append(latent) + + return fourier_latents + + +def plot_variance_features(latents): + # aggregate feature map variances + variances = [] + for latent in latents: # `latents` is a list of hidden feature maps in latent spaces + latent = latent.cpu() + + if len(latent.shape) == 3: # for ViT + b, n, c = latent.shape + h, w = int(math.sqrt(n)), int(math.sqrt(n)) + latent = rearrange(latent, "b (h w) c -> b c h w", h=h, w=w) + elif len(latent.shape) == 4: # for CNN + b, c, h, w = latent.shape + else: + raise Exception("shape: %s" % str(latent.shape)) + variances.append(latent.var(dim=[-1, -2]).mean(dim=[0, 1])) + + return variances + + +def forward_model(args, device): + + # ======================== build model ======================== + model, mean, std = get_model(args=args) + model = model.to(device) + model.eval() + + if "resnet" in args.model_name: + # model → blocks. `blocks` is a sequence of blocks + blocks = [ + nn.Sequential(model.conv1, model.bn1, model.act1, model.maxpool), + *model.layer1, + *model.layer2, + *model.layer3, + *model.layer4, + nn.Sequential(model.global_pool, model.fc) + ] + elif "vit" in args.model_name or "deit" in args.model_name: + # `blocks` is a sequence of blocks + blocks = [ + PatchEmbed(model), + *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] + for b in model.blocks]), + nn.Sequential(Lambda(lambda x: x[:, 0]), model.norm, model.head), + ] + else: + raise NotImplementedError + # print('blocks:', len(blocks)) + + # ======================== build dataloader ======================== + test_dir = args.test_dir + assert os.path.isdir(test_dir) and args.test_list is not None + test_set = openmixup_datasets.build_dataset( + dict( + type='ClassificationDataset', + data_source=dict( + list_file=args.test_list, root=test_dir, **dict(type='ImageNet')), + pipeline=[ + dict(type='Resize', size=256), + dict(type='CenterCrop', size=224), + dict(type='ToTensor'), + dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ], + prefetch=False) + ) + test_loader = openmixup_datasets.build_dataloader(test_set, imgs_per_gpu=args.batch_size, + workers_per_gpu=2, dist=False, shuffle=False, + drop_last=True, prefetch=False, img_norm_cfg=dict()) + + # ======================== build dataloader ======================== + # load a sample ImageNet-1K image -- use the full val dataset for precise results + latents = dict() + for i, data in tqdm(enumerate(test_loader)): + if isinstance(data, tuple): + assert len(data) == 2 + img, label = data + else: + assert isinstance(data, dict) + img = data['img'] + label = data['gt_label'] + + with torch.no_grad(): + xs, label = img.to(device), label.to(device) + # accumulate `latents` by collecting hidden states of a model + for b,block in enumerate(blocks): + if b == len(blocks) - 1: # drop logit (output) + break + xs = block(xs) + if i == 0: + latents[str(b)] = list() + latents[str(b)].append(xs.detach().cpu()) + else: + latents[str(b)].append(xs.detach().cpu()) + if i == 25: + break + + latent_list = list() + for i in range(len(blocks)-1): + l = torch.cat(latents[str(i)], dim=0) + latent_list.append(l) + latents = latent_list + + # for ViT/DeiT/pit_ti_224: Drop CLS token + if "vit" in args.model_name or "deit" in args.model_name or "pit" in args.model_name: + latents = [latent[:,1:] for latent in latents] + + return latents + + +def set_plot_args(model_name, idx=0, alpha_base=0.9): + # setup + linear_mapping = dict(cl="dashed", mim="solid", ssl="dashed", sl="dashdot") + marker_mapping = dict(cl="s", mim="p", ssl="D", sl="o") + colour_mapping = dict(cl=["YlGnBu", "Blues", "GnBu", "Greens", "YlGn", "winter"], + # mim=["Reds", "OrRd", "YlOrRd", "RdPu",], # ResNet + mim=["Reds", "YlOrRd", "OrRd", "RdPu",], # ViT + ssl=["PuRd",], # red + sl=["autumn", "winter", ], + ) + zorder_mapping = dict(cl=3, mim=4, ssl=2, sl=1) + + prefix = model_name.split("_")[0] + alpha = alpha_base if prefix != 'sl' else 0.7 + marker = marker_mapping[prefix] + liner = linear_mapping[prefix] + cmap_list = colour_mapping[prefix] + cmap_name = create_cmap(cmap_list[idx % len(cmap_list)], end=0.8 if prefix == 'sl' else 0.95) + zorder = zorder_mapping[prefix] + # refine model_name + model_name = model_name.split("_")[-1].replace("+", " \ ") + model_name = r"$\mathrm{" + model_name + "}$" + + return model_name, alpha, marker, liner, cmap_name, zorder + + +def plot_fft_A(args, fourier_latents, save_path, save_format='png'): + # A. Plot Fig 2a: "Relative log amplitudes of Fourier transformed feature maps" + fig, ax1 = plt.subplots(1, 1, figsize=(3.3, 4), dpi=150) + + for i, latent in enumerate(reversed(fourier_latents[:-1])): + freq = np.linspace(0, 1, len(latent)) + ax1.plot(freq, latent, color=cm.plasma_r(i / len(fourier_latents))) + + ax1.set_xlim(left=0, right=1) + ax1.set_xlabel("Frequency") + ax1.set_ylabel("$\Delta$ Log amplitude") + + from matplotlib.ticker import FormatStrFormatter + ax1.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) + ax1.xaxis.set_major_formatter(FormatStrFormatter('%.1fπ')) + + plt.show() + plt.savefig(os.path.join(save_path, f'fft_features.{save_format}')) + plt.close() + + +def plot_fft_B(args, fourier_latents, save_path, model_names=None, save_format='png'): + # B. Plot Fig 8: "Relative log amplitudes of high-frequency feature maps" + + # plot settings + alpha_base = 0.9 + font_size = 13 + cmap_name = "plasma" + liner = "solid" + + if model_names is None: + dpi = 120 + model_names = ['ssl_' + args.model_name] + fourier_latents = [fourier_latents] + else: + dpi = 400 + assert isinstance(model_names, list) and len(model_names) >= 1 + zipped = zip(model_names, fourier_latents) + zipped = sorted(zipped, key=lambda x:x[0]) + zipped = zip(*zipped) + model_names, fourier_latents = [list(x) for x in zipped] + + fig, ax2 = plt.subplots(1, 1, figsize=(6.5, 5), dpi=dpi) + proxy_list = [] + for i in range(len(model_names)): + print(i, model_names[i], len(fourier_latents[i])) + if "resnet" in args.model_name: + pools = [4, 8, 14] + msas = [] + marker = "D" + elif "vit" in args.model_name or "deit" in args.model_name: + pools = [] + msas = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23,] # vit-tiny + marker = "o" + else: + import warnings + warnings.warn("The configuration for %s are not implemented." % args.model_name, Warning) + pools, msas = [], [] + marker = "s" + + # setup + model_names[i], alpha, marker, liner, cmap_name, zorder = set_plot_args(model_names[i], i, alpha_base) + # add legend + proxy_list.append(make_proxy(cmap_name, marker, liner, linewidth=2)) + + # Normalize + depths = range(len(fourier_latents[i])) + depth = len(depths) - 1 + depths = (np.array(depths)) / depth + pools = (np.array(pools)) / depth + msas = (np.array(msas)) / depth + + lc = plot_segment(ax2, depths, [latent[-1] for latent in fourier_latents[i]], + marker=marker, liner=liner, alpha=alpha, cmap_name=cmap_name, zorder=zorder) + + # ploting + for pool in pools: + ax2.axvspan(pool - 1.0 / depth, pool + 0.0 / depth, color="tab:blue", alpha=0.15, lw=0) + for msa in msas: + ax2.axvspan(msa - 1.0 / depth, msa + 0.0 / depth, color="tab:gray", alpha=0.15, lw=0) + + ax2.set_xlabel(r"$\mathrm{Normalized \ Depth}$", fontsize=font_size+2) + ax2.set_ylabel(r"$\mathrm{\Delta \ Log \ Amplitude}$", fontsize=font_size+2) + ax2.set_xlim(-0.01, 1.01) + + if len(model_names) > 1: + # ax2.legend(proxy_list, model_names, loc='upper left', fontsize=font_size) + ax2.legend(proxy_list, model_names, fontsize=font_size) + plt.grid(ls='--', alpha=0.5, axis='y') + + from matplotlib.ticker import FormatStrFormatter + ax2.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) + + plt.show() + plt.savefig( + os.path.join(save_path, f'high_freq_fft_features.{save_format}'), + dpi=dpi, bbox_inches='tight', format=save_format) + plt.close() + + +def plot_var_feat(args, variances, save_path, model_names=None, save_format='png'): + # Plot Fig 9: "Feature map variance" + + # plot settings + alpha_base = 0.9 + font_size = 13 + cmap_name = "plasma" + liner = "solid" + + if model_names is None: + dpi = 120 + model_names = ['ssl_' + args.model_name] + variances = [variances] + else: + dpi = 400 + assert isinstance(model_names, list) and len(model_names) >= 1 + zipped = zip(model_names, variances) + zipped = sorted(zipped, key=lambda x:x[0]) + zipped = zip(*zipped) + model_names, variances = [list(x) for x in zipped] + + fig, ax2 = plt.subplots(1, 1, figsize=(6.5, 5), dpi=dpi) + proxy_list = [] + for i in range(len(model_names)): + print(i, model_names[i], len(variances[i])) + if "resnet" in args.model_name: + pools = [4, 8, 14] + msas = [] + marker = "D" + color = "tab:blue" + elif "vit" in args.model_name or "deit" in args.model_name: + pools = [] + msas = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23,] # vit-tiny + marker = "o" + color = "tab:red" + else: + import warnings + warnings.warn("The configuration for %s are not implemented." % args.model_name, Warning) + pools, msas = [], [] + marker = "s" + color = "tab:green" + + # setup + model_names[i], alpha, marker, liner, cmap_name, zorder = set_plot_args(model_names[i], i, alpha_base) + # add legend + proxy_list.append(make_proxy(cmap_name, marker, liner, linewidth=2)) + + # Normalize + depths = range(len(variances[i])) + depth = len(depths) - 1 + depths = (np.array(depths)) / depth + pools = (np.array(pools)) / depth + msas = (np.array(msas)) / depth + + lc = plot_segment(ax2, depths, variances[i], + marker=marker, liner=liner, alpha=alpha, cmap_name=cmap_name, zorder=zorder) + # cmap = cm.get_cmap(cmap_name) + # color = cmap(np.arange(4) / 4)[3] + # ax2.plot(depths, variances[i], marker=marker, color=color, markersize=7) + + # ploting + for pool in pools: + ax2.axvspan(pool - 1.0 / depth, pool + 0.0 / depth, color="tab:blue", alpha=0.15, lw=0) + for msa in msas: + ax2.axvspan(msa - 1.0 / depth, msa + 0.0 / depth, color="tab:gray", alpha=0.15, lw=0) + + ax2.set_xlabel(r"$\mathrm{Normalized \ Depth}$", fontsize=font_size+2) + ax2.set_ylabel(r"$\mathrm{Feature \ Map \ Variance}$", fontsize=font_size+2) + ax2.set_xlim(-0.01, 1.01) + + if len(model_names) > 1: + # ax2.legend(proxy_list, model_names, loc='upper left', fontsize=font_size) + ax2.legend(proxy_list, model_names, fontsize=font_size) + plt.grid(ls='--', alpha=0.5, axis='y') + + from matplotlib.ticker import FormatStrFormatter + ax2.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) + + plt.show() + plt.savefig( + os.path.join(save_path, f'variance_features.{save_format}'), + dpi=dpi, bbox_inches='tight', format=save_format) + plt.close() + + +def main(args, device, load_path=None, name_mapping=None, save_format='png'): + + exp_name = args.exp_name + if isinstance(load_path, str): + save_path = f"report/{exp_name}/{opt.model_name}/summary/" + mmcv.mkdir_or_exist(save_path) + assert os.path.exists(load_path) + model_list = os.listdir(load_path) + model_list.sort() + + if name_mapping is None: + name_mapping = dict() + for m in model_list: + name_mapping[m] = "mim_" + m.split(".")[0] + + latents = [] + model_names = [] + for m in model_list: + cur_model = m.split(".")[0] + if cur_model not in name_mapping.keys(): + continue + file_path = os.path.join(load_path, cur_model, f"{exp_name}_latents.pt") + try: + latents.append(torch.load(file_path)[f"{exp_name}_latents"]) + except: + continue + model_names.append(name_mapping[cur_model]) + + if exp_name == "fourier": + plot_fft_B(args, latents, save_path, model_names, save_format) + else: + plot_var_feat(args, latents, save_path, model_names, save_format) + + else: + if opt.pretrained_path is not None: + save_name = opt.pretrained_path.split("/") + save_name = "{}_{}".format(save_name[-2].split(".pth")[0], save_name[-1].split(".pth")[0]) + else: + save_name = opt.model_name + save_path = f"report/{exp_name}/{opt.model_name}/{save_name}" + print('start experiment:', save_name) + + latents = forward_model(args, device) + save_path = save_path.split('.pth')[0] + mmcv.mkdir_or_exist(save_path) + + if exp_name == "fourier": + latents = plot_fourier_features(latents) + torch.save( + dict(fourier_latents=latents), os.path.join(save_path, f"{exp_name}_latents.pt")) + plot_fft_A(args, latents, save_path, save_format) + plot_fft_B(args, latents, save_path, save_format) + else: + latents = plot_variance_features(latents) + torch.save( + dict(variance_latents=latents), os.path.join(save_path, f"{exp_name}_latents.pt")) + plot_var_feat(args, latents, save_path, save_format) + + +if __name__ == '__main__': + opt = parse_args() + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + assert opt.exp_name in ['fourier', 'variance', None,] + + # save_format = 'pdf' + save_format = 'png' + load_path = None + name_mapping = None + + load_path = f"./{opt.exp_name}/resnet50/" + + name_mapping = dict( + ### ResNet ### + randinit="sl_Random", + model_zoo_barlowtwins_r50_bs2048_ep300="cl_Barlow+Twins", + model_zoo_byol_r50_bs4096_ep200="cl_BYOL", + model_zoo_barlowtwins_r50_official_bs2048_ep1000="cl_BYOL", + model_zoo_colorization_r50_vissl_in1k="mim_Inpainting", + model_zoo_dino_r50_224_ep800="cl_DINO", + model_zoo_timm_resnet50_rsb_a2_224_ep300="sl_DeiT+(Sup.)", + model_zoo_mocov3_r50_official_ep300="cl_MoCoV3", + r50_r50_m07_rgb_m_learn_l3_res_fc_k1_l1_sz224_fft05_re_fun0_4xb256_accu2_cos_fp16_ep100="mim_SimMIM", + ### ViT ### + deit_small_no_aug_smth_mix0_8_cut1_0_4xb256_ema_fp16_ep300_latest="sl_DeiT", + model_zoo_vit_dino_deit_small_p16_224_ep300="cl_DINO", + model_zoo_vit_mae_vit_base_p16_224_ep400="mim_MAE", + model_zoo_vit_cae_vit_base_p16_224_ep300="mim_CAE", + rand_model_zoo_vit_dino_deit_small_p16_224_ep300="sl_Random", + ) + + main(args=opt, device=device, load_path=load_path, name_mapping=name_mapping, save_format=save_format) diff --git a/analysis_tools/fourier_analysis/utils.py b/analysis_tools/fourier_analysis/utils.py new file mode 100644 index 0000000..a620750 --- /dev/null +++ b/analysis_tools/fourier_analysis/utils.py @@ -0,0 +1,524 @@ +import argparse +import copy +import math +import sys + +import numpy as np +import torch +import torchvision +import torchvision.models as models +import timm +from timm.models import create_model +from torchvision import transforms +from torchvision.transforms import transforms +from tqdm import tqdm + +from mmcv.runner import load_state_dict + +import vit_models + + +def get_voc_dataset(voc_root=None): + if voc_root is None: + voc_root = "data/voc" # path to VOCdevkit for VOC2012 + data_transform = transforms.Compose([ + transforms.Resize((512, 512)), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + def load_target(image): + image = np.array(image) + image = torch.from_numpy(image) + return image + + target_transform = transforms.Compose([ + transforms.Resize((512, 512)), + transforms.Lambda(load_target), + ]) + + dataset = torchvision.datasets.VOCSegmentation(root=voc_root, image_set="val", transform=data_transform, + target_transform=target_transform) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, drop_last=False) + + return dataset, data_loader + + +def load_checkpoint(model_name, checkpoint_path, model): + """ load pretrained openmixup models """ + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if not ('state_dict' in checkpoint): + sd = checkpoint + else: + sd = checkpoint['state_dict'] + + # load with models trained on a single gpu or multiple gpus + if 'module.' in list(sd.keys())[0]: + sd = {k[len('module.'):]: v for k, v in sd.items()} + + # preprocess for openmixup ckpt + load_dict = dict() + + for key, value in sd.items(): + new_key = copy.copy(key) + replace = False + # convert ema + if ('ema_' in new_key) and ('deit' in model_name or 'vit' in model_name or 'swin' in model_name): + replace = True + new_key = new_key.replace("ema_", "") + new_key = new_key.replace("_", ".") + ema_key_list = [ + 'cls_token', 'pos_embed', 'patch_embed', + 'absolute_pos_embed', 'relative_position_bias_table', 'w_msa',] + for ori_key in ema_key_list: + dot_key = ori_key.replace("_", ".") + if dot_key in new_key: + new_key = new_key.replace(dot_key, ori_key) + + # remove backbone keys + for prefix_k in ['encoder_q', 'encoder', 'backbone', 'timm_model',]: + if new_key.startswith(prefix_k): + new_key = new_key[len(prefix_k) + 1: ] + # remove head keys + if "resnet" not in model_name: + if "mlpmixer" not in model_name: + for head_k in ['head.layers.', 'fc_cls.', 'fc.',]: + start_idx = new_key.find(head_k) + if start_idx != -1: + new_key = new_key[:start_idx] + new_key[start_idx + len(head_k): ] + else: + for head_k in ['head.',]: + start_idx = new_key.find(head_k) + if start_idx != -1: + new_key = new_key[:start_idx] + new_key[start_idx + len(head_k): ] + if 'fc_cls.' in new_key: + new_key = new_key.replace('fc_cls.', 'fc.') + + # replace as timm + if 'deit' in model_name or 'vit' in model_name: + if new_key.find('projection') != -1: + new_key = new_key.replace('projection', 'proj') + if new_key.find('ffn.layers.0.0.') != -1: + new_key = new_key.replace('ffn.layers.0.0.', 'mlp.fc1.') + if new_key.find('ffn.layers.1.') != -1: + new_key = new_key.replace('ffn.layers.1.', 'mlp.fc2.') + if new_key.find('gamma_') != -1: + new_key = new_key.replace('gamma_', 'ls') + + if new_key.find('layers') != -1: + new_key = new_key.replace('layers', 'blocks') + if new_key.find('.ln') != -1: + new_key = new_key.replace('.ln', '.norm') + if new_key == 'ln1.weight': + new_key = 'norm.weight' + if new_key == 'ln1.bias': + new_key = 'norm.bias' + + elif 'mlpmixer' in model_name: + if new_key.find('projection') != -1: + new_key = new_key.replace('projection', 'proj') + if new_key.find('ffn.layers.0.0.') != -1: + new_key = new_key.replace('ffn.layers.0.0.', 'mlp.fc1.') + if new_key.find('ffn.layers.1.') != -1: + new_key = new_key.replace('ffn.layers.1.', 'mlp.fc2.') + + if new_key.find('layers') != -1: + new_key = new_key.replace('layers', 'blocks') + if new_key.find('.ln') != -1: + new_key = new_key.replace('.ln', '.norm') + if new_key == 'ln1.weight': + new_key = 'norm.weight' + if new_key == 'ln1.bias': + new_key = 'norm.bias' + + if new_key.find('head.fc.') != -1: + new_key = new_key.replace('head.fc.', 'head.') + + elif 'swin' in model_name: + for _k in ['w_msa',]: + start_idx = new_key.find(_k) + if start_idx != -1: + new_key = new_key[:start_idx] + new_key[start_idx + len(_k)+1: ] + if new_key.find('projection') != -1: + new_key = new_key.replace('projection', 'proj') + if new_key.find('ffn.layers.0.0.') != -1: + new_key = new_key.replace('ffn.layers.0.0.', 'mlp.fc1.') + if new_key.find('ffn.layers.1.') != -1: + new_key = new_key.replace('ffn.layers.1.', 'mlp.fc2.') + + if new_key.find('stages.') != -1: + new_key = new_key.replace('stages.', 'layers.') + if new_key == 'norm3.weight': + new_key = 'norm.weight' + if new_key == 'norm3.bias': + new_key = 'norm.bias' + + # replace + if new_key in load_dict.keys(): + if not replace: + # print(f"EMA: remove the original key {new_key}") + continue + + load_dict[new_key] = value + print(f"keep key {key} -> {new_key}") + + load_state_dict(model, load_dict, strict=False) + return model + + +def get_model(args, pretrained=True): + model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + timm_model_names = timm.list_models(pretrained=True) + + pretrained_path = None + if args.pretrained_path is not None: + pretrained_path = args.pretrained_path + pretrained = False + + if args.model_name in model_names: + # model = models.__dict__[args.model_name](pretrained=pretrained) + model = timm.create_model(args.model_name, pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'resnet_drop' in args.model_name: + model = vit_models.drop_resnet50(pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'deit' in args.model_name: + model = create_model(args.model_name, pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'dino_small_dist' in args.model_name: + model = vit_models.dino_small_dist(patch_size=vars(args).get("patch_size", 16), pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'dino_tiny_dist' in args.model_name: + model = vit_models.dino_tiny_dist(patch_size=vars(args).get("patch_size", 16), pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'dino_small' in args.model_name: + model = vit_models.dino_small(patch_size=vars(args).get("patch_size", 16), pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'dino_tiny' in args.model_name: + model = vit_models.dino_tiny(patch_size=vars(args).get("patch_size", 16), pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'vit' in args.model_name and not 'T2t' in args.model_name: + model = create_model(args.model_name, pretrained=pretrained) + mean = (0.5, 0.5, 0.5) + std = (0.5, 0.5, 0.5) + elif 'T2t' in args.model_name: + model = create_model(args.model_name, pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + elif 'tnt' in args.model_name: + model = create_model(args.model_name, pretrained=pretrained) + mean = (0.5, 0.5, 0.5) + std = (0.5, 0.5, 0.5) + elif args.model_name in timm_model_names: + model = create_model(args.model_name, pretrained=pretrained) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + else: + raise NotImplementedError(f'Please provide correct model names: {model_names}') + + # reset classifier + if args.num_classes is not None: + model.reset_classifier(num_classes=args.num_classes) + + # load pretrained model from path + if pretrained_path is not None: + model = load_checkpoint(args.model_name, pretrained_path, model) + + return model, mean, std + + +def parse_args(): + parser = argparse.ArgumentParser(description='Transformers') + parser.add_argument('--test_dir', default='/home/kanchanaranasinghe/data/raw/imagenet/val', + help='ImageNet Validation Data') + parser.add_argument('--test_list', type=str, default=None, help='Meta file for custom dataset') + parser.add_argument('--exp_name', default=None, help='pretrained weight path') + parser.add_argument('--model_name', type=str, default='deit_small_patch16_224', help='Model Name') + parser.add_argument('--num_classes', type=int, default=None, help='Reset class number') + parser.add_argument('--pretrained_path', type=str, default=None, help='Path to custom pretrained model') + parser.add_argument('--scale_size', type=int, default=256, help='') + parser.add_argument('--img_size', type=int, default=224, help='') + parser.add_argument('--batch_size', type=int, default=256, help='Batch Size') + parser.add_argument('--drop_count', type=int, default=180, help='How many patches to drop') + parser.add_argument('--drop_best', action='store_true', default=False, help="set True to drop the best matching") + parser.add_argument('--test_image', action='store_true', default=False, help="set True to output test images") + parser.add_argument('--shuffle', action='store_true', default=False, help="shuffle instead of dropping") + parser.add_argument('--shuffle_size', type=int, default=14, help='nxn grid size of n', nargs='*') + parser.add_argument('--shuffle_h', type=int, default=None, help='h of hxw grid', nargs='*') + parser.add_argument('--shuffle_w', type=int, default=None, help='w of hxw grid', nargs='*') + parser.add_argument('--random_drop', action='store_true', default=False, help="randomly drop patches") + parser.add_argument('--random_offset_drop', action='store_true', default=False, help="randomly drop patches") + parser.add_argument('--cascade', action='store_true', default=False, help="run cascade evaluation") + parser.add_argument('--exp_count', type=int, default=1, help='random experiment count to average over') + parser.add_argument('--saliency', action='store_true', default=False, help="drop using saliency") + parser.add_argument('--saliency_box', action='store_true', default=False, help="drop using saliency") + parser.add_argument('--drop_lambda', type=float, default=0.2, help='percentage of image to drop for box') + parser.add_argument('--standard_box', action='store_true', default=False, help="drop using standard model") + parser.add_argument('--dino', action='store_true', default=False, help="drop using dino model saliency") + + parser.add_argument('--lesion', action='store_true', default=False, help="drop using dino model saliency") + parser.add_argument('--block_index', type=int, default=0, help='block index for lesion method', nargs='*') + + parser.add_argument('--draw_plots', action='store_true', default=False, help="draw plots") + parser.add_argument('--select_im', action='store_true', default=False, help="select robust images") + parser.add_argument('--save_path', type=str, default=None, help='save path') + + # segmentation evaluation arguments + parser.add_argument('--threshold', type=float, default=0.9, help='threshold for segmentation') + parser.add_argument('--pretrained_weights', default=None, help='pretrained weights path') + parser.add_argument('--patch_size', type=int, default=16, help='nxn grid size of n') + parser.add_argument('--use_shape', action='store_true', default=False, help="use shape token for prediction") + parser.add_argument('--rand_init', action='store_true', default=False, help="use randomly initialized model") + parser.add_argument('--generate_images', action='store_true', default=False, help="generate images instead of eval") + + return parser.parse_args() + + +def accuracy(output, target, top_k=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + max_k = max(top_k) + batch_size = target.size(0) + + _, pred = output.topk(max_k, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in top_k: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +def train_epoch(dataloader, model, criterion, optimizer, device, mixup_fn=None, model_ema=None, fine_tune=False): + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + + with tqdm(dataloader) as p_bar: + for samples, targets in p_bar: + samples = samples.to(device) + targets = targets.to(device) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + outputs = model(samples, fine_tune=fine_tune) + loss = criterion(outputs, targets) + + loss_value = loss.item() + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + if model_ema is not None: + model_ema.update(model) + + if mixup_fn is None: + acc1, acc5 = accuracy(outputs, targets, top_k=(1, 5)) + else: + acc1, acc5 = [0], [0] + losses.update(loss.item(), samples.size(0)) + top1.update(acc1[0], samples.size(0)) + top5.update(acc5[0], samples.size(0)) + + p_bar.set_postfix({"Loss": f'{losses.avg:.3f}', + "Top1": f'{top1.avg:.3f}', + "Top5": f'{top5.avg:.3f}', }) + + return losses.avg, top1.avg, top5.avg + + +def validate_epoch(dataloader, model, criterion, device): + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + + with tqdm(dataloader) as p_bar: + for samples, targets in p_bar: + samples = samples.to(device) + targets = targets.to(device) + + with torch.no_grad(): + outputs = model(samples) + loss = criterion(outputs, targets) + + acc1, acc5 = accuracy(outputs, targets, top_k=(1, 5)) + losses.update(loss.item(), samples.size(0)) + top1.update(acc1[0], samples.size(0)) + top5.update(acc5[0], samples.size(0)) + + p_bar.set_postfix({"Loss": f'{losses.avg:.3f}', + "Top1": f'{top1.avg:.3f}', + "Top5": f'{top5.avg:.3f}', }) + + return losses.avg, top1.avg, top5.avg + + +def parse_train_arguments(): + parser = argparse.ArgumentParser('default argument parser') + + # model architecture arguments + parser.add_argument('--model', type=str, default='deit') + parser.add_argument('--use_top_n_heads', type=int, default=1, help="use class token from intermediate layers") + parser.add_argument('--use_patch_outputs', action='store_true', default=False, help='use patch tokens') + + # default evaluation arguments + parser.add_argument('--datasets', type=str, default=None, metavar='DATASETS', nargs='+', + help="Datasets for evaluation") + parser.add_argument('--classifier', type=str, default='LR', choices=['LR', 'NN']) + parser.add_argument('--runs', type=int, default=600) + parser.add_argument('--num-support', type=int, default=1) + parser.add_argument('--save', type=str, default='logs') + parser.add_argument('--norm', action='store_true', default=False, help='use normalized features') + + # episodic dataset params + parser.add_argument('--n_test_runs', type=int, default=600, metavar='N', + help='Number of test runs') + parser.add_argument('--n_ways', type=int, default=5, metavar='N', + help='Number of classes for doing each classification run') + parser.add_argument('--n_shots', type=int, default=1, metavar='N', + help='Number of shots in test') + parser.add_argument('--n_queries', type=int, default=15, metavar='N', + help='Number of query in test') + parser.add_argument('--n_aug_support_samples', default=5, type=int, + help='The number of augmented samples for each meta test sample') + parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size', + help='Size of test batch)') + + # arguments for training + parser.add_argument('--data-path', type=str, default=None) + parser.add_argument('--data', type=str, default='CIFAR-FS') + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--project', type=str, default='vit_fsl') + parser.add_argument('--exp', type=str, default='exp_001') + parser.add_argument('--load', type=str, default=None, help="path to model to load") + parser.add_argument('--image_size', type=int, default=84) + + parser.add_argument('--model-ema', action='store_true') + parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') + parser.set_defaults(model_ema=True) + parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') + parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') + + # arguments for data augmentation + parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + \ + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') + parser.add_argument('--train-interpolation', type=str, default='bicubic', + help='Training interpolation (random, bilinear, bicubic default: "bicubic")') + # Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + # Mix-up params + parser.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') + parser.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') + parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + # Learning rate schedule parameters + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', + help='learning rate (default: 5e-4)') + parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') + parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') + parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') + parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') + parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + + parser.add_argument('--repeated-aug', action='store_true') + parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') + + return parser.parse_args() + + +def normalize(t, mean, std): + t[:, 0, :, :] = (t[:, 0, :, :] - mean[0]) / std[0] + t[:, 1, :, :] = (t[:, 1, :, :] - mean[1]) / std[1] + t[:, 2, :, :] = (t[:, 2, :, :] - mean[2]) / std[2] + return t diff --git a/analysis_tools/fourier_analysis/vit_models/__init__.py b/analysis_tools/fourier_analysis/vit_models/__init__.py new file mode 100644 index 0000000..1a52745 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/__init__.py @@ -0,0 +1,11 @@ +from .deit import * +from .deit_ensemble import * +from .deit_modified import * +from .dino import * +from .t2t_vit import * +from .t2t_vit_dense import * +from .t2t_vit_ghost import * +from .t2t_vit_se import * +from .tnt import * +from .vit import * +from .resnet import drop_resnet50 diff --git a/analysis_tools/fourier_analysis/vit_models/deit.py b/analysis_tools/fourier_analysis/vit_models/deit.py new file mode 100644 index 0000000..50087cb --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/deit.py @@ -0,0 +1,285 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import math + +import numpy as np +import torch +import torch.nn as nn +from functools import partial + +from timm.models.vision_transformer import VisionTransformer, _cfg +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_ + +import random + +__all__ = [ + 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', + 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', + 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', + 'deit_base_distilled_patch16_384', +] + + +class DistilledVisionTransformer(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + trunc_normal_(self.dist_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + self.head_dist.apply(self._init_weights) + + def forward_features(self, x): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add the dist_token + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + + x = x + self.pos_embed + x = self.pos_drop(x) + + layer_wise_tokens = [] + for blk in self.blocks: + x = blk(x) + layer_wise_tokens.append(x) + + layer_wise_tokens = [self.norm(x) for x in layer_wise_tokens] + return [(x[:, 0], x[:, 1]) for x in layer_wise_tokens] + + def forward(self, x): + list_out = self.forward_features(x) + x = [self.head(x) for x, _ in list_out] + x_dist = [self.head_dist(x_dist) for _, x_dist in list_out] + if self.training: + return [(out, out_dist) for out, out_dist in zip(x, x_dist)] + else: + # during inference, return the average of both classifier predictions + return [(out + out_dist) / 2 for out, out_dist in zip(x, x_dist)] + + +class VanillaVisionTransformer(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward_features(self, x, block_index=None, drop_rate=0): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add the dist_token + B, nc, w, h = x.shape + x = self.patch_embed(x) + + # interpolate patch embeddings + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size[0] + h0 = h // self.patch_embed.patch_size[1] + class_pos_embed = self.pos_embed[:, 0] + N = self.pos_embed.shape[1] - 1 + patch_pos_embed = self.pos_embed[:, 1:] + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + if w0 != patch_pos_embed.shape[-2]: + helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device) + patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2) + if h0 != patch_pos_embed.shape[-1]: + helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device) + patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-1) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + # interpolate patch embeddings finish + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + pos_embed + x = self.pos_drop(x) + + layer_wise_tokens = [] + for idx, blk in enumerate(self.blocks): + + if block_index is not None and idx == block_index: + token = x[:, :1, :] + features = x[:, 1:, :] + row = np.random.choice(range(x.shape[1] - 1), size=int(drop_rate*x.shape[1]), replace=False) + features[:, row, :] = 0.0 + x = torch.cat((token, features), dim=1) + + x = blk(x) + layer_wise_tokens.append(x) + + layer_wise_tokens = [self.norm(x) for x in layer_wise_tokens] + + return [x[:, 0] for x in layer_wise_tokens], [x for x in layer_wise_tokens] + + def forward(self, x, block_index=None, drop_rate=0, patches=False): + list_out, patch_out = self.forward_features(x, block_index, drop_rate) + x = [self.head(x) for x in list_out] + if patches: + return x, patch_out + else: + return x + + +class NonSpatialVisionTransformer(VisionTransformer): + def __init__(self, *args, **kwargs): + super(NonSpatialVisionTransformer, self).__init__(*args, **kwargs) + self.pos_embed = None + self.pos_drop = None + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + # x = x + self.pos_embed + # x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + +@register_model +def deit_tiny_patch16_224(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_tiny_patch16_224_no_pos(pretrained=False, **kwargs): + model = NonSpatialVisionTransformer( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://github.com/Muzammal-Naseer/Intriguing-Properties-of-Vision-Transformers/releases/download/" + "v0/no_pos_deit_t.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def deit_small_patch16_224(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_patch16_224(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_patch16_384(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model + + +@register_model +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): + model = DistilledVisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", + map_location="cpu", check_hash=True + ) + model.load_state_dict(checkpoint["model"]) + return model diff --git a/analysis_tools/fourier_analysis/vit_models/deit_ensemble.py b/analysis_tools/fourier_analysis/vit_models/deit_ensemble.py new file mode 100644 index 0000000..414daed --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/deit_ensemble.py @@ -0,0 +1,171 @@ +from functools import partial + +import torch +import torch.nn as nn +import math +from einops import reduce, rearrange +from timm.models.registry import register_model +from timm.models.vision_transformer import VisionTransformer, _cfg + +import torch.nn.functional as F + +__all__ = [ + "tiny_patch16_224_ensemble", "small_patch16_224_ensemble", "base_patch16_224_ensemble" +] + + +class FinalHead(nn.Module): + def __init__(self, token_dim=192): + super(FinalHead, self).__init__() + + self.token_dim = token_dim + self.fc = nn.Linear(self.token_dim, self.token_dim) + + def forward(self, x): + x = x.mean(dim=1) + return self.fc(x) + + +class TransformerHead(nn.Module): + expansion = 1 + + def __init__(self, token_dim, num_patches=196, num_classes=1000, stride=1): + super(TransformerHead, self).__init__() + + self.token_dim = token_dim + self.num_patches = num_patches + self.num_classes = num_classes + + # To process patches + self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn = nn.BatchNorm2d(self.token_dim) + self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=1, padding=1, bias=False) + self.bn = nn.BatchNorm2d(self.token_dim) + + self.shortcut = nn.Sequential() + if stride != 1 or self.token_dim != self.expansion * self.token_dim: + self.shortcut = nn.Sequential( + nn.Conv2d(self.token_dim, self.expansion * self.token_dim, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * self.token_dim) + ) + + self.token_fc = nn.Linear(self.token_dim, self.token_dim) + + def forward(self, x): + """ + x : (B, num_patches + 1, D) -> (B, C=num_classes) + """ + cls_token, patch_tokens = x[:, 0], x[:, 1:] + size = int(math.sqrt(x.shape[1])) + + patch_tokens = rearrange(patch_tokens, 'b (h w) d -> b d h w', h=size, w=size) # B, D, H, W + features = F.relu(self.bn(self.conv(patch_tokens))) + features = self.bn(self.conv(features)) + features += self.shortcut(patch_tokens) + features = F.relu(features) + patch_tokens = F.avg_pool2d(features, 14).view(-1, self.token_dim) + cls_token = self.token_fc(cls_token) + + out = patch_tokens + cls_token + + return out + + +class VisionTransformerEnsemble(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Transformer heads + self.transformerheads = nn.Sequential(*[ + TransformerHead(self.embed_dim) + for i in range(11)]) + self.spatialheads = nn.Sequential(*[FinalHead(self.embed_dim) for _ in range(4)]) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + # Store transformer outputs + transformerheads_outputs = [] + + for idx, blk in enumerate(self.blocks): + x = blk(x) + if idx <= 10: + out = self.norm(x) + out = self.transformerheads[idx](out) + transformerheads_outputs.append(out) + + x = self.norm(x) + return x, transformerheads_outputs + + def forward(self, x, get_average=False): + x, transformer_heads_outputs = self.forward_features(x) + final_heads_outputs = [self.head(x) for x in transformer_heads_outputs] + patches = x[:, 1:, :] + for idx in range(4): + final_heads_outputs.append(self.head(self.spatialheads[idx](patches[:, idx * 49:(idx + 1) * 49, :]))) + final_heads_outputs.append(self.head(x[:, 0])) + if get_average: + return torch.mean(torch.stack(final_heads_outputs, 0), dim=0) + return final_heads_outputs + + +@register_model +def tiny_patch16_224_ensemble(pretrained=False, index=0, **kwargs): + model = VisionTransformerEnsemble( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.load('pretrained_models/ckpts/joint_tiny_01/checkpoint.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict"]) + + return model + + +@register_model +def small_patch16_224_ensemble(pretrained=False, **kwargs): + model = VisionTransformerEnsemble( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + checkpoint = torch.load('pretrained_models/ckpts/joint_small_01/checkpoint.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict"]) + + return model + + +@register_model +def base_patch16_224_ensemble(pretrained=False, **kwargs): + model = VisionTransformerEnsemble( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + + if pretrained: + checkpoint = torch.load('pretrained_models/ckpts/joint_base_01/checkpoint.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict"]) + + return model + + +if __name__ == '__main__': + net = small_patch16_224_ensemble(pretrained=True) + + sample = torch.randn(1, 3, 224, 224) + pred, _ = net(sample) + + print('Parameters:', sum(p.numel() for p in net.parameters()) / 1000000) + print(f"Output shape: {pred.shape}") diff --git a/analysis_tools/fourier_analysis/vit_models/deit_modified.py b/analysis_tools/fourier_analysis/vit_models/deit_modified.py new file mode 100644 index 0000000..1951fc5 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/deit_modified.py @@ -0,0 +1,167 @@ +from functools import partial + +import torch +import torch.nn as nn +import math +from einops import reduce, rearrange +from timm.models.registry import register_model +from timm.models.vision_transformer import VisionTransformer, _cfg + +import torch.nn.functional as F + +__all__ = [ + "tiny_patch16_224_hierarchical", "small_patch16_224_hierarchical", "base_patch16_224_hierarchical" +] + + +class TransformerHead(nn.Module): + expansion = 1 + + def __init__(self, token_dim, num_patches=196, num_classes=1000, stride=1): + super(TransformerHead, self).__init__() + + self.token_dim = token_dim + self.num_patches = num_patches + self.num_classes = num_classes + + # To process patches + self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn = nn.BatchNorm2d(self.token_dim) + self.conv = nn.Conv2d(self.token_dim, self.token_dim, kernel_size=3, stride=1, padding=1, bias=False) + self.bn = nn.BatchNorm2d(self.token_dim) + + self.shortcut = nn.Sequential() + if stride != 1 or self.token_dim != self.expansion * self.token_dim: + self.shortcut = nn.Sequential( + nn.Conv2d(self.token_dim, self.expansion * self.token_dim, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * self.token_dim) + ) + + self.token_fc = nn.Linear(self.token_dim, self.token_dim) + + def forward(self, x): + """ + x : (B, num_patches + 1, D) -> (B, C=num_classes) + """ + cls_token, patch_tokens = x[:, 0], x[:, 1:] + size = int(math.sqrt(x.shape[1])) + + patch_tokens = rearrange(patch_tokens, 'b (h w) d -> b d h w', h=size, w=size) # B, D, H, W + features = F.relu(self.bn(self.conv(patch_tokens))) + features = self.bn(self.conv(features)) + features += self.shortcut(patch_tokens) + features = F.relu(features) + patch_tokens = F.avg_pool2d(features, 14).view(-1, self.token_dim) + cls_token = self.token_fc(cls_token) + + out = patch_tokens + cls_token + + return out + + +class VisionTransformer_hierarchical(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Transformer heads + self.transformerheads = nn.Sequential(*[ + TransformerHead(self.embed_dim) + for i in range(11)]) + + def forward_features(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) + + # interpolate patch embeddings + # dim = x.shape[-1] + # w0 = w // self.patch_embed.patch_size[0] + # h0 = h // self.patch_embed.patch_size[1] + # class_pos_embed = self.pos_embed[:, 0] + # N = self.pos_embed.shape[1] - 1 + # patch_pos_embed = self.pos_embed[:, 1:] + # patch_pos_embed = nn.functional.interpolate( + # patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + # scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + # mode='bicubic', + # ) + # if w0 != patch_pos_embed.shape[-2]: + # helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device) + # patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2) + # if h0 != patch_pos_embed.shape[-1]: + # helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device) + # patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-1) + # patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + # interpolate patch embeddings finish + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + # Store transformer outputs + transformerheads_outputs = [] + + for idx, blk in enumerate(self.blocks): + x = blk(x) + if idx <= 10: + out = self.norm(x) + out = self.transformerheads[idx](out) + transformerheads_outputs.append(out) + + x = self.norm(x) + return x, transformerheads_outputs + + def forward(self, x): + x, transformerheads_outputs = self.forward_features(x) + output = [] + for y in transformerheads_outputs: + output.append(self.head(y)) + output.append(self.head(x[:, 0])) + return output + +@register_model +def tiny_patch16_224_hierarchical(pretrained=False, index=0, **kwargs): + model = VisionTransformer_hierarchical( + patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + model = torch.nn.DataParallel(model) + checkpoint = torch.load('pretrained_models/ckpts/heir_tiny_001/model_best.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict"]) + return model + + +@register_model +def small_patch16_224_hierarchical(pretrained=False, **kwargs): + model = VisionTransformer_hierarchical( + patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + + if pretrained: + model = torch.nn.DataParallel(model) + checkpoint = torch.load('pretrained_models/ckpts/heir_small_001/model_best.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict"]) + return model + + +@register_model +def base_patch16_224_hierarchical(pretrained=False, **kwargs): + model = VisionTransformer_hierarchical( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + + if pretrained: + model = torch.nn.DataParallel(model) + checkpoint = torch.load('pretrained_models/ckpts/heir_base_001/model_best.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict"]) + return model diff --git a/analysis_tools/fourier_analysis/vit_models/dino.py b/analysis_tools/fourier_analysis/vit_models/dino.py new file mode 100644 index 0000000..41b9369 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/dino.py @@ -0,0 +1,465 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import torch +import torch.nn as nn + +import warnings +from timm.models.registry import register_model + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_classes = num_classes + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + # convert to list + if not isinstance(x, list): + x = [x] + # Perform forward pass separately on each resolution input. + # The inputs corresponding to a single resolution are clubbed and single + # forward is run on the same resolution inputs. Hence we do several + # forward passes = number of different resolutions used. We then + # concatenate all the output features. + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in x]), + return_counts=True, + )[1], 0) + start_idx = 0 + for end_idx in idx_crops: + _out = self.forward_features(torch.cat(x[start_idx: end_idx])) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + start_idx = end_idx + # Run the head forward on the concatenated features. + return self.head(output) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + if self.norm is not None: + x = self.norm(x) + + return x[:, 0] + + def interpolate_pos_encoding(self, x, pos_embed): + npatch = x.shape[1] - 1 + N = pos_embed.shape[1] - 1 + if npatch == N: + return pos_embed + class_emb = pos_embed[:, 0] + pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=math.sqrt(npatch / N), + mode='bicubic', + ) + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + + def forward_selfattention(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) + + # interpolate patch embeddings + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + class_pos_embed = self.pos_embed[:, 0] + if self.pos_embed.shape[1] == 198: + N = self.pos_embed.shape[1] - 2 + dist_pos_embed = self.pos_embed[:, 1] + patch_pos_embed = self.pos_embed[:, 2:] + else: + N = self.pos_embed.shape[1] - 1 + patch_pos_embed = self.pos_embed[:, 1:] + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + if w0 != patch_pos_embed.shape[-2]: + helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device) + patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2) + if h0 != patch_pos_embed.shape[-1]: + helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device) + patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-1) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + if self.pos_embed.shape[1] == 198: + pos_embed = torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + else: + pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + cls_tokens = self.cls_token.expand(B, -1, -1) + if self.pos_embed.shape[1] == 198: + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + x = torch.cat((cls_tokens, x), dim=1) + x = x + pos_embed + x = self.pos_drop(x) + + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + return blk(x, return_attention=True) + + def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) + x = x + pos_embed + x = self.pos_drop(x) + + # we will return the [CLS] tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)[:, 0]) + if return_patch_avgpool: + x = self.norm(x) + # In addition to the [CLS] tokens from the `n` last blocks, we also return + # the patch tokens from the last block. This is useful for linear eval. + output.append(torch.mean(x[:, 1:], dim=1)) + return torch.cat(output, dim=-1) + + +class DistilledVisionTransformer(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + trunc_normal_(self.dist_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + self.head_dist.apply(self._init_weights) + + def forward_features(self, x): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add the dist_token + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0], x[:, 1] + + def forward(self, x): + x, x_dist = self.forward_features(x) + x = self.head(x) + x_dist = self.head_dist(x_dist) + if self.training: + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + + +@register_model +def dino_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + + return model + + +@register_model +def dino_tiny_dist(patch_size=16, pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + + return model + + +@register_model +def dino_small(patch_size=16, pretrained=False, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model_url = { + 16: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + 8: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" + } + if pretrained: + state_dict = torch.hub.load_state_dict_from_url(model_url[patch_size]) + model.load_state_dict(state_dict, strict=False) + + return model + + +@register_model +def dino_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +@register_model +def dino_small_dist(patch_size=16, pretrained=False, **kwargs): + model = DistilledVisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, + bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/analysis_tools/fourier_analysis/vit_models/resnet.py b/analysis_tools/fourier_analysis/vit_models/resnet.py new file mode 100644 index 0000000..0118e79 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/resnet.py @@ -0,0 +1,78 @@ +import torch +import torchvision +from torchvision.models.resnet import Bottleneck, load_state_dict_from_url, model_urls + + +class NewResnet(torchvision.models.ResNet): + + def _forward_impl(self, x, drop_percent=None, drop_layer=0): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + if drop_layer == 1: + mask = torch.rand(x.shape[2:], device=x.device) + mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) + x = x * mask + x = self.layer1(x) + + if drop_layer == 2: + mask = torch.rand(x.shape[2:], device=x.device) + mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) + x = x * mask + x = self.layer2(x) + + if drop_layer == 3: + mask = torch.rand(x.shape[2:], device=x.device) + mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) + x = x * mask + x = self.layer3(x) + + if drop_layer == 4: + mask = torch.rand(x.shape[2:], device=x.device) + mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) + x = x * mask + x = self.layer4(x) + + if drop_layer == 5: + mask = torch.rand(x.shape[2:], device=x.device) + mask = (mask > drop_percent).unsqueeze(0).unsqueeze(0) + x = x * mask + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x, drop_percent=None, drop_layer=None): + return self._forward_impl(x, drop_percent=drop_percent, drop_layer=drop_layer) + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + model = NewResnet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def drop_resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +if __name__ == '__main__': + model = drop_resnet50(pretrained=True) + sample = torch.randn((1, 3, 224, 224)) + out = model(sample, drop_layer=1, drop_percent=0.25) diff --git a/analysis_tools/fourier_analysis/vit_models/t2t_vit.py b/analysis_tools/fourier_analysis/vit_models/t2t_vit.py new file mode 100644 index 0000000..001aebb --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/t2t_vit.py @@ -0,0 +1,301 @@ +# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. +# +# This source code is licensed under the Clear BSD License +# LICENSE file in the root directory of this file +# All rights reserved. +""" +T2T-ViT +""" +import torch +import torch.nn as nn + +from timm.models.helpers import load_pretrained +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_ +import numpy as np +from .token_transformer import Token_transformer +from .token_performer import Token_performer +from .transformer_block import Block, get_sinusoid_encoding + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), + 'classifier': 'head', + **kwargs + } + +default_cfgs = { + 'T2t_vit_7': _cfg(), + 'T2t_vit_10': _cfg(), + 'T2t_vit_12': _cfg(), + 'T2t_vit_14': _cfg(), + 'T2t_vit_19': _cfg(), + 'T2t_vit_24': _cfg(), + 'T2t_vit_t_14': _cfg(), + 'T2t_vit_t_19': _cfg(), + 'T2t_vit_t_24': _cfg(), + 'T2t_vit_14_resnext': _cfg(), + 'T2t_vit_14_wide': _cfg(), +} + +class T2T_module(nn.Module): + """ + Tokens-to-Token encoding module + """ + def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64): + super().__init__() + + if tokens_type == 'transformer': + print('adopt transformer encoder for tokens-to-token') + self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) + self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + + self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) + self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) + self.project = nn.Linear(token_dim * 3 * 3, embed_dim) + + elif tokens_type == 'performer': + print('adopt performer encoder for tokens-to-token') + self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) + self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) + + #self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5) + #self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5) + self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5) + self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5) + self.project = nn.Linear(token_dim * 3 * 3, embed_dim) + + elif tokens_type == 'convolution': # just for comparison with conolution, not our model + # for this tokens type, you need change forward as three convolution operation + print('adopt convolution layers for tokens-to-token') + self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution + self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution + self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution + + self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately + + def forward(self, x): + # step0: soft split + x = self.soft_split0(x).transpose(1, 2) + + # iteration1: re-structurization/reconstruction + x = self.attention1(x) + B, new_HW, C = x.shape + x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) + # iteration1: soft split + x = self.soft_split1(x).transpose(1, 2) + + # iteration2: re-structurization/reconstruction + x = self.attention2(x) + B, new_HW, C = x.shape + x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) + # iteration2: soft split + x = self.soft_split2(x).transpose(1, 2) + + # final tokens + x = self.project(x) + + return x + +class T2T_ViT(nn.Module): + def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.tokens_to_token = T2T_module( + img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim) + num_patches = self.tokens_to_token.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.tokens_to_token(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + block_heads = [] + + for idx, blk in enumerate(self.blocks): + x = blk(x) + out_temp = self.norm(x) + block_heads.append(out_temp) + + x = self.norm(x) + return x[:, 0], block_heads + + def forward(self, x, get_average=False): + x, block_heads = self.forward_features(x) + if get_average: + return torch.mean(torch.stack([self.head(x[:, 0]) for x in block_heads], 0), dim=0) + x = self.head(x) + return x + +@register_model +def T2t_vit_7(pretrained=False, **kwargs): # adopt performer for tokens to token + model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=7, num_heads=4, mlp_ratio=2., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_7'] + if pretrained: + checkpoint = torch.load('pretrained_models/71.7_T2T_ViT_7.pth.tar' + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +@register_model +def T2t_vit_10(pretrained=False, **kwargs): # adopt performer for tokens to token + model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=10, num_heads=4, mlp_ratio=2., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_10'] + if pretrained: + checkpoint = torch.load('pretrained_models/75.2_T2T_ViT_10.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +@register_model +def T2t_vit_12(pretrained=False, **kwargs): # adopt performer for tokens to token + model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=12, num_heads=4, mlp_ratio=2., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_12'] + if pretrained: + checkpoint = torch.load('pretrained_models/76.5_T2T_ViT_12.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + + +@register_model +def T2t_vit_14(pretrained=False, **kwargs): # adopt performer for tokens to token + model = T2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_14'] + if pretrained: + checkpoint = torch.load('pretrained_models/81.5_T2T_ViT_14.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +@register_model +def T2t_vit_19(pretrained=False, **kwargs): # adopt performer for tokens to token + model = T2T_ViT(tokens_type='performer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_19'] + if pretrained: + checkpoint = torch.load('pretrained_models/81.9_T2T_ViT_19.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +@register_model +def T2t_vit_24(pretrained=False, **kwargs): # adopt performer for tokens to token + model = T2T_ViT(tokens_type='performer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_24'] + if pretrained: + checkpoint = torch.load('pretrained_models/82.3_T2T_ViT_24.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +@register_model +def T2t_vit_t_14(pretrained=False, **kwargs): # adopt transformers for tokens to token + model = T2T_ViT(tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_t_14'] + if pretrained: + checkpoint = torch.load('pretrained_models/81.7_T2T_ViTt_14.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +@register_model +def T2t_vit_t_19(pretrained=False, **kwargs): # adopt transformers for tokens to token + model = T2T_ViT(tokens_type='transformer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_t_19'] + if pretrained: + checkpoint = torch.load('pretrained_models/82.4_T2T_ViTt_19.pth.tar', + map_location="cpu" + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +@register_model +def T2t_vit_t_24(pretrained=False, **kwargs): # adopt transformers for tokens to token + model = T2T_ViT(tokens_type='transformer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_t_24'] + if pretrained: + if pretrained: + checkpoint = torch.load('pretrained_models/82.6_T2T_ViTt_24.pth.tar' + ) + model.load_state_dict(checkpoint["state_dict_ema"]) + return model + +# rexnext and wide structure +@register_model +def T2t_vit_14_resnext(pretrained=False, **kwargs): + if pretrained: + kwargs.setdefault('qk_scale', 384 ** -0.5) + model = T2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=32, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_14_resnext'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + +@register_model +def T2t_vit_14_wide(pretrained=False, **kwargs): + if pretrained: + kwargs.setdefault('qk_scale', 512 ** -0.5) + model = T2T_ViT(tokens_type='performer', embed_dim=768, depth=4, num_heads=12, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_14_wide'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model diff --git a/analysis_tools/fourier_analysis/vit_models/t2t_vit_dense.py b/analysis_tools/fourier_analysis/vit_models/t2t_vit_dense.py new file mode 100644 index 0000000..febf24f --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/t2t_vit_dense.py @@ -0,0 +1,169 @@ +# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. +# +# This source code is licensed under the Clear BSD License +# LICENSE file in the root directory of this file +# All rights reserved. +""" +T2T-ViT-Dense +""" +import torch +import torch.nn as nn + +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model + +from .transformer_block import Mlp, Block, get_sinusoid_encoding +from .t2t_vit import T2T_module, _cfg + +default_cfgs = { + 't2t_vit_dense': _cfg(), +} + +class Transition(nn.Module): + def __init__(self, in_features, out_features, act_layer=nn.GELU): + super(Transition, self).__init__() + self.act = act_layer() + self.linear = nn.Linear(in_features, out_features) + def forward(self, x): + x = self.linear(x) + x = self.act(x) + + return x + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, growth_rate, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) #, out_features=growth_rate + self.dense_linear = nn.Linear(dim, growth_rate) + + def forward(self, x): + new_x = x + self.drop_path(self.attn(self.norm1(x))) + new_x = new_x + self.drop_path(self.mlp(self.norm2(new_x))) + new_x = self.dense_linear(new_x) + x = torch.cat([x, new_x], 2) # dense connnection: concate all the old features with new features in channel dimension + return x + +class T2T_ViT_Dense(nn.Module): + def __init__(self, growth_rate=32, tokens_type='performer', block_config=(3, 4, 6, 3), img_size=224, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.tokens_to_token = T2T_module( + img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.tokens_to_token.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList() + + start_dim = embed_dim + for i, num_layers in enumerate(block_config): + for j in range(num_layers): + new_dim = start_dim + j * growth_rate + block = Block(growth_rate=growth_rate, dim=new_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + self.blocks.append(block) + if i != len(block_config)-1: + transition = Transition(new_dim+growth_rate, (new_dim+growth_rate)//2) + self.blocks.append(transition) + start_dim = int((new_dim+growth_rate)//2) + out_dim = new_dim + growth_rate + print(f'end dim:{out_dim}') + self.norm = norm_layer(out_dim) + + # Classifier head + self.head = nn.Linear(out_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.tokens_to_token(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed # self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + +@register_model +def T2t_vit_dense(pretrained=False, **kwargs): + model = T2T_ViT_Dense(growth_rate=64, block_config=(3, 6, 6, 4), embed_dim=128, num_heads=8, mlp_ratio=2., **kwargs) + model.default_cfg = default_cfgs['t2t_vit_dense'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model diff --git a/analysis_tools/fourier_analysis/vit_models/t2t_vit_ghost.py b/analysis_tools/fourier_analysis/vit_models/t2t_vit_ghost.py new file mode 100644 index 0000000..a002374 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/t2t_vit_ghost.py @@ -0,0 +1,196 @@ +# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. +# +# This source code is licensed under the Clear BSD License +# LICENSE file in the root directory of this file +# All rights reserved. +""" +T2T-ViT-Ghost +""" +import torch +import torch.nn as nn + +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model + +from .transformer_block import Block, get_sinusoid_encoding +from .t2t_vit import T2T_module, _cfg + + +default_cfgs = { + 'T2t_vit_16_ghost': _cfg(), +} + +class Mlp_ghost(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, in_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + self.ratio = hidden_features//in_features + self.cheap_operation2 = nn.Conv1d(in_features, in_features, kernel_size=1, groups=in_features, bias=False) + self.cheap_operation3 = nn.Conv1d(in_features, in_features, kernel_size=1, groups=in_features, bias=False) + + def forward(self, x): # x: [B, N, C] + x1 = self.fc1(x) # x1: [B, N, C] + x1 = self.act(x1) + + x2 = self.cheap_operation2(x1.transpose(1,2)) # x2: [B, N, C] + x2 = x2.transpose(1,2) + x2 = self.act(x2) + + x3 = self.cheap_operation3(x1.transpose(1, 2)) # x3: [B, N, C] + x3 = x3.transpose(1, 2) + x3 = self.act(x3) + + x = torch.cat((x1, x2, x3), dim=2) # x: [B, N, 3C] + x = self.drop(x) + + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention_ghost(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + half_dim = int(0.5*dim) + self.q = nn.Linear(dim, half_dim, bias=qkv_bias) + self.k = nn.Linear(dim, half_dim, bias=qkv_bias) + self.v = nn.Linear(dim, half_dim, bias=qkv_bias) + + self.cheap_operation_q = nn.Conv1d(half_dim, half_dim, kernel_size=1, groups=half_dim, bias=False) + self.cheap_operation_k = nn.Conv1d(half_dim, half_dim, kernel_size=1, groups=half_dim, bias=False) + self.cheap_operation_v = nn.Conv1d(half_dim, half_dim, kernel_size=1, groups=half_dim, bias=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + q = self.q(x) + k = self.k(x) + v = self.v(x) + + q1 = self.cheap_operation_q(q.transpose(1,2)).transpose(1,2) + k1 = self.cheap_operation_k(k.transpose(1,2)).transpose(1,2) + v1 = self.cheap_operation_v(v.transpose(1,2)).transpose(1,2) + + q = torch.cat((q, q1), dim=2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = torch.cat((k, k1), dim=2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = torch.cat((v, v1), dim=2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention_ghost( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp_ghost(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class T2T_ViT_Ghost(nn.Module): + def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.tokens_to_token = T2T_module( + img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.tokens_to_token.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.tokens_to_token(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +@register_model +def T2t_vit_16_ghost(pretrained=False, **kwargs): + if pretrained: + kwargs.setdefault('qk_scale', 384 ** -0.5) + model = T2T_ViT_Ghost(tokens_type='performer', embed_dim=384, depth=16, num_heads=6, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_16_ghost'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model diff --git a/analysis_tools/fourier_analysis/vit_models/t2t_vit_se.py b/analysis_tools/fourier_analysis/vit_models/t2t_vit_se.py new file mode 100644 index 0000000..e76160c --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/t2t_vit_se.py @@ -0,0 +1,168 @@ +# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. +# +# This source code is licensed under the Clear BSD License +# LICENSE file in the root directory of this file +# All rights reserved. +""" +T2T-ViT-SE +""" +import torch +import torch.nn as nn + +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model +from .transformer_block import Block, Mlp, get_sinusoid_encoding +from .t2t_vit import T2T_module, _cfg + +default_cfgs = { + 'T2t_vit_14_se': _cfg(), +} + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool1d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): # x: [B, N, C] + x = torch.transpose(x, 1, 2) # [B, C, N] + b, c, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1) + x = x * y.expand_as(x) + x = torch.transpose(x, 1, 2) # [B, N, C] + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.se_layer = SELayer(dim) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.se_layer(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class T2T_ViT_SE(nn.Module): + def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.tokens_to_token = T2T_module( + img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.tokens_to_token.num_patches + print(num_patches) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.tokens_to_token(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + +@register_model +def T2t_vit_14_se(pretrained=False, **kwargs): + if pretrained: + kwargs.setdefault('qk_scale', 384 ** -0.5) + model = T2T_ViT_SE(tokens_type='performer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['T2t_vit_14_se'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model diff --git a/analysis_tools/fourier_analysis/vit_models/tnt.py b/analysis_tools/fourier_analysis/vit_models/tnt.py new file mode 100644 index 0000000..1ed95a8 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/tnt.py @@ -0,0 +1,249 @@ +""" Transformer in Transformer (TNT) in PyTorch +A PyTorch implement of TNT as described in +'Transformer in Transformer' - https://arxiv.org/abs/2103.00112 +The official mindspore code is released and available at +https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT +""" +import math +import torch +import torch.nn as nn +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.vision_transformer import Mlp +from timm.models.registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'pixel_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'tnt_s_patch16_224': _cfg( + url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'tnt_b_patch16_224': _cfg( + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), +} + + +class Attention(nn.Module): + """ Multi-Head Attention + """ + + def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + head_dim = hidden_dim // num_heads + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop, inplace=True) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop, inplace=True) + + def forward(self, x): + B, N, C = x.shape + qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple) + v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + """ TNT Block + """ + + def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., + qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + # Inner transformer + self.norm_in = norm_layer(in_dim) + self.attn_in = Attention( + in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.norm_mlp_in = norm_layer(in_dim) + self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), + out_features=in_dim, act_layer=act_layer, drop=drop) + + self.norm1_proj = norm_layer(in_dim) + self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) + # Outer transformer + self.norm_out = norm_layer(dim) + self.attn_out = Attention( + dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm_mlp = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), + out_features=dim, act_layer=act_layer, drop=drop) + + def forward(self, pixel_embed, patch_embed): + # inner + pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed))) + pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) + # outer + B, N, C = patch_embed.size() + patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) + patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) + return pixel_embed, patch_embed + + +class PixelEmbed(nn.Module): + """ Image to Pixel Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): + super().__init__() + num_patches = (img_size // patch_size) ** 2 + self.img_size = img_size + self.num_patches = num_patches + self.in_dim = in_dim + new_patch_size = math.ceil(patch_size / stride) + self.new_patch_size = new_patch_size + + self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) + self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) + + def forward(self, x, pixel_pos): + B, C, H, W = x.shape + assert H == self.img_size and W == self.img_size, \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})." + x = self.proj(x) + x = self.unfold(x) + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size) + x = x + pixel_pos + x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) + return x + + +class TNT(nn.Module): + """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, + num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.pixel_embed = PixelEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) + num_patches = self.pixel_embed.num_patches + self.num_patches = num_patches + new_patch_size = self.pixel_embed.new_patch_size + num_pixel = new_patch_size ** 2 + + self.norm1_proj = norm_layer(num_pixel * in_dim) + self.proj = nn.Linear(num_pixel * in_dim, embed_dim) + self.norm2_proj = norm_layer(embed_dim) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + blocks = [] + for i in range(depth): + blocks.append(Block( + dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, + mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[i], norm_layer=norm_layer)) + self.blocks = nn.ModuleList(blocks) + self.norm = norm_layer(embed_dim) + + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.patch_pos, std=.02) + trunc_normal_(self.pixel_pos, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'patch_pos', 'pixel_pos', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + pixel_embed = self.pixel_embed(x, self.pixel_pos) + + patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) + patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) + patch_embed = patch_embed + self.patch_pos + patch_embed = self.pos_drop(patch_embed) + + for blk in self.blocks: + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) + + patch_embed = self.norm(patch_embed) + return patch_embed[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +@register_model +def tnt_s_patch16_224(pretrained=False, **kwargs): + model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, + qkv_bias=False, **kwargs) + model.default_cfg = default_cfgs['tnt_s_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def tnt_b_patch16_224(pretrained=False, **kwargs): + model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, + qkv_bias=False, **kwargs) + model.default_cfg = default_cfgs['tnt_b_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model \ No newline at end of file diff --git a/analysis_tools/fourier_analysis/vit_models/token_performer.py b/analysis_tools/fourier_analysis/vit_models/token_performer.py new file mode 100644 index 0000000..e2e0f90 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/token_performer.py @@ -0,0 +1,60 @@ +""" +Take Performer as T2T Transformer +""" +import math +import torch +import torch.nn as nn + +class Token_performer(nn.Module): + def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1): + super().__init__() + self.emb = in_dim * head_cnt # we use 1, so it is no need here + self.kqv = nn.Linear(dim, 3 * self.emb) + self.dp = nn.Dropout(dp1) + self.proj = nn.Linear(self.emb, self.emb) + self.head_cnt = head_cnt + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(self.emb) + self.epsilon = 1e-8 # for stable in division + + self.mlp = nn.Sequential( + nn.Linear(self.emb, 1 * self.emb), + nn.GELU(), + nn.Linear(1 * self.emb, self.emb), + nn.Dropout(dp2), + ) + + self.m = int(self.emb * kernel_ratio) + self.w = torch.randn(self.m, self.emb) + self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) + + def prm_exp(self, x): + # part of the function is borrow from https://github.com/lucidrains/performer-pytorch + # and Simo Ryu (https://github.com/cloneofsimo) + # ==== positive random features for gaussian kernels ==== + # x = (B, T, hs) + # w = (m, hs) + # return : x : B, T, m + # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] + # therefore return exp(w^Tx - |x|/2)/sqrt(m) + xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 + wtx = torch.einsum('bti,mi->btm', x.float(), self.w) + + return torch.exp(wtx - xd) / math.sqrt(self.m) + + def single_attn(self, x): + k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) + kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) + D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) + kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m) + y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag + # skip connection + y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection + + return y + + def forward(self, x): + x = self.single_attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + diff --git a/analysis_tools/fourier_analysis/vit_models/token_transformer.py b/analysis_tools/fourier_analysis/vit_models/token_transformer.py new file mode 100644 index 0000000..ab81e8d --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/token_transformer.py @@ -0,0 +1,60 @@ +# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. +# +# This source code is licensed under the Clear BSD License +# LICENSE file in the root directory of this file +# All rights reserved. +""" +Take the standard Transformer as T2T Transformer +""" +import torch.nn as nn +from timm.models.layers import DropPath +from .transformer_block import Mlp + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + self.in_dim = in_dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(in_dim, in_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim) + x = self.proj(x) + x = self.proj_drop(x) + + # skip connection + x = v.squeeze(1) + x # because the original x has different size with current x, use v to do skip connection + + return x + +class Token_transformer(nn.Module): + + def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(in_dim) + self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = self.attn(self.norm1(x)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x diff --git a/analysis_tools/fourier_analysis/vit_models/transformer_block.py b/analysis_tools/fourier_analysis/vit_models/transformer_block.py new file mode 100644 index 0000000..6b28526 --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/transformer_block.py @@ -0,0 +1,88 @@ +# Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. +# +# This source code is licensed under the Clear BSD License +# LICENSE file in the root directory of this file +# All rights reserved. +""" +Borrow from timm(https://github.com/rwightman/pytorch-image-models) +""" +import torch +import torch.nn as nn +import numpy as np +from timm.models.layers import DropPath + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +def get_sinusoid_encoding(n_position, d_hid): + ''' Sinusoid position encoding table ''' + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) diff --git a/analysis_tools/fourier_analysis/vit_models/vit.py b/analysis_tools/fourier_analysis/vit_models/vit.py new file mode 100644 index 0000000..b44d2aa --- /dev/null +++ b/analysis_tools/fourier_analysis/vit_models/vit.py @@ -0,0 +1,150 @@ + +import torch +import torch.nn as nn +from functools import partial + +from timm.models.vision_transformer import VisionTransformer, _cfg +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_ +from timm.models.helpers import * + + +default_cfgs = { + # patch vit_models + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth' + ), + 'vit_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth' + ), + 'vit_base_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth' + ), + 'vit_base_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth' + ), + 'vit_large_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth' + ), + 'vit_large_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth' + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth' + ) +} + +class VanillaVisionTransformer(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward_features(self, x): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add the dist_token + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + layer_wise_tokens = [] + for blk in self.blocks: + x = blk(x) + layer_wise_tokens.append(x) + + layer_wise_tokens = [self.norm(x) for x in layer_wise_tokens] + return [x[:, 0] for x in layer_wise_tokens] + + def forward(self, x): + list_out = self.forward_features(x) + x = [self.head(x) for x in list_out] + return [out for out in x] + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + if pretrained: + # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model + kwargs.setdefault('qk_scale', 768 ** -0.5) + model = VanillaVisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs) + model.default_cfg = default_cfgs['vit_small_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + return model + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + return model + + +@register_model +def vit_base_patch16_384(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch16_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_base_patch32_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_large_patch16_224(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch16_224'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_large_patch16_384(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch16_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model + + +@register_model +def vit_large_patch32_384(pretrained=False, **kwargs): + model = VanillaVisionTransformer( + img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = default_cfgs['vit_large_patch32_384'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + return model \ No newline at end of file