mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(TAESD): support TAESD — Tiny Autoencoder for Stable Diffusion
This commit is contained in:
parent
98a4cc20a9
commit
8611ffe32d
@ -1,11 +1,13 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
|
from functools import singledispatchmethod
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
@ -15,7 +17,7 @@ from diffusers.models.attention_processor import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
@ -29,8 +31,21 @@ from invokeai.app.invocations.primitives import (
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
from .baseinvocation import (
|
||||||
from ...backend.model_management import BaseModelType, ModelPatcher
|
BaseInvocation,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
UIType,
|
||||||
|
tags,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
from .compel import ConditioningField
|
||||||
|
from .controlnet_image_processors import ControlField
|
||||||
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
|
from ...backend.model_management import BaseModelType
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
@ -42,22 +57,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ...backend.util.logging import InvokeAILogger
|
||||||
from .baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
FieldDescriptions,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIType,
|
|
||||||
tags,
|
|
||||||
title,
|
|
||||||
)
|
|
||||||
from .compel import ConditioningField
|
|
||||||
from .controlnet_image_processors import ControlField
|
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
|
||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
@ -514,10 +514,17 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
latents = latents.half()
|
latents = latents.half()
|
||||||
|
|
||||||
|
try:
|
||||||
if self.tiled or context.services.configuration.tiled_decode:
|
if self.tiled or context.services.configuration.tiled_decode:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
except AttributeError as err:
|
||||||
|
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
|
||||||
|
if err.name.endswith("_tiling"):
|
||||||
|
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -704,16 +711,22 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
# latents = latents.half()
|
# latents = latents.half()
|
||||||
|
|
||||||
|
try:
|
||||||
if self.tiled:
|
if self.tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
except AttributeError as err:
|
||||||
|
# FIXME: This is a TEMPORARY measure until AutoencoderTiny gets tiling support from https://github.com/huggingface/diffusers/pull/4627
|
||||||
|
if err.name.endswith("_tiling"):
|
||||||
|
InvokeAILogger.getLogger(self.__class__.__name__).debug("ignoring tiling error for %s", vae.__class__, exc_info=err)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
# non_noised_latents_from_image
|
# non_noised_latents_from_image
|
||||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
latents = self._encode_to_tensor(vae, image_tensor)
|
||||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
|
||||||
|
|
||||||
latents = vae.config.scaling_factor * latents
|
latents = vae.config.scaling_factor * latents
|
||||||
latents = latents.to(dtype=orig_dtype)
|
latents = latents.to(dtype=orig_dtype)
|
||||||
@ -722,3 +735,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
|
def _encode_to_tensor(self, vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||||
|
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@singledispatchmethod
|
||||||
|
def _encode_to_tensor(self, vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
return vae.encode(image_tensor).latents
|
||||||
|
@ -51,6 +51,7 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionXLPipeline": ModelType.Main,
|
"StableDiffusionXLPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.Vae,
|
"AutoencoderKL": ModelType.Vae,
|
||||||
|
"AutoencoderTiny": ModelType.Vae,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user