mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get FLUX non-masked image-to-image working - still rough.
This commit is contained in:
parent
e3a7bf12c1
commit
b33cba500c
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
@ -8,6 +10,7 @@ from invokeai.app.invocations.fields import (
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
@ -21,6 +24,8 @@ from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, pre
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
@ -33,6 +38,18 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Text-to-image generation using a FLUX model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
@ -78,7 +95,10 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
t5_embeddings = flux_conditioning.t5_embeds
|
||||
clip_embeddings = flux_conditioning.clip_embeds
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
@ -90,8 +110,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
timesteps = get_schedule(
|
||||
@ -100,6 +119,22 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Prepare inputs for image-to-image case.
|
||||
if self.denoising_start > EPS:
|
||||
if init_latents is None:
|
||||
raise ValueError("latents must be provided if denoising_start > 0.")
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start.
|
||||
# TODO(ryand): Should we apply denoising_start in timestep-space rather than timestep-index-space?
|
||||
start_idx = int(self.denoising_start * len(timesteps))
|
||||
timesteps = timesteps[start_idx:]
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * x + (1.0 - t_0) * init_latents
|
||||
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
|
@ -28,6 +28,7 @@ from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -59,9 +60,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
|
||||
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
|
||||
# should be used for VAE encode sampling.
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
|
@ -91,7 +91,7 @@ def get_schedule(
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
# estimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user