From 534640ccdebf845f8719056345a2bf76acd67c61 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 7 Jun 2024 15:05:57 -0400 Subject: [PATCH] Rough prototype of TiledStableDiffusionRefineInvocation is working. --- .../tiled_stable_diffusion_refine.py | 53 +++++++++---------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/invokeai/app/invocations/tiled_stable_diffusion_refine.py b/invokeai/app/invocations/tiled_stable_diffusion_refine.py index 2b983c14a7..5313bc1c2c 100644 --- a/invokeai/app/invocations/tiled_stable_diffusion_refine.py +++ b/invokeai/app/invocations/tiled_stable_diffusion_refine.py @@ -1,3 +1,4 @@ +import numpy as np import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from PIL import Image @@ -12,17 +13,16 @@ from invokeai.app.invocations.fields import ( ImageField, Input, InputField, - LatentsField, UIType, ) from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.invocations.noise import get_noise -from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor -from invokeai.backend.tiles.tiles import calc_tiles_min_overlap +from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending from invokeai.backend.tiles.utils import Tile from invokeai.backend.util.devices import TorchDevice @@ -34,17 +34,11 @@ from invokeai.backend.util.devices import TorchDevice category="latents", version="1.0.0", ) -class TiledStableDiffusionRefine(BaseInvocation): +class TiledStableDiffusionRefineInvocation(BaseInvocation): """A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow. """ - # Implementation order: - # - Basic tiled denoising. Support text prompts, but no other features. - # - Support LoRA + TI - # - Support ControlNet - # - IP-Adapter? (It has to run on each tile independently. Could be complicated to support batching.) - image: ImageField = InputField(description="Image to be refined.") positive_conditioning: ConditioningField = InputField( @@ -53,11 +47,6 @@ class TiledStableDiffusionRefine(BaseInvocation): negative_conditioning: ConditioningField = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection ) - noise: LatentsField | None = InputField( - default=None, - description=FieldDescriptions.noise, - input=Input.Connection, - ) steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale") denoising_start: float = InputField( @@ -84,11 +73,6 @@ class TiledStableDiffusionRefine(BaseInvocation): cfg_rescale_multiplier: float = InputField( title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier ) - latents: LatentsField | None = InputField( - default=None, - description=FieldDescriptions.latents, - input=Input.Connection, - ) vae: VAEField = InputField( description=FieldDescriptions.vae, input=Input.Connection, @@ -120,7 +104,7 @@ class TiledStableDiffusionRefine(BaseInvocation): f"The tile coordinates must all be divisible by the latent scale factor" f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}." ) - assert latents.dim == 4 # We expect: (batch_size, channels, height, width). + assert latents.dim() == 4 # We expect: (batch_size, channels, height, width). top = image_tile.coords.top // LATENT_SCALE_FACTOR left = image_tile.coords.left // LATENT_SCALE_FACTOR @@ -129,13 +113,14 @@ class TiledStableDiffusionRefine(BaseInvocation): return latents[..., top:bottom, left:right] @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: # TODO(ryand): Expose the seed parameter. seed = 0 # Load the input image. input_image = context.images.get_pil(self.image.image_name) input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR) + input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension. # Calculate the tile locations to cover the image. # TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.) @@ -236,6 +221,9 @@ class TiledStableDiffusionRefine(BaseInvocation): ) ) + # TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM. + latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype) + noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype) refined_latent_tile = pipeline.latents_from_embeddings( latents=latent_tile, timesteps=timesteps, @@ -268,12 +256,19 @@ class TiledStableDiffusionRefine(BaseInvocation): ) refined_image_tiles.append(refined_image_tile) - # Merge the refined image tiles back into a single image. - ... - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - result_latents = result_latents.to("cpu") + # TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important. TorchDevice.empty_cache() - name = context.tensors.save(tensor=result_latents) - return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) + # Merge the refined image tiles back into a single image. + refined_image_tiles_np = [np.array(t) for t in refined_image_tiles] + merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8) + # TODO(ryand): Expose the blend_amount parameter, or set it based on the value of min_overlap used earlier. + merge_tiles_with_linear_blending( + dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=32 + ) + + # Save the refined image and return its reference. + merged_image_pil = Image.fromarray(merged_image_np) + image_dto = context.images.save(image=merged_image_pil) + + return ImageOutput.build(image_dto)