mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(TAESD): support TAESD — Tiny Autoencoder for Stable Diffusion
This commit is contained in:
parent
98a4cc20a9
commit
8611ffe32d
@ -1,11 +1,13 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
@ -15,7 +17,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
@ -29,8 +31,21 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -42,22 +57,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
tags,
|
||||
title,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
from ...backend.util.logging import InvokeAILogger
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
@ -514,10 +514,17 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
try:
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
except AttributeError as err:
|
||||
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
|
||||
if err.name.endswith("_tiling"):
|
||||
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
|
||||
else:
|
||||
raise
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
@ -704,16 +711,22 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
vae.to(dtype=torch.float16)
|
||||
# latents = latents.half()
|
||||
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
try:
|
||||
if self.tiled:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
except AttributeError as err:
|
||||
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
|
||||
if err.name.endswith("_tiling"):
|
||||
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
|
||||
else:
|
||||
raise
|
||||
|
||||
# non_noised_latents_from_image
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents = self._encode_to_tensor(vae, image_tensor)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
latents = latents.to(dtype=orig_dtype)
|
||||
@ -722,3 +735,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
latents = latents.to("cpu")
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _encode_to_tensor(self, vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@singledispatchmethod
|
||||
def _encode_to_tensor(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
|
@ -51,6 +51,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user