diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ddee01e7b8..6c7f3da366 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,12 +1,14 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from contextlib import ExitStack +from functools import singledispatchmethod from typing import List, Literal, Optional, Union import einops import numpy as np import torch import torchvision.transforms as T +from diffusers import AutoencoderKL, AutoencoderTiny from diffusers.image_processor import VaeImageProcessor from diffusers.models import UNet2DConditionModel from diffusers.models.attention_processor import ( @@ -857,8 +859,7 @@ class ImageToLatentsInvocation(BaseInvocation): # non_noised_latents_from_image image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) with torch.inference_mode(): - image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor) latents = vae.config.scaling_factor * latents latents = latents.to(dtype=orig_dtype) @@ -885,6 +886,18 @@ class ImageToLatentsInvocation(BaseInvocation): context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents, seed=None) + @singledispatchmethod + @staticmethod + def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + image_tensor_dist = vae.encode(image_tensor).latent_dist + latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! + return latents + + @_encode_to_tensor.register + @staticmethod + def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor: + return vae.encode(image_tensor).latents + @invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0") class BlendLatentsInvocation(BaseInvocation): diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index c9a55f3888..1fc4a51354 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -1,4 +1,5 @@ import json +import re from dataclasses import dataclass from pathlib import Path from typing import Callable, Dict, Literal, Optional, Union @@ -53,6 +54,7 @@ class ModelProbe(object): "StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main, "AutoencoderKL": ModelType.Vae, + "AutoencoderTiny": ModelType.Vae, "ControlNetModel": ModelType.ControlNet, "CLIPVisionModelWithProjection": ModelType.CLIPVision, } @@ -177,6 +179,7 @@ class ModelProbe(object): Get the model type of a hugging-face style folder. """ class_name = None + error_hint = None if model: class_name = model.__class__.__name__ else: @@ -202,12 +205,18 @@ class ModelProbe(object): class_name = conf["architectures"][0] else: class_name = None + else: + error_hint = f"No model_index.json or config.json found in {folder_path}." if class_name and (type := cls.CLASS2TYPE.get(class_name)): return type + else: + error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" # give up - raise InvalidModelException(f"Unable to determine model type for {folder_path}") + raise InvalidModelException( + f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") + ) @classmethod def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: @@ -461,16 +470,32 @@ class PipelineFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: + if self._config_looks_like_sdxl(): + return BaseModelType.StableDiffusionXL + elif self._name_looks_like_sdxl(): + # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down + # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. + return BaseModelType.StableDiffusionXL + else: + return BaseModelType.StableDiffusion1 + + def _config_looks_like_sdxl(self) -> bool: + # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") with open(config_file, "r") as file: config = json.load(file) - return ( - BaseModelType.StableDiffusionXL - if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - else BaseModelType.StableDiffusion1 - ) + return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] + + def _name_looks_like_sdxl(self) -> bool: + return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) + + def _guess_name(self) -> str: + name = self.folder_path.name + if name == "vae": + name = self.folder_path.parent.name + return name class TextualInversionFolderProbe(FolderProbeBase): diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py new file mode 100644 index 0000000000..248b7d602f --- /dev/null +++ b/tests/test_model_probe.py @@ -0,0 +1,22 @@ +from pathlib import Path + +import pytest + +from invokeai.backend import BaseModelType +from invokeai.backend.model_management.model_probe import VaeFolderProbe + + +@pytest.mark.parametrize( + "vae_path,expected_type", + [ + ("sd-vae-ft-mse", BaseModelType.StableDiffusion1), + ("sdxl-vae", BaseModelType.StableDiffusionXL), + ("taesd", BaseModelType.StableDiffusion1), + ("taesdxl", BaseModelType.StableDiffusionXL), + ], +) +def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Path): + sd1_vae_path = datadir / "vae" / vae_path + probe = VaeFolderProbe(sd1_vae_path) + base_type = probe.get_base_type() + assert base_type == expected_type diff --git a/tests/test_model_probe/vae/sd-vae-ft-mse/config.json b/tests/test_model_probe/vae/sd-vae-ft-mse/config.json new file mode 100644 index 0000000000..0db2671757 --- /dev/null +++ b/tests/test_model_probe/vae/sd-vae-ft-mse/config.json @@ -0,0 +1,29 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.4.2", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ] +} diff --git a/tests/test_model_probe/vae/sdxl-vae/config.json b/tests/test_model_probe/vae/sdxl-vae/config.json new file mode 100644 index 0000000000..2c7267b492 --- /dev/null +++ b/tests/test_model_probe/vae/sdxl-vae/config.json @@ -0,0 +1,31 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.18.0.dev0", + "_name_or_path": ".", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 1024, + "scaling_factor": 0.13025, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ] +} diff --git a/tests/test_model_probe/vae/taesd/config.json b/tests/test_model_probe/vae/taesd/config.json new file mode 100644 index 0000000000..62f01c3eb4 --- /dev/null +++ b/tests/test_model_probe/vae/taesd/config.json @@ -0,0 +1,37 @@ +{ + "_class_name": "AutoencoderTiny", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "relu", + "decoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "encoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 4, + "latent_magnitude": 3, + "latent_shift": 0.5, + "num_decoder_blocks": [ + 3, + 3, + 3, + 1 + ], + "num_encoder_blocks": [ + 1, + 3, + 3, + 3 + ], + "out_channels": 3, + "scaling_factor": 1.0, + "upsampling_scaling_factor": 2 +} diff --git a/tests/test_model_probe/vae/taesdxl/config.json b/tests/test_model_probe/vae/taesdxl/config.json new file mode 100644 index 0000000000..62f01c3eb4 --- /dev/null +++ b/tests/test_model_probe/vae/taesdxl/config.json @@ -0,0 +1,37 @@ +{ + "_class_name": "AutoencoderTiny", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "relu", + "decoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "encoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 4, + "latent_magnitude": 3, + "latent_shift": 0.5, + "num_decoder_blocks": [ + 3, + 3, + 3, + 1 + ], + "num_encoder_blocks": [ + 1, + 3, + 3, + 3 + ], + "out_channels": 3, + "scaling_factor": 1.0, + "upsampling_scaling_factor": 2 +}