From 2899a7053d96275b2e539312af056e8d8da83efc Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Wed, 19 Feb 2025 19:12:43 +0000 Subject: [PATCH] Should address the failing tests --- sharktank/sharktank/models/vae/model.py | 9 +++++++++ .../sharktank/pipelines/flux/export_components.py | 10 +--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sharktank/sharktank/models/vae/model.py b/sharktank/sharktank/models/vae/model.py index 126c05f2c..e4dd0845c 100644 --- a/sharktank/sharktank/models/vae/model.py +++ b/sharktank/sharktank/models/vae/model.py @@ -74,6 +74,15 @@ def forward( "latent_embeds": latent_embeds, }, ) + if not self.hp.use_post_quant_conv: + sample = rearrange( + sample, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(1024 / 16), + w=math.ceil(1024 / 16), + ph=2, + pw=2, + ) sample = sample / self.hp.scaling_factor + self.hp.shift_factor if self.hp.use_post_quant_conv: diff --git a/sharktank/sharktank/pipelines/flux/export_components.py b/sharktank/sharktank/pipelines/flux/export_components.py index 85c7888de..334dbdadb 100644 --- a/sharktank/sharktank/pipelines/flux/export_components.py +++ b/sharktank/sharktank/pipelines/flux/export_components.py @@ -267,15 +267,7 @@ def __init__(self, weight_file, height=1024, width=1024, precision="fp32"): self.width = width def forward(self, z): - d_in = rearrange( - z, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(self.height / 16), - w=math.ceil(self.width / 16), - ph=2, - pw=2, - ) - return self.ae.forward(d_in) + return self.ae.forward(z) def get_ae_model_and_inputs(