Skip to content

Commit ebf82b8

Browse files
authored
Merge pull request #122 from mrT23/master
TResNet models
2 parents e15f979 + 8a63c1a commit ebf82b8

File tree

5 files changed

+409
-0
lines changed

5 files changed

+409
-0
lines changed

timm/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .dla import *
1818
from .hrnet import *
1919
from .sknet import *
20+
from .tresnet import *
2021

2122
from .registry import *
2223
from .factory import create_model

timm/models/layers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@
1616
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1717
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
1818
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
19+
from .anti_aliasing import AntiAliasDownsampleLayer
20+
from .space_to_depth import SpaceToDepthModule

timm/models/layers/anti_aliasing.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch.nn.parallel
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
7+
class AntiAliasDownsampleLayer(nn.Module):
8+
def __init__(self, remove_aa_jit: bool = False, filt_size: int = 3, stride: int = 2,
9+
channels: int = 0):
10+
super(AntiAliasDownsampleLayer, self).__init__()
11+
if not remove_aa_jit:
12+
self.op = DownsampleJIT(filt_size, stride, channels)
13+
else:
14+
self.op = Downsample(filt_size, stride, channels)
15+
16+
def forward(self, x):
17+
return self.op(x)
18+
19+
20+
@torch.jit.script
21+
class DownsampleJIT(object):
22+
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
23+
self.stride = stride
24+
self.filt_size = filt_size
25+
self.channels = channels
26+
27+
assert self.filt_size == 3
28+
assert stride == 2
29+
a = torch.tensor([1., 2., 1.])
30+
31+
filt = (a[:, None] * a[None, :]).clone().detach()
32+
filt = filt / torch.sum(filt)
33+
self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half()
34+
35+
def __call__(self, input: torch.Tensor):
36+
if input.dtype != self.filt.dtype:
37+
self.filt = self.filt.float()
38+
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
39+
return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1])
40+
41+
42+
class Downsample(nn.Module):
43+
def __init__(self, filt_size=3, stride=2, channels=None):
44+
super(Downsample, self).__init__()
45+
self.filt_size = filt_size
46+
self.stride = stride
47+
self.channels = channels
48+
49+
50+
assert self.filt_size == 3
51+
a = torch.tensor([1., 2., 1.])
52+
53+
filt = (a[:, None] * a[None, :])
54+
filt = filt / torch.sum(filt)
55+
56+
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
57+
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
58+
59+
def forward(self, input):
60+
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
61+
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])

timm/models/layers/space_to_depth.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class SpaceToDepth(nn.Module):
6+
def __init__(self, block_size=4):
7+
super().__init__()
8+
assert block_size == 4
9+
self.bs = block_size
10+
11+
def forward(self, x):
12+
N, C, H, W = x.size()
13+
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
14+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
15+
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
16+
return x
17+
18+
19+
@torch.jit.script
20+
class SpaceToDepthJit(object):
21+
def __call__(self, x: torch.Tensor):
22+
# assuming hard-coded that block_size==4 for acceleration
23+
N, C, H, W = x.size()
24+
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
25+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
26+
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
27+
return x
28+
29+
30+
class SpaceToDepthModule(nn.Module):
31+
def __init__(self, remove_model_jit=False):
32+
super().__init__()
33+
if not remove_model_jit:
34+
self.op = SpaceToDepthJit()
35+
else:
36+
self.op = SpaceToDepth()
37+
38+
def forward(self, x):
39+
return self.op(x)
40+
41+
42+
class DepthToSpace(nn.Module):
43+
44+
def __init__(self, block_size):
45+
super().__init__()
46+
self.bs = block_size
47+
48+
def forward(self, x):
49+
N, C, H, W = x.size()
50+
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
51+
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
52+
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
53+
return x

0 commit comments

Comments
 (0)