From ffa05a0bb33d70213d5ed10326b6410c3b229e46 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 6 Jan 2024 13:41:35 -0500 Subject: [PATCH] Only replace vae when it is the broken SDXL 1.0 version --- .../model_management/detect_baked_in_vae.py | 31 +++++++++++++++++++ .../backend/model_management/models/sdxl.py | 18 ++++++++--- 2 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 invokeai/backend/model_management/detect_baked_in_vae.py diff --git a/invokeai/backend/model_management/detect_baked_in_vae.py b/invokeai/backend/model_management/detect_baked_in_vae.py new file mode 100644 index 0000000000..9118438548 --- /dev/null +++ b/invokeai/backend/model_management/detect_baked_in_vae.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team +""" +This module exports the function has_baked_in_sdxl_vae(). +It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE, +which doesn't work properly in fp16 mode. +""" + +import hashlib +from pathlib import Path + +from safetensors.torch import load_file + +SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51" + + +def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool: + """Return true if the checkpoint contains a custom (non SDXL-1.0) VAE.""" + hash = _vae_hash(checkpoint_path) + return hash != SDXL_1_0_VAE_HASH + + +def _vae_hash(checkpoint_path: Path) -> str: + checkpoint = load_file(checkpoint_path, device="cpu") + vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")] + hash = hashlib.new("sha256") + for key in vae_keys: + value = checkpoint[key] + hash.update(bytes(key, "UTF-8")) + hash.update(bytes(str(value), "UTF-8")) + + return hash.hexdigest() diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management/models/sdxl.py index 1cb5971c33..53c080fa66 100644 --- a/invokeai/backend/model_management/models/sdxl.py +++ b/invokeai/backend/model_management/models/sdxl.py @@ -1,11 +1,16 @@ import json import os from enum import Enum +from pathlib import Path from typing import Literal, Optional from omegaconf import OmegaConf from pydantic import Field +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae +from invokeai.backend.util.logging import InvokeAILogger + from .base import ( BaseModelType, DiffusersModel, @@ -116,17 +121,20 @@ class StableDiffusionXLModel(DiffusersModel): # The convert script adapted from the diffusers package uses # strings for the base model type. To avoid making too many # source code changes, we simply translate here + if Path(output_path).exists(): + return output_path + if isinstance(config, cls.CheckpointConfig): from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache # Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed, - # then we bake it into the converted model. - from invokeai.app.services.config import InvokeAIAppConfig - - kwargs = dict() + # then we bake it into the converted model unless there is already + # a nonstandard VAE installed. + kwargs = {} app_config = InvokeAIAppConfig.get_config() vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix" - if vae_path.exists(): + if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)): + InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.") kwargs["vae_path"] = vae_path return _convert_ckpt_and_cache(