diff --git a/threestudio/models/prompt_processors/base.py b/threestudio/models/prompt_processors/base.py index a85d3719..9a1ccf76 100644 --- a/threestudio/models/prompt_processors/base.py +++ b/threestudio/models/prompt_processors/base.py @@ -56,7 +56,6 @@ def get_text_embeddings( azimuth: Float[Tensor, "B"], camera_distances: Float[Tensor, "B"], view_dependent_prompting: bool = True, - return_prompt: bool = False, ) -> Float[Tensor, "BB N Nf"]: batch_size = elevation.shape[0] @@ -71,19 +70,14 @@ def get_text_embeddings( # Get text embeddings text_embeddings = self.text_embeddings_vd[direction_idx] # type: ignore uncond_text_embeddings = self.uncond_text_embeddings_vd[direction_idx] # type: ignore - prompts = self.prompts_vd[direction_idx] else: text_embeddings = self.text_embeddings.expand(batch_size, -1, -1) # type: ignore uncond_text_embeddings = self.uncond_text_embeddings.expand( # type: ignore batch_size, -1, -1 ) - prompts = self.prompt # IMPORTANT: we return (cond, uncond), which is in different order than other implementations! - if not return_prompt: - return torch.cat([text_embeddings, uncond_text_embeddings], dim=0) - else: - return torch.cat([text_embeddings, uncond_text_embeddings], dim=0), prompts + return torch.cat([text_embeddings, uncond_text_embeddings], dim=0) def get_text_embeddings_perp_neg( self,