mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor TiledStableDiffusionRefineInvocation to more closely mirror TiledMultiDiffusionDenoiseLatents. The biggest improvement is in the handling of the ControlNets - global ControlNet info can now be passed in and it is tiled within the node.
This commit is contained in:
parent
b74bc77347
commit
8379feeb8a
@ -2,14 +2,14 @@ from contextlib import ExitStack
|
|||||||
from typing import Iterator, Tuple
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
@ -17,22 +17,24 @@ from invokeai.app.invocations.fields import (
|
|||||||
ImageField,
|
ImageField,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
|
LatentsField,
|
||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
|
from invokeai.app.invocations.model import UNetField, VAEField
|
||||||
from invokeai.app.invocations.noise import get_noise
|
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
from invokeai.app.invocations.tiled_multi_diffusion_denoise_latents import crop_controlnet_data
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
|
||||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
from invokeai.backend.tiles.tiles import (
|
||||||
from invokeai.backend.tiles.utils import Tile
|
calc_tiles_min_overlap,
|
||||||
|
merge_tiles_with_linear_blending,
|
||||||
|
)
|
||||||
|
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -40,6 +42,7 @@ from invokeai.backend.util.hotfixes import ControlNetModel
|
|||||||
title="Tiled Stable Diffusion Refine",
|
title="Tiled Stable Diffusion Refine",
|
||||||
tags=["upscale", "denoise"],
|
tags=["upscale", "denoise"],
|
||||||
category="latents",
|
category="latents",
|
||||||
|
classification=Classification.Beta,
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
||||||
@ -55,13 +58,21 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||||
)
|
)
|
||||||
# TODO(ryand): Add multiple-of validation.
|
noise: LatentsField = InputField(
|
||||||
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
|
description=FieldDescriptions.noise,
|
||||||
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tile_height: int = InputField(
|
||||||
|
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
|
||||||
|
)
|
||||||
|
tile_width: int = InputField(
|
||||||
|
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
|
||||||
|
)
|
||||||
tile_overlap: int = InputField(
|
tile_overlap: int = InputField(
|
||||||
default=16,
|
default=32,
|
||||||
|
multiple_of=LATENT_SCALE_FACTOR,
|
||||||
gt=0,
|
gt=0,
|
||||||
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
|
description="Target overlap between adjacent tiles in image space.",
|
||||||
)
|
)
|
||||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||||
@ -92,16 +103,10 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
vae_fp32: bool = InputField(
|
vae_fp32: bool = InputField(
|
||||||
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
||||||
)
|
)
|
||||||
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
|
control: ControlField | list[ControlField] | None = InputField(
|
||||||
# don't want to use the image field. Figure out how best to handle this.
|
default=None,
|
||||||
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
|
input=Input.Connection,
|
||||||
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
|
|
||||||
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
|
|
||||||
# etc.
|
|
||||||
control_model: ModelIdentifierField = InputField(
|
|
||||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
|
||||||
)
|
)
|
||||||
control_weight: float = InputField(default=0.6)
|
|
||||||
|
|
||||||
@field_validator("cfg_scale")
|
@field_validator("cfg_scale")
|
||||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||||
@ -115,90 +120,72 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@staticmethod
|
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
||||||
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
|
"""Scale the tile by the given factor."""
|
||||||
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
|
return Tile(
|
||||||
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
|
coords=TBLR(
|
||||||
"""
|
top=tile.coords.top * scale,
|
||||||
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
|
bottom=tile.coords.bottom * scale,
|
||||||
if coord % LATENT_SCALE_FACTOR != 0:
|
left=tile.coords.left * scale,
|
||||||
raise ValueError(
|
right=tile.coords.right * scale,
|
||||||
f"The tile coordinates must all be divisible by the latent scale factor"
|
),
|
||||||
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
overlap=TBLR(
|
||||||
)
|
top=tile.overlap.top * scale,
|
||||||
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
|
bottom=tile.overlap.bottom * scale,
|
||||||
|
left=tile.overlap.left * scale,
|
||||||
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
right=tile.overlap.right * scale,
|
||||||
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
),
|
||||||
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
|
|
||||||
right = image_tile.coords.right // LATENT_SCALE_FACTOR
|
|
||||||
return latents[..., top:bottom, left:right]
|
|
||||||
|
|
||||||
def run_controlnet(
|
|
||||||
self,
|
|
||||||
image: Image.Image,
|
|
||||||
controlnet_model: ControlNetModel,
|
|
||||||
weight: float,
|
|
||||||
do_classifier_free_guidance: bool,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
|
||||||
) -> ControlNetData:
|
|
||||||
control_image = prepare_control_image(
|
|
||||||
image=image,
|
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
control_mode=control_mode,
|
|
||||||
resize_mode=resize_mode,
|
|
||||||
)
|
|
||||||
return ControlNetData(
|
|
||||||
model=controlnet_model,
|
|
||||||
image_tensor=control_image,
|
|
||||||
weight=weight,
|
|
||||||
begin_step_percent=0.0,
|
|
||||||
end_step_percent=1.0,
|
|
||||||
control_mode=control_mode,
|
|
||||||
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
|
|
||||||
# ControlNetData in case needed in the future.
|
|
||||||
resize_mode=resize_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# TODO(ryand): Expose the seed parameter.
|
# Convert tile image-space dimensions to latent-space dimensions.
|
||||||
seed = 0
|
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
|
||||||
|
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
|
||||||
|
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
|
||||||
|
|
||||||
# Load the input image.
|
# Load the input image.
|
||||||
input_image = context.images.get_pil(self.image.image_name)
|
input_image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
# Calculate the tile locations to cover the image.
|
|
||||||
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
|
|
||||||
# facilitates conversions between image space and latent space.
|
|
||||||
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
|
||||||
tiles = calc_tiles_with_overlap(
|
|
||||||
image_height=input_image.height,
|
|
||||||
image_width=input_image.width,
|
|
||||||
tile_height=self.tile_height,
|
|
||||||
tile_width=self.tile_width,
|
|
||||||
overlap=self.tile_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert the input image to a torch.Tensor.
|
# Convert the input image to a torch.Tensor.
|
||||||
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
||||||
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
||||||
# Validate our assumptions about the shape of input_image_torch.
|
# Validate our assumptions about the shape of input_image_torch.
|
||||||
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
|
batch_size, channels, image_height, image_width = input_image_torch.shape
|
||||||
assert input_image_torch.shape[:2] == (1, 3)
|
assert batch_size == 1
|
||||||
|
assert channels == 3
|
||||||
|
|
||||||
|
# Load the noise tensor.
|
||||||
|
noise = context.tensors.load(self.noise.latents_name)
|
||||||
|
if list(noise.shape) != [
|
||||||
|
batch_size,
|
||||||
|
4,
|
||||||
|
image_height // LATENT_SCALE_FACTOR,
|
||||||
|
image_width // LATENT_SCALE_FACTOR,
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incompatible noise and image dimensions. Image shape: {input_image_torch.shape}. "
|
||||||
|
f"Noise shape: {noise.shape}. Expected noise shape: [1, 1, "
|
||||||
|
f"{image_height // LATENT_SCALE_FACTOR}, {image_width // LATENT_SCALE_FACTOR}]. "
|
||||||
|
)
|
||||||
|
latent_height, latent_width = noise.shape[2:]
|
||||||
|
|
||||||
|
# Extract the seed from the noise field.
|
||||||
|
assert self.noise.seed is not None
|
||||||
|
seed = self.noise.seed or 0
|
||||||
|
|
||||||
|
# Calculate the tile locations in both latent space and image space.
|
||||||
|
latent_space_tiles = calc_tiles_min_overlap(
|
||||||
|
image_height=latent_height,
|
||||||
|
image_width=latent_width,
|
||||||
|
tile_height=latent_tile_height,
|
||||||
|
tile_width=latent_tile_width,
|
||||||
|
min_overlap=latent_tile_overlap,
|
||||||
|
)
|
||||||
|
image_space_tiles = [self._scale_tile(tile, LATENT_SCALE_FACTOR) for tile in latent_space_tiles]
|
||||||
|
|
||||||
# Split the input image into tiles in torch.Tensor format.
|
# Split the input image into tiles in torch.Tensor format.
|
||||||
image_tiles_torch: list[torch.Tensor] = []
|
image_tiles_torch: list[torch.Tensor] = []
|
||||||
for tile in tiles:
|
for tile in image_space_tiles:
|
||||||
image_tile = input_image_torch[
|
image_tile = input_image_torch[
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
@ -207,22 +194,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
]
|
]
|
||||||
image_tiles_torch.append(image_tile)
|
image_tiles_torch.append(image_tile)
|
||||||
|
|
||||||
# Split the input image into tiles in numpy format.
|
|
||||||
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
|
|
||||||
# with torch.Tensor tiles.
|
|
||||||
input_image_np = np.array(input_image)
|
|
||||||
image_tiles_np: list[npt.NDArray[np.uint8]] = []
|
|
||||||
for tile in tiles:
|
|
||||||
image_tile_np = input_image_np[
|
|
||||||
tile.coords.top : tile.coords.bottom,
|
|
||||||
tile.coords.left : tile.coords.right,
|
|
||||||
:,
|
|
||||||
]
|
|
||||||
image_tiles_np.append(image_tile_np)
|
|
||||||
|
|
||||||
# VAE-encode each image tile independently.
|
# VAE-encode each image tile independently.
|
||||||
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
|
|
||||||
# about for decoding?
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
latent_tiles: list[torch.Tensor] = []
|
latent_tiles: list[torch.Tensor] = []
|
||||||
for image_tile_torch in image_tiles_torch:
|
for image_tile_torch in image_tiles_torch:
|
||||||
@ -232,23 +204,16 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate noise with dimensions corresponding to the full image in latent space.
|
|
||||||
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
|
|
||||||
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
|
|
||||||
# noise.
|
|
||||||
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
|
|
||||||
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
|
|
||||||
global_noise = get_noise(
|
|
||||||
width=input_image_torch.shape[3],
|
|
||||||
height=input_image_torch.shape[2],
|
|
||||||
device=TorchDevice.choose_torch_device(),
|
|
||||||
seed=seed,
|
|
||||||
downsampling_factor=LATENT_SCALE_FACTOR,
|
|
||||||
use_cpu=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Crop the global noise into tiles.
|
# Crop the global noise into tiles.
|
||||||
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
|
noise_tiles: list[torch.Tensor] = []
|
||||||
|
for tile in latent_space_tiles:
|
||||||
|
noise_tile = noise[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
tile.coords.top : tile.coords.bottom,
|
||||||
|
tile.coords.left : tile.coords.right,
|
||||||
|
]
|
||||||
|
noise_tiles.append(noise_tile)
|
||||||
|
|
||||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
@ -273,53 +238,42 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
|
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||||
# Assume that all tiles have the same shape.
|
|
||||||
_, _, latent_height, latent_width = latent_tiles[0].shape
|
|
||||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||||
context=context,
|
context=context,
|
||||||
positive_conditioning_field=self.positive_conditioning,
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
negative_conditioning_field=self.negative_conditioning,
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
latent_height=latent_height,
|
latent_height=latent_tile_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_tile_width,
|
||||||
cfg_scale=self.cfg_scale,
|
cfg_scale=self.cfg_scale,
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load the ControlNet model.
|
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
|
||||||
# TODO(ryand): Support multiple ControlNet models.
|
context=context,
|
||||||
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
|
control_input=self.control,
|
||||||
assert isinstance(controlnet_model, ControlNetModel)
|
# NOTE: We use the shape of the global noise tensor here, because this is a global ControlNet. We tile
|
||||||
|
# it later.
|
||||||
|
latents_shape=list(noise.shape),
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split the controlnet_data into tiles.
|
||||||
|
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
|
||||||
|
controlnet_data_tiles: list[list[ControlNetData]] = []
|
||||||
|
for tile in latent_space_tiles:
|
||||||
|
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
|
||||||
|
controlnet_data_tiles.append(tile_controlnet_data)
|
||||||
|
|
||||||
# Denoise (i.e. "refine") each tile independently.
|
# Denoise (i.e. "refine") each tile independently.
|
||||||
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
|
for latent_tile, noise_tile, controlnet_data_tile in zip(
|
||||||
|
latent_tiles, noise_tiles, controlnet_data_tiles, strict=True
|
||||||
|
):
|
||||||
assert latent_tile.shape == noise_tile.shape
|
assert latent_tile.shape == noise_tile.shape
|
||||||
|
|
||||||
# Prepare a PIL Image for ControlNet processing.
|
|
||||||
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
|
|
||||||
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
|
|
||||||
image_tile_pil = Image.fromarray(image_tile_np)
|
|
||||||
|
|
||||||
# Run the ControlNet on the image tile.
|
|
||||||
height, width, _ = image_tile_np.shape
|
|
||||||
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
|
|
||||||
# validate this assumption here.
|
|
||||||
assert height % LATENT_SCALE_FACTOR == 0
|
|
||||||
assert width % LATENT_SCALE_FACTOR == 0
|
|
||||||
controlnet_data = self.run_controlnet(
|
|
||||||
image=image_tile_pil,
|
|
||||||
controlnet_model=controlnet_model,
|
|
||||||
weight=self.control_weight,
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
device=controlnet_model.device,
|
|
||||||
dtype=controlnet_model.dtype,
|
|
||||||
control_mode="balanced",
|
|
||||||
resize_mode="just_resize_simple",
|
|
||||||
)
|
|
||||||
|
|
||||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
device=unet.device,
|
device=unet.device,
|
||||||
@ -342,7 +296,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
masked_latents=None,
|
masked_latents=None,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=[controlnet_data],
|
control_data=controlnet_data_tile,
|
||||||
ip_adapter_data=None,
|
ip_adapter_data=None,
|
||||||
t2i_adapter_data=None,
|
t2i_adapter_data=None,
|
||||||
callback=lambda x: None,
|
callback=lambda x: None,
|
||||||
@ -368,9 +322,11 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
|||||||
# Merge the refined image tiles back into a single image.
|
# Merge the refined image tiles back into a single image.
|
||||||
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
||||||
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
||||||
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
|
|
||||||
merge_tiles_with_linear_blending(
|
merge_tiles_with_linear_blending(
|
||||||
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
|
dst_image=merged_image_np,
|
||||||
|
tiles=image_space_tiles,
|
||||||
|
tile_images=refined_image_tiles_np,
|
||||||
|
blend_amount=self.tile_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save the refined image and return its reference.
|
# Save the refined image and return its reference.
|
||||||
|
Loading…
Reference in New Issue
Block a user