From 045caddee134d358f0f4f1b1f95b67ac56520879 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 5 Jun 2024 13:53:53 -0400 Subject: [PATCH] Move LatentsToImageInvocation to its own file. No functional changes. --- invokeai/app/invocations/latent.py | 93 +--------------- invokeai/app/invocations/latents_to_image.py | 107 +++++++++++++++++++ 2 files changed, 108 insertions(+), 92 deletions(-) create mode 100644 invokeai/app/invocations/latents_to_image.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c41dca195e..7102a0a4eb 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -8,16 +8,7 @@ import torch import torchvision import torchvision.transforms as T from diffusers.configuration_utils import ConfigMixin -from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter -from diffusers.models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler from diffusers.schedulers.scheduling_tcd import TCDScheduler @@ -38,11 +29,9 @@ from invokeai.app.invocations.fields import ( LatentsField, OutputField, UIType, - WithBoard, - WithMetadata, ) from invokeai.app.invocations.ip_adapter import IPAdapterField -from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput +from invokeai.app.invocations.primitives import DenoiseMaskOutput, LatentsOutput from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image @@ -1033,83 +1022,3 @@ class DenoiseLatentsInvocation(BaseInvocation): name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) - - -@invocation( - "l2i", - title="Latents to Image", - tags=["latents", "image", "vae", "l2i"], - category="latents", - version="1.2.2", -) -class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): - """Generates an image from latents.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.latents, - input=Input.Connection, - ) - vae: VAEField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) - fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) - - @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.tensors.load(self.latents.latents_name) - - vae_info = context.models.load(self.vae.vae) - assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) - with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: - assert isinstance(vae, torch.nn.Module) - latents = latents.to(vae.device) - if self.fp32: - vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( - vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - vae.post_quant_conv.to(latents.dtype) - vae.decoder.conv_in.to(latents.dtype) - vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() - - else: - vae.to(dtype=torch.float16) - latents = latents.half() - - if self.tiled or context.config.get().force_tiled_decode: - vae.enable_tiling() - else: - vae.disable_tiling() - - # clear memory as vae decode can request a lot - TorchDevice.empty_cache() - - with torch.inference_mode(): - # copied from diffusers pipeline - latents = latents / vae.config.scaling_factor - image = vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) # denormalize - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - image = VaeImageProcessor.numpy_to_pil(np_image)[0] - - TorchDevice.empty_cache() - - image_dto = context.images.save(image=image) - - return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py new file mode 100644 index 0000000000..e5038869dd --- /dev/null +++ b/invokeai/app/invocations/latents_to_image.py @@ -0,0 +1,107 @@ +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.latent import DEFAULT_PRECISION +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion import set_seamless +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "l2i", + title="Latents to Image", + tags=["latents", "image", "vae", "l2i"], + category="latents", + version="1.2.2", +) +class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates an image from latents.""" + + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + + vae_info = context.models.load(self.vae.vae) + assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny)) + with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + assert isinstance(vae, torch.nn.Module) + latents = latents.to(vae.device) + if self.fp32: + vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance( + vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + vae.post_quant_conv.to(latents.dtype) + vae.decoder.conv_in.to(latents.dtype) + vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + else: + vae.to(dtype=torch.float16) + latents = latents.half() + + if self.tiled or context.config.get().force_tiled_decode: + vae.enable_tiling() + else: + vae.disable_tiling() + + # clear memory as vae decode can request a lot + TorchDevice.empty_cache() + + with torch.inference_mode(): + # copied from diffusers pipeline + latents = latents / vae.config.scaling_factor + image = vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) # denormalize + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + image = VaeImageProcessor.numpy_to_pil(np_image)[0] + + TorchDevice.empty_cache() + + image_dto = context.images.save(image=image) + + return ImageOutput.build(image_dto)