mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get FLUX non-masked image-to-image working - still rough.
This commit is contained in:
parent
e3a7bf12c1
commit
b33cba500c
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -8,6 +10,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
FluxConditioningField,
|
FluxConditioningField,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
|
LatentsField,
|
||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
@ -21,6 +24,8 @@ from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, pre
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
EPS = 1e-6
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"flux_text_to_image",
|
"flux_text_to_image",
|
||||||
@ -33,6 +38,18 @@ from invokeai.backend.util.devices import TorchDevice
|
|||||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Text-to-image generation using a FLUX model."""
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
|
# If latents is provided, this means we are doing image-to-image.
|
||||||
|
latents: Optional[LatentsField] = InputField(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.latents,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
denoising_start: float = InputField(
|
||||||
|
default=0.0,
|
||||||
|
ge=0,
|
||||||
|
le=1,
|
||||||
|
description=FieldDescriptions.denoising_start,
|
||||||
|
)
|
||||||
transformer: TransformerField = InputField(
|
transformer: TransformerField = InputField(
|
||||||
description=FieldDescriptions.flux_model,
|
description=FieldDescriptions.flux_model,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -78,7 +95,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
t5_embeddings = flux_conditioning.t5_embeds
|
t5_embeddings = flux_conditioning.t5_embeds
|
||||||
clip_embeddings = flux_conditioning.clip_embeds
|
clip_embeddings = flux_conditioning.clip_embeds
|
||||||
|
|
||||||
transformer_info = context.models.load(self.transformer.transformer)
|
# Load the input latents, if provided.
|
||||||
|
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||||
|
if init_latents is not None:
|
||||||
|
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||||
|
|
||||||
# Prepare input noise.
|
# Prepare input noise.
|
||||||
x = get_noise(
|
x = get_noise(
|
||||||
@ -90,8 +110,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
x, img_ids = prepare_latent_img_patches(x)
|
transformer_info = context.models.load(self.transformer.transformer)
|
||||||
|
|
||||||
is_schnell = "schnell" in transformer_info.config.config_path
|
is_schnell = "schnell" in transformer_info.config.config_path
|
||||||
|
|
||||||
timesteps = get_schedule(
|
timesteps = get_schedule(
|
||||||
@ -100,6 +119,22 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
shift=not is_schnell,
|
shift=not is_schnell,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare inputs for image-to-image case.
|
||||||
|
if self.denoising_start > EPS:
|
||||||
|
if init_latents is None:
|
||||||
|
raise ValueError("latents must be provided if denoising_start > 0.")
|
||||||
|
|
||||||
|
# Clip the timesteps schedule based on denoising_start.
|
||||||
|
# TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space?
|
||||||
|
start_idx = int(self.denoising_start * len(timesteps))
|
||||||
|
timesteps = timesteps[start_idx:]
|
||||||
|
|
||||||
|
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||||
|
t_0 = timesteps[0]
|
||||||
|
x = t_0 * x + (1.0 - t_0) * init_latents
|
||||||
|
|
||||||
|
x, img_ids = prepare_latent_img_patches(x)
|
||||||
|
|
||||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ from invokeai.backend.model_manager import LoadedModel
|
|||||||
from invokeai.backend.model_manager.config import BaseModelType
|
from invokeai.backend.model_manager.config import BaseModelType
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -59,9 +60,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||||
# should be used for VAE encode sampling.
|
# should be used for VAE encode sampling.
|
||||||
generator = torch.Generator().manual_seed(0)
|
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
assert isinstance(vae, AutoEncoder)
|
assert isinstance(vae, AutoEncoder)
|
||||||
|
image_tensor = image_tensor.to(
|
||||||
|
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||||
|
)
|
||||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ def get_schedule(
|
|||||||
|
|
||||||
# shifting the schedule to favor high timesteps for higher signal images
|
# shifting the schedule to favor high timesteps for higher signal images
|
||||||
if shift:
|
if shift:
|
||||||
# eastimate mu based on linear estimation between two points
|
# estimate mu based on linear estimation between two points
|
||||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||||
timesteps = time_shift(mu, 1.0, timesteps)
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user