feature: support TAESD - Tiny Autoencoder for Stable Diffusion (#4316)

[TAESD - Tiny Autoencoder for Stable
Diffusion](https://github.com/madebyollin/taesd) - is a tiny VAE that
provides significantly better results than my single-multiplication hack
but is still very fast.

The entire TAESD model weights are under 10 MB!

This PR requires diffusers 0.20:
- [x] #4311 

## To Do

Test with
- [x] SD 1.x
- [ ] SD 2.x: #4415 
- [x] SDXL

## Have you discussed this change with the InvokeAI team?
- See [TAESD Invocation
API](https://discord.com/channels/1020123559063990373/1137857402453119166)
      
## Have you updated all relevant documentation?
- [ ] No


## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

Should be able to import these models:
- [madebyollin/taesd](https://huggingface.co/madebyollin/taesd)
- [madebyollin/taesdxl](https://huggingface.co/madebyollin/taesdxl)

and use them as VAE.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [x] Some. There are new tests for VaeFolderProbe based on VAE
configurations, but no tests that require the full model weights.
This commit is contained in:
Kent Keirsey 2023-09-20 17:23:20 -04:00 committed by GitHub
commit b64ade586d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 202 additions and 8 deletions

View File

@ -1,12 +1,14 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
import einops
import numpy as np
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import (
@ -857,8 +859,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
@ -885,6 +886,18 @@ class ImageToLatentsInvocation(BaseInvocation):
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
return vae.encode(image_tensor).latents
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
class BlendLatentsInvocation(BaseInvocation):

View File

@ -1,4 +1,5 @@
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union
@ -53,6 +54,7 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
}
@ -177,6 +179,7 @@ class ModelProbe(object):
Get the model type of a hugging-face style folder.
"""
class_name = None
error_hint = None
if model:
class_name = model.__class__.__name__
else:
@ -202,12 +205,18 @@ class ModelProbe(object):
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
@ -461,16 +470,32 @@ class PipelineFolderProbe(FolderProbeBase):
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
return (
BaseModelType.StableDiffusionXL
if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
else BaseModelType.StableDiffusion1
)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.folder_path.name
if name == "vae":
name = self.folder_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):

22
tests/test_model_probe.py Normal file
View File

@ -0,0 +1,22 @@
from pathlib import Path
import pytest
from invokeai.backend import BaseModelType
from invokeai.backend.model_management.model_probe import VaeFolderProbe
@pytest.mark.parametrize(
"vae_path,expected_type",
[
("sd-vae-ft-mse", BaseModelType.StableDiffusion1),
("sdxl-vae", BaseModelType.StableDiffusionXL),
("taesd", BaseModelType.StableDiffusion1),
("taesdxl", BaseModelType.StableDiffusionXL),
],
)
def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Path):
sd1_vae_path = datadir / "vae" / vae_path
probe = VaeFolderProbe(sd1_vae_path)
base_type = probe.get_base_type()
assert base_type == expected_type

View File

@ -0,0 +1,29 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.4.2",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 256,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
]
}

View File

@ -0,0 +1,31 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.18.0.dev0",
"_name_or_path": ".",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 1024,
"scaling_factor": 0.13025,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
]
}

View File

@ -0,0 +1,37 @@
{
"_class_name": "AutoencoderTiny",
"_diffusers_version": "0.20.0.dev0",
"act_fn": "relu",
"decoder_block_out_channels": [
64,
64,
64,
64
],
"encoder_block_out_channels": [
64,
64,
64,
64
],
"force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"latent_magnitude": 3,
"latent_shift": 0.5,
"num_decoder_blocks": [
3,
3,
3,
1
],
"num_encoder_blocks": [
1,
3,
3,
3
],
"out_channels": 3,
"scaling_factor": 1.0,
"upsampling_scaling_factor": 2
}

View File

@ -0,0 +1,37 @@
{
"_class_name": "AutoencoderTiny",
"_diffusers_version": "0.20.0.dev0",
"act_fn": "relu",
"decoder_block_out_channels": [
64,
64,
64,
64
],
"encoder_block_out_channels": [
64,
64,
64,
64
],
"force_upcast": false,
"in_channels": 3,
"latent_channels": 4,
"latent_magnitude": 3,
"latent_shift": 0.5,
"num_decoder_blocks": [
3,
3,
3,
1
],
"num_encoder_blocks": [
1,
3,
3,
3
],
"out_channels": 3,
"scaling_factor": 1.0,
"upsampling_scaling_factor": 2
}