mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add naive ControlNet support to TiledStableDiffusionRefineInvocation
This commit is contained in:
parent
d08e405017
commit
5301770525
@ -1,4 +1,7 @@
|
|||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -17,14 +20,16 @@ from invokeai.app.invocations.fields import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
||||||
from invokeai.app.invocations.model import UNetField, VAEField
|
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
|
||||||
from invokeai.app.invocations.noise import get_noise
|
from invokeai.app.invocations.noise import get_noise
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
|
||||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
|
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
|
||||||
from invokeai.backend.tiles.utils import Tile
|
from invokeai.backend.tiles.utils import Tile
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -66,10 +71,6 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="UNet",
|
title="UNet",
|
||||||
)
|
)
|
||||||
# control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
|
||||||
# default=None,
|
|
||||||
# input=Input.Connection,
|
|
||||||
# )
|
|
||||||
cfg_rescale_multiplier: float = InputField(
|
cfg_rescale_multiplier: float = InputField(
|
||||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
)
|
)
|
||||||
@ -80,6 +81,15 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
vae_fp32: bool = InputField(
|
vae_fp32: bool = InputField(
|
||||||
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
||||||
)
|
)
|
||||||
|
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
|
||||||
|
# don't want to use the image field. Figure out how best to handle this.
|
||||||
|
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
|
||||||
|
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
|
||||||
|
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
|
||||||
|
# etc.
|
||||||
|
control_model: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("cfg_scale")
|
@field_validator("cfg_scale")
|
||||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||||
@ -112,6 +122,41 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
right = image_tile.coords.right // LATENT_SCALE_FACTOR
|
right = image_tile.coords.right // LATENT_SCALE_FACTOR
|
||||||
return latents[..., top:bottom, left:right]
|
return latents[..., top:bottom, left:right]
|
||||||
|
|
||||||
|
def run_controlnet(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
controlnet_model: ControlNetModel,
|
||||||
|
weight: float,
|
||||||
|
do_classifier_free_guidance: bool,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||||
|
) -> ControlNetData:
|
||||||
|
control_image = prepare_control_image(
|
||||||
|
image=image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
control_mode=control_mode,
|
||||||
|
resize_mode=resize_mode,
|
||||||
|
)
|
||||||
|
return ControlNetData(
|
||||||
|
model=controlnet_model,
|
||||||
|
image_tensor=control_image,
|
||||||
|
weight=weight,
|
||||||
|
begin_step_percent=0.0,
|
||||||
|
end_step_percent=1.0,
|
||||||
|
control_mode=control_mode,
|
||||||
|
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
|
||||||
|
# ControlNetData in case needed in the future.
|
||||||
|
resize_mode=resize_mode,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# TODO(ryand): Expose the seed parameter.
|
# TODO(ryand): Expose the seed parameter.
|
||||||
@ -119,8 +164,6 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Load the input image.
|
# Load the input image.
|
||||||
input_image = context.images.get_pil(self.image.image_name)
|
input_image = context.images.get_pil(self.image.image_name)
|
||||||
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
|
||||||
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
|
||||||
|
|
||||||
# Calculate the tile locations to cover the image.
|
# Calculate the tile locations to cover the image.
|
||||||
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
||||||
@ -132,12 +175,15 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
min_overlap=128,
|
min_overlap=128,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert the input image to a torch.Tensor.
|
||||||
|
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
||||||
|
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
||||||
# Validate our assumptions about the shape of input_image_torch.
|
# Validate our assumptions about the shape of input_image_torch.
|
||||||
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
|
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||||
assert input_image_torch.shape[:2] == (1, 3)
|
assert input_image_torch.shape[:2] == (1, 3)
|
||||||
|
|
||||||
# Split the input image into tiles.
|
# Split the input image into tiles in torch.Tensor format.
|
||||||
image_tiles: list[torch.Tensor] = []
|
image_tiles_torch: list[torch.Tensor] = []
|
||||||
for tile in tiles:
|
for tile in tiles:
|
||||||
image_tile = input_image_torch[
|
image_tile = input_image_torch[
|
||||||
:,
|
:,
|
||||||
@ -145,17 +191,30 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
tile.coords.top : tile.coords.bottom,
|
tile.coords.top : tile.coords.bottom,
|
||||||
tile.coords.left : tile.coords.right,
|
tile.coords.left : tile.coords.right,
|
||||||
]
|
]
|
||||||
image_tiles.append(image_tile)
|
image_tiles_torch.append(image_tile)
|
||||||
|
|
||||||
|
# Split the input image into tiles in numpy format.
|
||||||
|
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
|
||||||
|
# with torch.Tensor tiles.
|
||||||
|
input_image_np = np.array(input_image)
|
||||||
|
image_tiles_np: list[npt.NDArray[np.uint8]] = []
|
||||||
|
for tile in tiles:
|
||||||
|
image_tile_np = input_image_np[
|
||||||
|
tile.coords.top : tile.coords.bottom,
|
||||||
|
tile.coords.left : tile.coords.right,
|
||||||
|
:,
|
||||||
|
]
|
||||||
|
image_tiles_np.append(image_tile_np)
|
||||||
|
|
||||||
# VAE-encode each image tile independently.
|
# VAE-encode each image tile independently.
|
||||||
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
|
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
|
||||||
# about for decoding?
|
# about for decoding?
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
latent_tiles: list[torch.Tensor] = []
|
latent_tiles: list[torch.Tensor] = []
|
||||||
for image_tile in image_tiles:
|
for image_tile_torch in image_tiles_torch:
|
||||||
latent_tiles.append(
|
latent_tiles.append(
|
||||||
ImageToLatentsInvocation.vae_encode(
|
ImageToLatentsInvocation.vae_encode(
|
||||||
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile
|
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -181,7 +240,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
unet_info = context.models.load(self.unet.unet)
|
unet_info = context.models.load(self.unet.unet)
|
||||||
|
|
||||||
refined_latent_tiles: list[torch.Tensor] = []
|
refined_latent_tiles: list[torch.Tensor] = []
|
||||||
with unet_info as unet:
|
with ExitStack() as exit_stack, unet_info as unet:
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
@ -206,10 +265,39 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Load the ControlNet model.
|
||||||
|
# TODO(ryand): Support multiple ControlNet models.
|
||||||
|
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
|
||||||
|
assert isinstance(controlnet_model, ControlNetModel)
|
||||||
|
|
||||||
# Denoise (i.e. "refine") each tile independently.
|
# Denoise (i.e. "refine") each tile independently.
|
||||||
for latent_tile, noise_tile in zip(latent_tiles, noise_tiles, strict=True):
|
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
|
||||||
assert latent_tile.shape == noise_tile.shape
|
assert latent_tile.shape == noise_tile.shape
|
||||||
|
|
||||||
|
# Prepare a PIL Image for ControlNet processing.
|
||||||
|
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
|
||||||
|
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
|
||||||
|
image_tile_pil = Image.fromarray(image_tile_np)
|
||||||
|
|
||||||
|
# Run the ControlNet on the image tile.
|
||||||
|
height, width, _ = image_tile_np.shape
|
||||||
|
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
|
||||||
|
# validate this assumption here.
|
||||||
|
assert height % LATENT_SCALE_FACTOR == 0
|
||||||
|
assert width % LATENT_SCALE_FACTOR == 0
|
||||||
|
controlnet_data = self.run_controlnet(
|
||||||
|
image=image_tile_pil,
|
||||||
|
controlnet_model=controlnet_model,
|
||||||
|
weight=1.0,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
device=controlnet_model.device,
|
||||||
|
dtype=controlnet_model.dtype,
|
||||||
|
control_mode="balanced",
|
||||||
|
resize_mode="just_resize_simple",
|
||||||
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = (
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = (
|
||||||
DenoiseLatentsInvocation.init_scheduler(
|
DenoiseLatentsInvocation.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
@ -236,7 +324,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=None,
|
control_data=[controlnet_data],
|
||||||
ip_adapter_data=None,
|
ip_adapter_data=None,
|
||||||
t2i_adapter_data=None,
|
t2i_adapter_data=None,
|
||||||
callback=lambda x: None,
|
callback=lambda x: None,
|
||||||
|
@ -289,7 +289,7 @@ def prepare_control_image(
|
|||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
num_channels: int = 3,
|
num_channels: int = 3,
|
||||||
device: str = "cuda",
|
device: str | torch.device = "cuda",
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||||
@ -304,7 +304,7 @@ def prepare_control_image(
|
|||||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
|
||||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
Loading…
Reference in New Issue
Block a user