mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into stalker-modular_controlnet
This commit is contained in:
@ -59,7 +59,9 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import Custom
|
||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@ -823,6 +825,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
ext_manager.add_extension(PreviewExt(step_callback))
|
||||
|
||||
### cfg rescale
|
||||
if self.cfg_rescale_multiplier > 0:
|
||||
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
|
||||
|
||||
### freeu
|
||||
if self.unet.freeu_config:
|
||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
||||
|
||||
# context for loading additional models
|
||||
with ExitStack() as exit_stack:
|
||||
# later should be smth like:
|
||||
@ -837,12 +847,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
unet_info.model_on_device() as (model_state_dict, unet),
|
||||
unet_info.model_on_device() as (cached_weights, unet),
|
||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||
# ext: controlnet
|
||||
ext_manager.patch_extensions(denoise_ctx),
|
||||
# ext: freeu, seamless, ip adapter, lora
|
||||
ext_manager.patch_unet(model_state_dict, unet),
|
||||
ext_manager.patch_unet(unet, cached_weights),
|
||||
):
|
||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||
denoise_ctx.unet = unet
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -21,7 +23,7 @@ from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
|
||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
@ -34,8 +36,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tile_size: int = InputField(
|
||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||
)
|
||||
scale: float = InputField(
|
||||
default=4.0,
|
||||
gt=0.0,
|
||||
le=16.0,
|
||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
||||
)
|
||||
fit_to_multiple_of_8: bool = InputField(
|
||||
default=False,
|
||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
||||
)
|
||||
|
||||
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
||||
@classmethod
|
||||
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
|
||||
return Tile(
|
||||
coords=TBLR(
|
||||
top=tile.coords.top * scale,
|
||||
@ -51,20 +64,22 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
@classmethod
|
||||
def upscale_image(
|
||||
cls,
|
||||
image: Image.Image,
|
||||
tile_size: int,
|
||||
spandrel_model: SpandrelImageToImageModel,
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> Image.Image:
|
||||
# Compute the image tiles.
|
||||
if self.tile_size > 0:
|
||||
if tile_size > 0:
|
||||
min_overlap = 20
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=image.height,
|
||||
image_width=image.width,
|
||||
tile_height=self.tile_size,
|
||||
tile_width=self.tile_size,
|
||||
tile_height=tile_size,
|
||||
tile_width=tile_size,
|
||||
min_overlap=min_overlap,
|
||||
)
|
||||
else:
|
||||
@ -85,60 +100,123 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Prepare input image for inference.
|
||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
# Scale the tiles for re-assembling the final image.
|
||||
scale = spandrel_model.scale
|
||||
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
|
||||
|
||||
# Prepare the output tensor.
|
||||
_, channels, height, width = image_tensor.shape
|
||||
output_tensor = torch.zeros(
|
||||
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on each tile.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||
# Exit early if the invocation has been canceled.
|
||||
if is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# Scale the tiles for re-assembling the final image.
|
||||
scale = spandrel_model.scale
|
||||
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
||||
# Extract the current tile from the input tensor.
|
||||
input_tile = image_tensor[
|
||||
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
||||
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Prepare the output tensor.
|
||||
_, channels, height, width = image_tensor.shape
|
||||
output_tensor = torch.zeros(
|
||||
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
||||
)
|
||||
# Run the model on the tile.
|
||||
output_tile = spandrel_model.run(input_tile)
|
||||
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
# Convert the output tile into the output tensor's format.
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
output_tile = output_tile.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
output_tile = output_tile.permute(1, 2, 0)
|
||||
output_tile = output_tile.clamp(0, 1)
|
||||
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
|
||||
|
||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||
# Exit early if the invocation has been canceled.
|
||||
if context.util.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# Extract the current tile from the input tensor.
|
||||
input_tile = image_tensor[
|
||||
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
||||
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run the model on the tile.
|
||||
output_tile = spandrel_model.run(input_tile)
|
||||
|
||||
# Convert the output tile into the output tensor's format.
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
output_tile = output_tile.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
output_tile = output_tile.permute(1, 2, 0)
|
||||
output_tile = output_tile.clamp(0, 1)
|
||||
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
|
||||
|
||||
# Merge the output tile into the output tensor.
|
||||
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
|
||||
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
|
||||
# it seems unnecessary, but we may find a need in the future.
|
||||
top_overlap = scaled_tile.overlap.top // 2
|
||||
left_overlap = scaled_tile.overlap.left // 2
|
||||
output_tensor[
|
||||
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
|
||||
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
|
||||
:,
|
||||
] = output_tile[top_overlap:, left_overlap:, :]
|
||||
# Merge the output tile into the output tensor.
|
||||
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
|
||||
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
|
||||
# it seems unnecessary, but we may find a need in the future.
|
||||
top_overlap = scaled_tile.overlap.top // 2
|
||||
left_overlap = scaled_tile.overlap.left // 2
|
||||
output_tensor[
|
||||
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
|
||||
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
|
||||
:,
|
||||
] = output_tile[top_overlap:, left_overlap:, :]
|
||||
|
||||
# Convert the output tensor to a PIL image.
|
||||
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
return pil_image
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||
# revisit this.
|
||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
|
||||
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
|
||||
target_width = int(image.width * self.scale)
|
||||
target_height = int(image.height * self.scale)
|
||||
|
||||
# Do the upscaling.
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# First pass of upscaling. Note: `pil_image` will be mutated.
|
||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
|
||||
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
|
||||
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
|
||||
# to be considered an upscale model.
|
||||
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
|
||||
|
||||
if is_upscale_model:
|
||||
# This is an upscale model, so we should keep upscaling until we reach the target size.
|
||||
iterations = 1
|
||||
while pil_image.width < target_width or pil_image.height < target_height:
|
||||
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||
iterations += 1
|
||||
|
||||
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
|
||||
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
|
||||
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
|
||||
# we should never reach this limit.
|
||||
if iterations >= 5:
|
||||
context.logger.warning(
|
||||
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
|
||||
)
|
||||
break
|
||||
else:
|
||||
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
|
||||
# to be the same as the processed image size.
|
||||
|
||||
# The output size is now the size of the processed image.
|
||||
target_width = pil_image.width
|
||||
target_height = pil_image.height
|
||||
|
||||
# Warn the user if they requested a scale greater than 1.
|
||||
if self.scale > 1:
|
||||
context.logger.warning(
|
||||
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
|
||||
)
|
||||
|
||||
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
|
||||
# in the final resize
|
||||
if self.fit_to_multiple_of_8:
|
||||
target_width = int(target_width // 8 * 8)
|
||||
target_height = int(target_height // 8 * 8)
|
||||
|
||||
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
|
||||
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
|
||||
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
Reference in New Issue
Block a user