Rough prototype of TiledStableDiffusionRefineInvocation is working.

This commit is contained in:
Ryan Dick 2024-06-07 15:05:57 -04:00
parent 787e1bbb5f
commit 459d487620

View File

@ -1,3 +1,4 @@
import numpy as np
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
@ -11,7 +12,6 @@ from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
@ -19,10 +19,10 @@ from invokeai.app.invocations.latent import DenoiseLatentsInvocation, get_schedu
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
@ -34,17 +34,11 @@ from invokeai.backend.util.devices import TorchDevice
category="latents",
version="1.0.0",
)
class TiledStableDiffusionRefine(BaseInvocation):
class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
# Implementation order:
# - Basic tiled denoising. Support text prompts, but no other features.
# - Support LoRA + TI
# - Support ControlNet
# - IP-Adapter? (It has to run on each tile independently. Could be complicated to support batching.)
image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField(
@ -53,11 +47,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
@ -84,11 +73,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
@ -120,7 +104,7 @@ class TiledStableDiffusionRefine(BaseInvocation):
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim == 4 # We expect: (batch_size, channels, height, width).
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
@ -129,13 +113,14 @@ class TiledStableDiffusionRefine(BaseInvocation):
return latents[..., top:bottom, left:right]
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter.
seed = 0
# Load the input image.
input_image = context.images.get_pil(self.image.image_name)
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Calculate the tile locations to cover the image.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
@ -236,6 +221,9 @@ class TiledStableDiffusionRefine(BaseInvocation):
)
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
@ -268,12 +256,19 @@ class TiledStableDiffusionRefine(BaseInvocation):
)
refined_image_tiles.append(refined_image_tile)
# Merge the refined image tiles back into a single image.
...
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
# Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Expose the blend_amount parameter, or set it based on the value of min_overlap used earlier.
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=32
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)