Change tiling strategy to make TiledStableDiffusionRefineInvocation work with more tile shapes and overlaps.

This commit is contained in:
Ryan Dick 2024-06-10 16:40:13 -04:00
parent 911792f258
commit 59284c707e

View File

@ -29,7 +29,7 @@ from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNE
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel from invokeai.backend.util.hotfixes import ControlNetModel
@ -58,11 +58,15 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
# TODO(ryand): Add multiple-of validation. # TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.") tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.") tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_min_overlap: int = InputField(default=16, gt=0, description="Minimum overlap between tiles.") tile_overlap: int = InputField(
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) default=16,
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale") gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField( denoising_start: float = InputField(
default=0.0, default=0.65,
ge=0, ge=0,
le=1, le=1,
description=FieldDescriptions.denoising_start, description=FieldDescriptions.denoising_start,
@ -174,13 +178,15 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
input_image = context.images.get_pil(self.image.image_name) input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image. # Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.) # TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_min_overlap( tiles = calc_tiles_with_overlap(
image_height=input_image.height, image_height=input_image.height,
image_width=input_image.width, image_width=input_image.width,
tile_height=self.tile_height, tile_height=self.tile_height,
tile_width=self.tile_width, tile_width=self.tile_width,
min_overlap=self.tile_min_overlap, overlap=self.tile_overlap,
) )
# Convert the input image to a torch.Tensor. # Convert the input image to a torch.Tensor.
@ -366,9 +372,9 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
# Merge the refined image tiles back into a single image. # Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles] refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8) merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Expose the blend_amount parameter, or set it based on the value of min_overlap used earlier. # TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending( merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=32 dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
) )
# Save the refined image and return its reference. # Save the refined image and return its reference.