feat(TAESD): support TAESD — Tiny Autoencoder for Stable Diffusion

This commit is contained in:
Kevin Turner 2023-08-17 19:59:31 -07:00
parent 98a4cc20a9
commit 8611ffe32d
2 changed files with 52 additions and 29 deletions

View File

@ -1,11 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
import einops
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
@ -15,7 +17,7 @@ from diffusers.models.attention_processor import (
)
from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator
from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata
@ -29,8 +31,21 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management import BaseModelType, ModelPatcher
from .baseinvocation import (
BaseInvocation,
FieldDescriptions,
Input,
InputField,
InvocationContext,
UIType,
tags,
title,
)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ResourceOrigin
from ...backend.model_management import BaseModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
@ -42,22 +57,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,
)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
from ...backend.util.logging import InvokeAILogger
DEFAULT_PRECISION = choose_precision(choose_torch_device())
@ -514,10 +514,17 @@ class LatentsToImageInvocation(BaseInvocation):
vae.to(dtype=torch.float16)
latents = latents.half()
try:
if self.tiled or context.services.configuration.tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
except AttributeError as err:
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
if err.name.endswith("_tiling"):
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
else:
raise
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
@ -704,16 +711,22 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.to(dtype=torch.float16)
# latents = latents.half()
try:
if self.tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
except AttributeError as err:
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
if err.name.endswith("_tiling"):
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
else:
raise
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents = self._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
@ -722,3 +735,12 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
def _encode_to_tensor(self, vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@singledispatchmethod
def _encode_to_tensor(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents

View File

@ -51,6 +51,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
}