diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 40f7af8703..ee6ed63da8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,11 +1,13 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from contextlib import ExitStack +from functools import singledispatchmethod from typing import List, Literal, Optional, Union import einops import torch import torchvision.transforms as T +from diffusers import AutoencoderKL, AutoencoderTiny from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import ( AttnProcessor2_0, @@ -15,7 +17,7 @@ from diffusers.models.attention_processor import ( ) from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import SchedulerMixin as Scheduler -from pydantic import BaseModel, Field, validator +from pydantic import validator from torchvision.transforms.functional import resize as tv_resize from invokeai.app.invocations.metadata import CoreMetadata @@ -29,8 +31,21 @@ from invokeai.app.invocations.primitives import ( from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings - -from ...backend.model_management import BaseModelType, ModelPatcher +from .baseinvocation import ( + BaseInvocation, + FieldDescriptions, + Input, + InputField, + InvocationContext, + UIType, + tags, + title, +) +from .compel import ConditioningField +from .controlnet_image_processors import ControlField +from .model import ModelInfo, UNetField, VaeField +from ..models.image import ImageCategory, ResourceOrigin +from ...backend.model_management import BaseModelType from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( @@ -42,22 +57,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import choose_precision, choose_torch_device -from ..models.image import ImageCategory, ResourceOrigin -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - FieldDescriptions, - Input, - InputField, - InvocationContext, - OutputField, - UIType, - tags, - title, -) -from .compel import ConditioningField -from .controlnet_image_processors import ControlField -from .model import ModelInfo, UNetField, VaeField +from ...backend.util.logging import InvokeAILogger DEFAULT_PRECISION = choose_precision(choose_torch_device()) @@ -514,10 +514,17 @@ class LatentsToImageInvocation(BaseInvocation): vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.services.configuration.tiled_decode: - vae.enable_tiling() - else: - vae.disable_tiling() + try: + if self.tiled or context.services.configuration.tiled_decode: + vae.enable_tiling() + else: + vae.disable_tiling() + except AttributeError as err: + # FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627 + if err.name.endswith("_tiling"): + InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err) + else: + raise # clear memory as vae decode can request a lot torch.cuda.empty_cache() @@ -704,16 +711,22 @@ class ImageToLatentsInvocation(BaseInvocation): vae.to(dtype=torch.float16) # latents = latents.half() - if self.tiled: - vae.enable_tiling() - else: - vae.disable_tiling() + try: + if self.tiled: + vae.enable_tiling() + else: + vae.disable_tiling() + except AttributeError as err: + # FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627 + if err.name.endswith("_tiling"): + InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err) + else: + raise # non_noised_latents_from_image image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) with torch.inference_mode(): - image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + latents = self._encode_to_tensor(vae, image_tensor) latents = vae.config.scaling_factor * latents latents = latents.to(dtype=orig_dtype) @@ -722,3 +735,12 @@ class ImageToLatentsInvocation(BaseInvocation): latents = latents.to("cpu") context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents, seed=None) + + def _encode_to_tensor(self, vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + image_tensor_dist = vae.encode(image_tensor).latent_dist + latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + return latents + + @singledispatchmethod + def _encode_to_tensor(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + return vae.encode(image_tensor).latents diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 3045849065..f157fb177a 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -51,6 +51,7 @@ class ModelProbe(object): "StableDiffusionXLPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main, "AutoencoderKL": ModelType.Vae, + "AutoencoderTiny": ModelType.Vae, "ControlNetModel": ModelType.ControlNet, }