From 8918501869b6cafa6c70f6b1c09e1ebdd3e3e495 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 22 Sep 2023 16:37:42 +1000 Subject: [PATCH] feat(nodes): add config to limit size of images in denoising This serves as a relatively crude way to prevent OOM errors during denoising (and any operations downstream of the denoising step, like the VAE decode in Linear UI graphs). - Add `max_image_size` config options - this is the total number of pixels eg the area - Add logic to `denoise_latents` to scale the `latents` and `noise` to fit this - Add logic to `color_correct` to scale the reference and mask to fit the image --- invokeai/app/invocations/image.py | 14 +++++++--- invokeai/app/invocations/latent.py | 28 +++++++++++++++++-- .../app/services/config/invokeai_config.py | 1 + 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 0403fa71e3..c849915062 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -645,13 +645,19 @@ class ColorCorrectInvocation(BaseInvocation): mask_blur_radius: float = InputField(default=8, description="Mask blur radius") def invoke(self, context: InvocationContext) -> ImageOutput: + result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + + init_image = context.services.images.get_pil_image(self.reference.image_name) + # fit reference image to the input image + if init_image.size != result.size: + init_image = init_image.resize((result.width, result.height), Image.BILINEAR) + pil_init_mask = None if self.mask is not None: pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L") - - init_image = context.services.images.get_pil_image(self.reference.image_name) - - result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + # fit mask to the input image + if pil_init_mask.size != result.size: + pil_init_mask = pil_init_mask.resize((result.width, result.height), Image.BILINEAR) # if init_image is None or init_mask is None: # return result diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 10caeee67a..cea948e813 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -77,6 +77,25 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] +def fit_latents(latents: torch.Tensor, max_latents_size: int, device: torch.device) -> torch.Tensor: + if max_latents_size == 0: + return latents + + latents_area = latents.shape[2] * latents.shape[3] + if latents_area <= max_latents_size: + return latents + + scale_factor = np.sqrt(max_latents_size / latents_area) + scaled_latents = torch.nn.functional.interpolate( + latents.to(device), + scale_factor=scale_factor, + mode="bilinear", + antialias=True, + ) + + return scaled_latents + + @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) @@ -500,14 +519,20 @@ class DenoiseLatentsInvocation(BaseInvocation): with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None + max_image_size = context.services.configuration.max_image_size + if self.noise is not None: noise = context.services.latents.get(self.noise.latents_name) seed = self.noise.seed + noise = fit_latents(latents=noise, max_latents_size=max_image_size // 64, device=choose_torch_device()) if self.latents is not None: latents = context.services.latents.get(self.latents.latents_name) if seed is None: seed = self.latents.seed + latents = fit_latents( + latents=latents, max_latents_size=max_image_size // 64, device=choose_torch_device() + ) if noise is not None and noise.shape[1:] != latents.shape[1:]: raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") @@ -532,8 +557,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), - context=context, + **lora.dict(exclude={"weight"}), context=context ) yield (lora_info.context.model, lora.weight) del lora_info diff --git a/invokeai/app/services/config/invokeai_config.py b/invokeai/app/services/config/invokeai_config.py index 51ccf45704..63d14b7848 100644 --- a/invokeai/app/services/config/invokeai_config.py +++ b/invokeai/app/services/config/invokeai_config.py @@ -255,6 +255,7 @@ class InvokeAIAppConfig(InvokeAISettings): attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", ) force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) + max_image_size : int = Field(default=512 * 512, description="The maximum size of images, in pixels. The maximum size for latents is inferred from this evaluating `max_image_size // 8`. If the size is exceeded during denoising, the latents will be resized.", category="Generation", ) # QUEUE max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )