From d8e8385983491ff99e5742e0a822677b51c198e3 Mon Sep 17 00:00:00 2001 From: Yingkai Sha Date: Sat, 9 Nov 2024 22:22:04 -0700 Subject: [PATCH] add drop_path to fuxi --- credit/models/fuxi.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/credit/models/fuxi.py b/credit/models/fuxi.py index 162e880..945b1a5 100644 --- a/credit/models/fuxi.py +++ b/credit/models/fuxi.py @@ -226,7 +226,13 @@ class UTransformer(nn.Module): """ def __init__( - self, embed_dim, num_groups, input_resolution, num_heads, window_size, depth + self, embed_dim, + num_groups, + input_resolution, + num_heads, + window_size, + depth, + drop_path ): super().__init__() num_groups = to_2tuple(num_groups) @@ -248,7 +254,13 @@ def __init__( # SwinT block self.layer = SwinTransformerV2Stage( - embed_dim, embed_dim, input_resolution, depth, num_heads, window_size[0] + embed_dim, + embed_dim, + input_resolution, + depth, + num_heads, + window_size[0], + drop_path=drop_path ) # <--- window_size[0] get window_size[int] from tuple # up-sampling block @@ -315,6 +327,7 @@ def __init__( window_size=7, use_spectral_norm=True, interp=True, + drop_path=0, padding_conf=None, post_conf=None, **kwargs, @@ -363,7 +376,12 @@ def __init__( # Downsampling --> SwinTransformerV2 stacks --> Upsampling self.u_transformer = UTransformer( - dim, num_groups, input_resolution, num_heads, window_size, depth=depth + dim, num_groups, + input_resolution, + num_heads, + window_size, + depth=depth, + drop_path=drop_path ) # dense layer applied on channel dmension