feat(TAESD): support TAESD — Tiny Autoencoder for Stable Diffusion

This commit is contained in:
Kevin Turner 2023-08-17 19:59:31 -07:00
parent 98a4cc20a9
commit 8611ffe32d
2 changed files with 52 additions and 29 deletions

View File

@ -1,11 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
@ -15,7 +17,7 @@ from diffusers.models.attention_processor import (
) )
from diffusers.schedulers import DPMSolverSDEScheduler from diffusers.schedulers import DPMSolverSDEScheduler
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
@ -29,8 +31,21 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import (
from ...backend.model_management import BaseModelType, ModelPatcher BaseInvocation,
FieldDescriptions,
Input,
InputField,
InvocationContext,
UIType,
tags,
title,
)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ResourceOrigin
from ...backend.model_management import BaseModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -42,22 +57,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin from ...backend.util.logging import InvokeAILogger
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
tags,
title,
)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
@ -514,10 +514,17 @@ class LatentsToImageInvocation(BaseInvocation):
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
latents = latents.half() latents = latents.half()
if self.tiled or context.services.configuration.tiled_decode: try:
vae.enable_tiling() if self.tiled or context.services.configuration.tiled_decode:
else: vae.enable_tiling()
vae.disable_tiling() else:
vae.disable_tiling()
except AttributeError as err:
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
if err.name.endswith("_tiling"):
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
else:
raise
# clear memory as vae decode can request a lot # clear memory as vae decode can request a lot
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -704,16 +711,22 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
# latents = latents.half() # latents = latents.half()
if self.tiled: try:
vae.enable_tiling() if self.tiled:
else: vae.enable_tiling()
vae.disable_tiling() else:
vae.disable_tiling()
except AttributeError as err:
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
if err.name.endswith("_tiling"):
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
else:
raise
# non_noised_latents_from_image # non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(): with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist latents = self._encode_to_tensor(vae, image_tensor)
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents = vae.config.scaling_factor * latents latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)
@ -722,3 +735,12 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = latents.to("cpu") latents = latents.to("cpu")
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None) return build_latents_output(latents_name=name, latents=latents, seed=None)
def _encode_to_tensor(self, vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@singledispatchmethod
def _encode_to_tensor(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents

View File

@ -51,6 +51,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
} }