feature: support TAESD - Tiny Autoencoder for Stable Diffusion (#4316)

[TAESD - Tiny Autoencoder for Stable
Diffusion](https://github.com/madebyollin/taesd) - is a tiny VAE that
provides significantly better results than my single-multiplication hack
but is still very fast.

The entire TAESD model weights are under 10 MB!

This PR requires diffusers 0.20:
- [x] #4311 

## To Do

Test with
- [x] SD 1.x
- [ ] SD 2.x: #4415 
- [x] SDXL

## Have you discussed this change with the InvokeAI team?
- See [TAESD Invocation
API](https://discord.com/channels/1020123559063990373/1137857402453119166)
      
## Have you updated all relevant documentation?
- [ ] No


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

Should be able to import these models:
- [madebyollin/taesd](https://huggingface.co/madebyollin/taesd)
- [madebyollin/taesdxl](https://huggingface.co/madebyollin/taesdxl)

and use them as VAE.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [x] Some. There are new tests for VaeFolderProbe based on VAE
configurations, but no tests that require the full model weights.
This commit is contained in:
Kent Keirsey 2023-09-20 17:23:20 -04:00 committed by GitHub
commit b64ade586d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 202 additions and 8 deletions

View File

@ -1,12 +1,14 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import einops import einops
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -857,8 +859,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# non_noised_latents_from_image # non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(): with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents = vae.config.scaling_factor * latents latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)
@ -885,6 +886,18 @@ class ImageToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None) return build_latents_output(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0") @invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
class BlendLatentsInvocation(BaseInvocation): class BlendLatentsInvocation(BaseInvocation):

View File

@ -1,4 +1,5 @@
import json import json
import re
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union from typing import Callable, Dict, Literal, Optional, Union
@ -53,6 +54,7 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision, "CLIPVisionModelWithProjection": ModelType.CLIPVision,
} }
@ -177,6 +179,7 @@ class ModelProbe(object):
Get the model type of a hugging-face style folder. Get the model type of a hugging-face style folder.
""" """
class_name = None class_name = None
error_hint = None
if model: if model:
class_name = model.__class__.__name__ class_name = model.__class__.__name__
else: else:
@ -202,12 +205,18 @@ class ModelProbe(object):
class_name = conf["architectures"][0] class_name = conf["architectures"][0]
else: else:
class_name = None class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)): if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up # give up
raise InvalidModelException(f"Unable to determine model type for {folder_path}") raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod @classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
@ -461,16 +470,32 @@ class PipelineFolderProbe(FolderProbeBase):
class VaeFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json" config_file = self.folder_path / "config.json"
if not config_file.exists(): if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file: with open(config_file, "r") as file:
config = json.load(file) config = json.load(file)
return ( return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] def _name_looks_like_sdxl(self) -> bool:
else BaseModelType.StableDiffusion1 return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
)
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase): class TextualInversionFolderProbe(FolderProbeBase):

22
tests/test_model_probe.py Normal file
View File

@ -0,0 +1,22 @@
from pathlib import Path
import pytest
from invokeai.backend import BaseModelType
from invokeai.backend.model_management.model_probe import VaeFolderProbe
@pytest.mark.parametrize(
"vae_path,expected_type",
[
("sd-vae-ft-mse", BaseModelType.StableDiffusion1),
("sdxl-vae", BaseModelType.StableDiffusionXL),
("taesd", BaseModelType.StableDiffusion1),
("taesdxl", BaseModelType.StableDiffusionXL),
],
)
def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Path):
sd1_vae_path = datadir / "vae" / vae_path
probe = VaeFolderProbe(sd1_vae_path)
base_type = probe.get_base_type()
assert base_type == expected_type

View File

@ -0,0 +1,29 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.4.2",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 256,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
]
}

View File

@ -0,0 +1,31 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.18.0.dev0",
"_name_or_path": ".",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 1024,
"scaling_factor": 0.13025,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
]
}

View File

@ -0,0 +1,37 @@
{
"_class_name": "AutoencoderTiny",
"_diffusers_version": "0.20.0.dev0",
"act_fn": "relu",
"decoder_block_out_channels": [
64,
64,
64,
64
],
"encoder_block_out_channels": [
64,
64,
64,
64
],
"force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"latent_magnitude": 3,
"latent_shift": 0.5,
"num_decoder_blocks": [
3,
3,
3,
1
],
"num_encoder_blocks": [
1,
3,
3,
3
],
"out_channels": 3,
"scaling_factor": 1.0,
"upsampling_scaling_factor": 2
}

View File

@ -0,0 +1,37 @@
{
"_class_name": "AutoencoderTiny",
"_diffusers_version": "0.20.0.dev0",
"act_fn": "relu",
"decoder_block_out_channels": [
64,
64,
64,
64
],
"encoder_block_out_channels": [
64,
64,
64,
64
],
"force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"latent_magnitude": 3,
"latent_shift": 0.5,
"num_decoder_blocks": [
3,
3,
3,
1
],
"num_encoder_blocks": [
1,
3,
3,
3
],
"out_channels": 3,
"scaling_factor": 1.0,
"upsampling_scaling_factor": 2
}