Skip to content

Commit

Permalink
rollback crossformer code to the main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Oct 19, 2024
1 parent 91d04c3 commit 0cd923b
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions credit/models/crossformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def __init__(
# define embedding layer using adjusted sizes
# if the original sizes were good, adjusted sizes should == original sizes
self.cube_embedding = CubeEmbedding(
(frames, self.image_height_adjust, self.image_width_adjust),
(frames, image_height, image_width),
(frames, patch_height, patch_width),
input_channels,
dim[0]
Expand Down Expand Up @@ -509,11 +509,7 @@ def integrate_step(x, k, factor):
input_only_channels = 3
frame_patch_size = 2

input_tensor = torch.randn(1,
channels * levels + surface_channels + input_only_channels,
frames,
image_height,
image_width).to("cuda")
input_tensor = torch.randn(1, channels * levels + surface_channels + input_only_channels, frames, image_height, image_width).to("cuda")

model = CrossFormer(
image_height=image_height,
Expand All @@ -524,12 +520,12 @@ def integrate_step(x, k, factor):
surface_channels=surface_channels,
input_only_channels=input_only_channels,
levels=levels,
dim=(32, 64, 128, 256),
depth=(2, 2, 8, 2),
global_window_size=(10, 5, 2, 1),
dim=(128, 256, 512, 1024),
depth=(2, 2, 18, 2),
global_window_size=(8, 4, 2, 1),
local_window_size=5,
cross_embed_kernel_sizes=((4, 8, 16, 32), (2, 4), (2, 4), (2, 4)),
cross_embed_strides=(2, 2, 2, 2),
cross_embed_strides=(4, 2, 2, 2),
attn_dropout=0.,
ff_dropout=0.,
).to("cuda")
Expand All @@ -540,4 +536,4 @@ def integrate_step(x, k, factor):
y_pred = model(input_tensor.to("cuda"))
print("Predicted shape:", y_pred.shape)

# print(model.rk4(input_tensor.to("cpu")).shape)
# print(model.rk4(input_tensor.to("cpu")).shape)

0 comments on commit 0cd923b

Please sign in to comment.