diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 248122d8cd..dacc282543 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from einops import rearrange from PIL import Image @@ -8,6 +10,7 @@ from invokeai.app.invocations.fields import ( FluxConditioningField, Input, InputField, + LatentsField, WithBoard, WithMetadata, ) @@ -21,6 +24,8 @@ from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, pre from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice +EPS = 1e-6 + @invocation( "flux_text_to_image", @@ -33,6 +38,18 @@ from invokeai.backend.util.devices import TorchDevice class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Text-to-image generation using a FLUX model.""" + # If latents is provided, this means we are doing image-to-image. + latents: Optional[LatentsField] = InputField( + default=None, + description=FieldDescriptions.latents, + input=Input.Connection, + ) + denoising_start: float = InputField( + default=0.0, + ge=0, + le=1, + description=FieldDescriptions.denoising_start, + ) transformer: TransformerField = InputField( description=FieldDescriptions.flux_model, input=Input.Connection, @@ -78,7 +95,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): t5_embeddings = flux_conditioning.t5_embeds clip_embeddings = flux_conditioning.clip_embeds - transformer_info = context.models.load(self.transformer.transformer) + # Load the input latents, if provided. + init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None + if init_latents is not None: + init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype) # Prepare input noise. x = get_noise( @@ -90,8 +110,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): seed=self.seed, ) - x, img_ids = prepare_latent_img_patches(x) - + transformer_info = context.models.load(self.transformer.transformer) is_schnell = "schnell" in transformer_info.config.config_path timesteps = get_schedule( @@ -100,6 +119,22 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): shift=not is_schnell, ) + # Prepare inputs for image-to-image case. + if self.denoising_start > EPS: + if init_latents is None: + raise ValueError("latents must be provided if denoising_start > 0.") + + # Clip the timesteps schedule based on denoising_start. + # TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space? + start_idx = int(self.denoising_start * len(timesteps)) + timesteps = timesteps[start_idx:] + + # Noise the orig_latents by the appropriate amount for the first timestep. + t_0 = timesteps[0] + x = t_0 * x + (1.0 - t_0) * init_latents + + x, img_ids = prepare_latent_img_patches(x) + bs, t5_seq_len, _ = t5_embeddings.shape txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) diff --git a/invokeai/app/invocations/image_to_latents.py b/invokeai/app/invocations/image_to_latents.py index e277173b70..8eba975898 100644 --- a/invokeai/app/invocations/image_to_latents.py +++ b/invokeai/app/invocations/image_to_latents.py @@ -28,6 +28,7 @@ from invokeai.backend.model_manager import LoadedModel from invokeai.backend.model_manager.config import BaseModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params +from invokeai.backend.util.devices import TorchDevice @invocation( @@ -59,9 +60,12 @@ class ImageToLatentsInvocation(BaseInvocation): # TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes. # There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function # should be used for VAE encode sampling. - generator = torch.Generator().manual_seed(0) + generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) with vae_info as vae: assert isinstance(vae, AutoEncoder) + image_tensor = image_tensor.to( + device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype() + ) latents = vae.encode(image_tensor, sample=True, generator=generator) return latents diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py index 7a35b0aedf..3c3103411a 100644 --- a/invokeai/backend/flux/sampling.py +++ b/invokeai/backend/flux/sampling.py @@ -91,7 +91,7 @@ def get_schedule( # shifting the schedule to favor high timesteps for higher signal images if shift: - # eastimate mu based on linear estimation between two points + # estimate mu based on linear estimation between two points mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps)