diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 19829c47a4..229c1fdf46 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -118,7 +118,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): img, img_ids = self._prepare_latent_img_patches(x) # HACK(ryand): Find a better way to determine if this is a schnell model or not. - is_schnell = "shnell" in str(flux_transformer_path) + is_schnell = "schnell" in str(flux_transformer_path) timesteps = get_schedule( num_steps=self.num_steps, image_seq_len=img.shape[1], @@ -174,6 +174,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + img_ids = img_ids.to(latent_img.device) return img, img_ids @@ -239,13 +240,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): else: raise ValueError(f"Unsupported quantization type: {self.quantization_type}") - assert isinstance(model, FluxTransformer2DModel) + assert isinstance(model, Flux) return model @staticmethod def _load_flux_vae(path: Path) -> AutoEncoder: # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - ae_params = flux_configs["flux1-schnell"].ae_params + ae_params = flux_configs["flux-schnell"].ae_params with accelerate.init_empty_weights(): ae = AutoEncoder(ae_params)