diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 3ac45ba19c..248122d8cd 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -58,13 +58,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - # Load the conditioning data. - cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) - assert len(cond_data.conditionings) == 1 - flux_conditioning = cond_data.conditionings[0] - assert isinstance(flux_conditioning, FLUXConditioningInfo) - - latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) + latents = self._run_diffusion(context) image = self._run_vae_decoding(context, latents) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) @@ -72,12 +66,20 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def _run_diffusion( self, context: InvocationContext, - clip_embeddings: torch.Tensor, - t5_embeddings: torch.Tensor, ): - transformer_info = context.models.load(self.transformer.transformer) inference_dtype = torch.bfloat16 + # Load the conditioning data. + cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + flux_conditioning = flux_conditioning.to(dtype=inference_dtype) + t5_embeddings = flux_conditioning.t5_embeds + clip_embeddings = flux_conditioning.clip_embeds + + transformer_info = context.models.load(self.transformer.transformer) + # Prepare input noise. x = get_noise( num_samples=1, @@ -88,13 +90,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): seed=self.seed, ) - img, img_ids = prepare_latent_img_patches(x) + x, img_ids = prepare_latent_img_patches(x) is_schnell = "schnell" in transformer_info.config.config_path timesteps = get_schedule( num_steps=self.num_steps, - image_seq_len=img.shape[1], + image_seq_len=x.shape[1], shift=not is_schnell, ) @@ -135,7 +137,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): x = denoise( model=transformer, - img=img, + img=x, img_ids=img_ids, txt=t5_embeddings, txt_ids=txt_ids, diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 19de48ae81..7a35b0aedf 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -111,16 +111,7 @@ def denoise( step_callback: Callable[[], None], guidance: float = 4.0, ): - dtype = model.txt_in.bias.dtype - - # TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller. - img = img.to(dtype=dtype) - img_ids = img_ids.to(dtype=dtype) - txt = txt.to(dtype=dtype) - txt_ids = txt_ids.to(dtype=dtype) - vec = vec.to(dtype=dtype) - - # this is ignored for schnell + # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) @@ -168,9 +159,9 @@ def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, img = repeat(img, "1 ... -> bs ...", bs=bs) # Generate patch position ids. - img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] + img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device, dtype=img.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device, dtype=img.dtype)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) return img, img_ids diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index c5fda909c7..b7e9038cf7 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -43,6 +43,11 @@ class FLUXConditioningInfo: clip_embeds: torch.Tensor t5_embeds: torch.Tensor + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + self.clip_embeds = self.clip_embeds.to(device=device, dtype=dtype) + self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype) + return self + @dataclass class ConditioningFieldData: