Skip to content

Commit

Permalink
add drop_path to fuxi
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Nov 10, 2024
1 parent f7b72a4 commit d8e8385
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions credit/models/fuxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ class UTransformer(nn.Module):
"""

def __init__(
self, embed_dim, num_groups, input_resolution, num_heads, window_size, depth
self, embed_dim,
num_groups,
input_resolution,
num_heads,
window_size,
depth,
drop_path
):
super().__init__()
num_groups = to_2tuple(num_groups)
Expand All @@ -248,7 +254,13 @@ def __init__(

# SwinT block
self.layer = SwinTransformerV2Stage(
embed_dim, embed_dim, input_resolution, depth, num_heads, window_size[0]
embed_dim,
embed_dim,
input_resolution,
depth,
num_heads,
window_size[0],
drop_path=drop_path
) # <--- window_size[0] get window_size[int] from tuple

# up-sampling block
Expand Down Expand Up @@ -315,6 +327,7 @@ def __init__(
window_size=7,
use_spectral_norm=True,
interp=True,
drop_path=0,
padding_conf=None,
post_conf=None,
**kwargs,
Expand Down Expand Up @@ -363,7 +376,12 @@ def __init__(

# Downsampling --> SwinTransformerV2 stacks --> Upsampling
self.u_transformer = UTransformer(
dim, num_groups, input_resolution, num_heads, window_size, depth=depth
dim, num_groups,
input_resolution,
num_heads,
window_size,
depth=depth,
drop_path=drop_path
)

# dense layer applied on channel dmension
Expand Down

0 comments on commit d8e8385

Please sign in to comment.