Get FLUX non-masked image-to-image working - still rough.

This commit is contained in:
Ryan Dick 2024-08-29 14:17:08 +00:00
parent e3a7bf12c1
commit b33cba500c
3 changed files with 44 additions and 5 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional
import torch
from einops import rearrange
from PIL import Image
@ -8,6 +10,7 @@ from invokeai.app.invocations.fields import (
FluxConditioningField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
@ -21,6 +24,8 @@ from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, pre
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
EPS = 1e-6
@invocation(
"flux_text_to_image",
@ -33,6 +38,18 @@ from invokeai.backend.util.devices import TorchDevice
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
@ -78,7 +95,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
transformer_info = context.models.load(self.transformer.transformer)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
# Prepare input noise.
x = get_noise(
@ -90,8 +110,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed,
)
x, img_ids = prepare_latent_img_patches(x)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
@ -100,6 +119,22 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
shift=not is_schnell,
)
# Prepare inputs for image-to-image case.
if self.denoising_start > EPS:
if init_latents is None:
raise ValueError("latents must be provided if denoising_start > 0.")
# Clip the timesteps schedule based on denoising_start.
# TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space?
start_idx = int(self.denoising_start * len(timesteps))
timesteps = timesteps[start_idx:]
# Noise the orig_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
x = t_0 * x + (1.0 - t_0) * init_latents
x, img_ids = prepare_latent_img_patches(x)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())

View File

@ -28,6 +28,7 @@ from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@invocation(
@ -59,9 +60,12 @@ class ImageToLatentsInvocation(BaseInvocation):
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
# should be used for VAE encode sampling.
generator = torch.Generator().manual_seed(0)
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents

View File

@ -91,7 +91,7 @@ def get_schedule(
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
# estimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)