mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rough prototype of TiledStableDiffusionRefineInvocation is working.
This commit is contained in:
parent
08ca03ef9f
commit
6a7a26f1bf
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from PIL import Image
|
||||
@ -12,17 +13,16 @@ from invokeai.app.invocations.fields import (
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
||||
from invokeai.app.invocations.model import UNetField, VAEField
|
||||
from invokeai.app.invocations.noise import get_noise
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -34,17 +34,11 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class TiledStableDiffusionRefine(BaseInvocation):
|
||||
class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
||||
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
|
||||
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
|
||||
"""
|
||||
|
||||
# Implementation order:
|
||||
# - Basic tiled denoising. Support text prompts, but no other features.
|
||||
# - Support LoRA + TI
|
||||
# - Support ControlNet
|
||||
# - IP-Adapter? (It has to run on each tile independently. Could be complicated to support batching.)
|
||||
|
||||
image: ImageField = InputField(description="Image to be refined.")
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
@ -53,11 +47,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
noise: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
denoising_start: float = InputField(
|
||||
@ -84,11 +73,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
latents: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
@ -120,7 +104,7 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
||||
f"The tile coordinates must all be divisible by the latent scale factor"
|
||||
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
||||
)
|
||||
assert latents.dim == 4 # We expect: (batch_size, channels, height, width).
|
||||
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||
|
||||
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
||||
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
||||
@ -129,13 +113,14 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
||||
return latents[..., top:bottom, left:right]
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# TODO(ryand): Expose the seed parameter.
|
||||
seed = 0
|
||||
|
||||
# Load the input image.
|
||||
input_image = context.images.get_pil(self.image.image_name)
|
||||
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
||||
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
||||
|
||||
# Calculate the tile locations to cover the image.
|
||||
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
||||
@ -236,6 +221,9 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
|
||||
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
refined_latent_tile = pipeline.latents_from_embeddings(
|
||||
latents=latent_tile,
|
||||
timesteps=timesteps,
|
||||
@ -268,12 +256,19 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
||||
)
|
||||
refined_image_tiles.append(refined_image_tile)
|
||||
|
||||
# Merge the refined image tiles back into a single image.
|
||||
...
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
# Merge the refined image tiles back into a single image.
|
||||
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
||||
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
||||
# TODO(ryand): Expose the blend_amount parameter, or set it based on the value of min_overlap used earlier.
|
||||
merge_tiles_with_linear_blending(
|
||||
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=32
|
||||
)
|
||||
|
||||
# Save the refined image and return its reference.
|
||||
merged_image_pil = Image.fromarray(merged_image_np)
|
||||
image_dto = context.images.save(image=merged_image_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
Loading…
Reference in New Issue
Block a user