Skip to content

Commit

Permalink
add atten_drop and proj_drop to fuxi
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Nov 16, 2024
1 parent f7a2c7e commit b843a9d
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion credit/models/fuxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def __init__(
num_heads,
window_size,
depth,
proj_drop,
attn_drop,
drop_path
):
super().__init__()
Expand Down Expand Up @@ -260,6 +262,8 @@ def __init__(
depth,
num_heads,
window_size[0],
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
) # <--- window_size[0] get window_size[int] from tuple

Expand Down Expand Up @@ -327,6 +331,8 @@ def __init__(
window_size=7,
use_spectral_norm=True,
interp=True,
proj_drop=0,
attn_drop=0,
drop_path=0,
padding_conf=None,
post_conf=None,
Expand Down Expand Up @@ -379,14 +385,16 @@ def __init__(
self.cube_embedding = CubeEmbedding(img_size, patch_size, in_chans, dim)

# Downsampling --> SwinTransformerV2 stacks --> Upsampling
logger.info(f"Define UTransforme with drop path: {drop_path}")
logger.info(f"Define UTransforme with proj_drop={proj_drop}, attn_drop={attn_drop}, drop_path={drop_path}")

self.u_transformer = UTransformer(
dim, num_groups,
input_resolution,
num_heads,
window_size,
depth=depth,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path
)

Expand Down

0 comments on commit b843a9d

Please sign in to comment.