mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rough prototype of TiledStableDiffusionRefineInvocation is working.
This commit is contained in:
parent
08ca03ef9f
commit
6a7a26f1bf
@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -12,17 +13,16 @@ from invokeai.app.invocations.fields import (
|
|||||||
ImageField,
|
ImageField,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
LatentsField,
|
|
||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
||||||
from invokeai.app.invocations.model import UNetField, VAEField
|
from invokeai.app.invocations.model import UNetField, VAEField
|
||||||
from invokeai.app.invocations.noise import get_noise
|
from invokeai.app.invocations.noise import get_noise
|
||||||
from invokeai.app.invocations.primitives import LatentsOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||||
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending
|
||||||
from invokeai.backend.tiles.utils import Tile
|
from invokeai.backend.tiles.utils import Tile
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -34,17 +34,11 @@ from invokeai.backend.util.devices import TorchDevice
|
|||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class TiledStableDiffusionRefine(BaseInvocation):
|
class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
||||||
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
|
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
|
||||||
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
|
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Implementation order:
|
|
||||||
# - Basic tiled denoising. Support text prompts, but no other features.
|
|
||||||
# - Support LoRA + TI
|
|
||||||
# - Support ControlNet
|
|
||||||
# - IP-Adapter? (It has to run on each tile independently. Could be complicated to support batching.)
|
|
||||||
|
|
||||||
image: ImageField = InputField(description="Image to be refined.")
|
image: ImageField = InputField(description="Image to be refined.")
|
||||||
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: ConditioningField = InputField(
|
||||||
@ -53,11 +47,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
|||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: ConditioningField = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||||
)
|
)
|
||||||
noise: LatentsField | None = InputField(
|
|
||||||
default=None,
|
|
||||||
description=FieldDescriptions.noise,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||||
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||||
denoising_start: float = InputField(
|
denoising_start: float = InputField(
|
||||||
@ -84,11 +73,6 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
|||||||
cfg_rescale_multiplier: float = InputField(
|
cfg_rescale_multiplier: float = InputField(
|
||||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||||
)
|
)
|
||||||
latents: LatentsField | None = InputField(
|
|
||||||
default=None,
|
|
||||||
description=FieldDescriptions.latents,
|
|
||||||
input=Input.Connection,
|
|
||||||
)
|
|
||||||
vae: VAEField = InputField(
|
vae: VAEField = InputField(
|
||||||
description=FieldDescriptions.vae,
|
description=FieldDescriptions.vae,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
@ -120,7 +104,7 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
|||||||
f"The tile coordinates must all be divisible by the latent scale factor"
|
f"The tile coordinates must all be divisible by the latent scale factor"
|
||||||
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
||||||
)
|
)
|
||||||
assert latents.dim == 4 # We expect: (batch_size, channels, height, width).
|
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||||
|
|
||||||
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
||||||
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
||||||
@ -129,13 +113,14 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
|||||||
return latents[..., top:bottom, left:right]
|
return latents[..., top:bottom, left:right]
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# TODO(ryand): Expose the seed parameter.
|
# TODO(ryand): Expose the seed parameter.
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
# Load the input image.
|
# Load the input image.
|
||||||
input_image = context.images.get_pil(self.image.image_name)
|
input_image = context.images.get_pil(self.image.image_name)
|
||||||
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
||||||
|
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
||||||
|
|
||||||
# Calculate the tile locations to cover the image.
|
# Calculate the tile locations to cover the image.
|
||||||
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
||||||
@ -236,6 +221,9 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
|
||||||
|
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
|
||||||
refined_latent_tile = pipeline.latents_from_embeddings(
|
refined_latent_tile = pipeline.latents_from_embeddings(
|
||||||
latents=latent_tile,
|
latents=latent_tile,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
@ -268,12 +256,19 @@ class TiledStableDiffusionRefine(BaseInvocation):
|
|||||||
)
|
)
|
||||||
refined_image_tiles.append(refined_image_tile)
|
refined_image_tiles.append(refined_image_tile)
|
||||||
|
|
||||||
# Merge the refined image tiles back into a single image.
|
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||||
...
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
result_latents = result_latents.to("cpu")
|
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
# Merge the refined image tiles back into a single image.
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
||||||
|
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
||||||
|
# TODO(ryand): Expose the blend_amount parameter, or set it based on the value of min_overlap used earlier.
|
||||||
|
merge_tiles_with_linear_blending(
|
||||||
|
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the refined image and return its reference.
|
||||||
|
merged_image_pil = Image.fromarray(merged_image_np)
|
||||||
|
image_dto = context.images.save(image=merged_image_pil)
|
||||||
|
|
||||||
|
return ImageOutput.build(image_dto)
|
||||||
|
Loading…
Reference in New Issue
Block a user