Skip to content

Commit

Permalink
Should address the failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Feb 19, 2025
1 parent 84a2a3a commit 2899a70
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
9 changes: 9 additions & 0 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def forward(
"latent_embeds": latent_embeds,
},
)
if not self.hp.use_post_quant_conv:
sample = rearrange(
sample,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(1024 / 16),
w=math.ceil(1024 / 16),
ph=2,
pw=2,
)
sample = sample / self.hp.scaling_factor + self.hp.shift_factor

if self.hp.use_post_quant_conv:
Expand Down
10 changes: 1 addition & 9 deletions sharktank/sharktank/pipelines/flux/export_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,7 @@ def __init__(self, weight_file, height=1024, width=1024, precision="fp32"):
self.width = width

def forward(self, z):
d_in = rearrange(
z,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
return self.ae.forward(d_in)
return self.ae.forward(z)


def get_ae_model_and_inputs(
Expand Down

0 comments on commit 2899a70

Please sign in to comment.