Rough prototype of TiledStableDiffusionRefineInvocation is working.

This commit is contained in:
Ryan Dick 2024-06-07 15:05:57 -04:00
parent 08ca03ef9f
commit 6a7a26f1bf

View File

@ -1,3 +1,4 @@
import numpy as np
import torch import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image from PIL import Image
@ -12,17 +13,16 @@ from invokeai.app.invocations.fields import (
ImageField, ImageField,
Input, Input,
InputField, InputField,
LatentsField,
UIType, UIType,
) )
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.invocations.noise import get_noise from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -34,17 +34,11 @@ from invokeai.backend.util.devices import TorchDevice
category="latents", category="latents",
version="1.0.0", version="1.0.0",
) )
class TiledStableDiffusionRefine(BaseInvocation): class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to """A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow. refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
""" """
# Implementation order:
# - Basic tiled denoising. Support text prompts, but no other features.
# - Support LoRA + TI
# - Support ControlNet
# - IP-Adapter? (It has to run on each tile independently. Could be complicated to support batching.)
image: ImageField = InputField(description="Image to be refined.") image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField( positive_conditioning: ConditioningField = InputField(
@ -53,11 +47,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
negative_conditioning: ConditioningField = InputField( negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection description=FieldDescriptions.negative_cond, input=Input.Connection
) )
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale") cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField( denoising_start: float = InputField(
@ -84,11 +73,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
cfg_rescale_multiplier: float = InputField( cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
) )
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField( vae: VAEField = InputField(
description=FieldDescriptions.vae, description=FieldDescriptions.vae,
input=Input.Connection, input=Input.Connection,
@ -120,7 +104,7 @@ class TiledStableDiffusionRefine(BaseInvocation):
f"The tile coordinates must all be divisible by the latent scale factor" f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}." f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
) )
assert latents.dim == 4 # We expect: (batch_size, channels, height, width). assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR left = image_tile.coords.left // LATENT_SCALE_FACTOR
@ -129,13 +113,14 @@ class TiledStableDiffusionRefine(BaseInvocation):
return latents[..., top:bottom, left:right] return latents[..., top:bottom, left:right]
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter. # TODO(ryand): Expose the seed parameter.
seed = 0 seed = 0
# Load the input image. # Load the input image.
input_image = context.images.get_pil(self.image.image_name) input_image = context.images.get_pil(self.image.image_name)
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR) input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Calculate the tile locations to cover the image. # Calculate the tile locations to cover the image.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.) # TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
@ -236,6 +221,9 @@ class TiledStableDiffusionRefine(BaseInvocation):
) )
) )
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings( refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile, latents=latent_tile,
timesteps=timesteps, timesteps=timesteps,
@ -268,12 +256,19 @@ class TiledStableDiffusionRefine(BaseInvocation):
) )
refined_image_tiles.append(refined_image_tile) refined_image_tiles.append(refined_image_tile)
# Merge the refined image tiles back into a single image. # TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
...
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache() TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents) # Merge the refined image tiles back into a single image.
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Expose the blend_amount parameter, or set it based on the value of min_overlap used earlier.
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=32
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)