mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get a rough version of FLUX inpainting working.
This commit is contained in:
parent
e0f12c762e
commit
7d854f32b0
@ -185,7 +185,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
ui_order=8,
|
||||
)
|
||||
|
@ -181,7 +181,7 @@ class FieldDescriptions:
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
denoise_mask = "A mask of the region to apply the denoising process to."
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
@ -16,8 +19,15 @@ from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
generate_img_ids,
|
||||
get_noise,
|
||||
get_schedule,
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -41,6 +51,11 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
@ -95,7 +110,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
init_latents = init_latents.to(device=TorchDevice.choose_torch_device(), dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise.
|
||||
x = get_noise(
|
||||
noise = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
@ -107,14 +122,16 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
is_schnell = "schnell" in transformer_info.config.config_path
|
||||
|
||||
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=x.shape[1],
|
||||
image_seq_len=image_seq_len,
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Prepare inputs for image-to-image case.
|
||||
# Prepare input latent image.
|
||||
if self.denoising_start > EPS:
|
||||
# If denoising_start > 0, we are doing image-to-image.
|
||||
if init_latents is None:
|
||||
raise ValueError("latents must be provided if denoising_start > 0.")
|
||||
|
||||
@ -125,13 +142,34 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Noise the orig_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * x + (1.0 - t_0) * init_latents
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# We are not doing image-to-image, so we are starting from noise.
|
||||
x = noise
|
||||
|
||||
x, img_ids = prepare_latent_img_patches(x)
|
||||
# Prepare inpaint mask.
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
# Expand the inpaint mask to the same shape as the init_latents so that when we pack inpaint_mask it lines
|
||||
# up with the init_latents.
|
||||
inpaint_mask = inpaint_mask.expand_as(init_latents)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# Pack all latent tensors.
|
||||
init_latents = pack(init_latents) if init_latents is not None else None
|
||||
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise = pack(noise)
|
||||
x = pack(x)
|
||||
|
||||
# Verify that we calculated the image_seq_len correctly.
|
||||
assert image_seq_len == x.shape[1]
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
@ -174,8 +212,33 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
timesteps=timesteps,
|
||||
step_callback=step_callback,
|
||||
guidance=self.guidance,
|
||||
init_latents=init_latents,
|
||||
noise=noise,
|
||||
inpaint_mask=inpaint_mask,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
|
||||
Loads the mask, resizes if necessary, casts to same device/dtype as latents.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor | None, bool]: (mask, is_gradient_mask)
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
55
invokeai/backend/flux/denoise.py
Normal file
55
invokeai/backend/flux/denoise.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint import merge_intermediate_latents_with_init_latents
|
||||
from invokeai.backend.flux.model import Flux
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float,
|
||||
# For inpainting:
|
||||
init_latents: torch.Tensor | None,
|
||||
noise: torch.Tensor,
|
||||
inpaint_mask: torch.Tensor | None,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
img = merge_intermediate_latents_with_init_latents(
|
||||
init_latents=init_latents,
|
||||
intermediate_latents=img,
|
||||
timestep=t_prev,
|
||||
noise=noise,
|
||||
inpaint_mask=inpaint_mask,
|
||||
)
|
||||
|
||||
step_callback()
|
||||
|
||||
return img
|
15
invokeai/backend/flux/inpaint.py
Normal file
15
invokeai/backend/flux/inpaint.py
Normal file
@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
init_latents: torch.Tensor,
|
||||
intermediate_latents: torch.Tensor,
|
||||
timestep: float,
|
||||
noise: torch.Tensor,
|
||||
inpaint_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Noise the init_latents for the current timestep.
|
||||
noised_init_latents = noise * timestep + (1.0 - timestep) * init_latents
|
||||
|
||||
# Merge the intermediate_latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * inpaint_mask + noised_init_latents * (1.0 - inpaint_mask)
|
@ -6,10 +6,6 @@ from typing import Callable
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
@ -35,40 +31,6 @@ def get_noise(
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
@ -98,39 +60,6 @@ def get_schedule(
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[], None],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_callback()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
@ -142,21 +71,34 @@ def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
)
|
||||
|
||||
|
||||
def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def pack(x: Tensor) -> Tensor:
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtype: torch.dtype) -> Tensor:
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
|
||||
|
||||
def prepare_latent_img_patches(img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert an input image in latent space to patches for diffusion.
|
||||
|
||||
This implementation was extracted from:
|
||||
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Input image in latent space.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
|
||||
"""
|
||||
bs, c, h, w = latent_img.shape
|
||||
bs, c, h, w = img.shape
|
||||
|
||||
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
|
||||
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
img = pack(img)
|
||||
|
||||
# Generate patch position ids.
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device, dtype=img.dtype)
|
Loading…
Reference in New Issue
Block a user