1
+ import torch
2
+ import torch .nn .parallel
3
+ import torch .nn as nn
4
+ import torch .nn .functional as F
5
+
6
+
7
+ class AntiAliasDownsampleLayer (nn .Module ):
8
+ def __init__ (self , remove_aa_jit : bool = False , filt_size : int = 3 , stride : int = 2 ,
9
+ channels : int = 0 ):
10
+ super (AntiAliasDownsampleLayer , self ).__init__ ()
11
+ if not remove_aa_jit :
12
+ self .op = DownsampleJIT (filt_size , stride , channels )
13
+ else :
14
+ self .op = Downsample (filt_size , stride , channels )
15
+
16
+ def forward (self , x ):
17
+ return self .op (x )
18
+
19
+
20
+ @torch .jit .script
21
+ class DownsampleJIT (object ):
22
+ def __init__ (self , filt_size : int = 3 , stride : int = 2 , channels : int = 0 ):
23
+ self .stride = stride
24
+ self .filt_size = filt_size
25
+ self .channels = channels
26
+
27
+ assert self .filt_size == 3
28
+ assert stride == 2
29
+ a = torch .tensor ([1. , 2. , 1. ])
30
+
31
+ filt = (a [:, None ] * a [None , :]).clone ().detach ()
32
+ filt = filt / torch .sum (filt )
33
+ self .filt = filt [None , None , :, :].repeat ((self .channels , 1 , 1 , 1 )).cuda ().half ()
34
+
35
+ def __call__ (self , input : torch .Tensor ):
36
+ if input .dtype != self .filt .dtype :
37
+ self .filt = self .filt .float ()
38
+ input_pad = F .pad (input , (1 , 1 , 1 , 1 ), 'reflect' )
39
+ return F .conv2d (input_pad , self .filt , stride = 2 , padding = 0 , groups = input .shape [1 ])
40
+
41
+
42
+ class Downsample (nn .Module ):
43
+ def __init__ (self , filt_size = 3 , stride = 2 , channels = None ):
44
+ super (Downsample , self ).__init__ ()
45
+ self .filt_size = filt_size
46
+ self .stride = stride
47
+ self .channels = channels
48
+
49
+
50
+ assert self .filt_size == 3
51
+ a = torch .tensor ([1. , 2. , 1. ])
52
+
53
+ filt = (a [:, None ] * a [None , :])
54
+ filt = filt / torch .sum (filt )
55
+
56
+ # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
57
+ self .register_buffer ('filt' , filt [None , None , :, :].repeat ((self .channels , 1 , 1 , 1 )))
58
+
59
+ def forward (self , input ):
60
+ input_pad = F .pad (input , (1 , 1 , 1 , 1 ), 'reflect' )
61
+ return F .conv2d (input_pad , self .filt , stride = self .stride , padding = 0 , groups = input .shape [1 ])
0 commit comments