Only replace vae when it is the broken SDXL 1.0 version

This commit is contained in:
Lincoln Stein 2024-01-06 13:41:35 -05:00 committed by Kent Keirsey
parent a20e17330b
commit ffa05a0bb3
2 changed files with 44 additions and 5 deletions

View File

@ -0,0 +1,31 @@
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
"""
This module exports the function has_baked_in_sdxl_vae().
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
which doesn't work properly in fp16 mode.
"""
import hashlib
from pathlib import Path
from safetensors.torch import load_file
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
hash = _vae_hash(checkpoint_path)
return hash != SDXL_1_0_VAE_HASH
def _vae_hash(checkpoint_path: Path) -> str:
checkpoint = load_file(checkpoint_path, device="cpu")
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
hash = hashlib.new("sha256")
for key in vae_keys:
value = checkpoint[key]
hash.update(bytes(key, "UTF-8"))
hash.update(bytes(str(value), "UTF-8"))
return hash.hexdigest()

View File

@ -1,11 +1,16 @@
import json
import os
from enum import Enum
from pathlib import Path
from typing import Literal, Optional
from omegaconf import OmegaConf
from pydantic import Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
from invokeai.backend.util.logging import InvokeAILogger
from .base import (
BaseModelType,
DiffusersModel,
@ -116,17 +121,20 @@ class StableDiffusionXLModel(DiffusersModel):
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
if Path(output_path).exists():
return output_path
if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
# then we bake it into the converted model.
from invokeai.app.services.config import InvokeAIAppConfig
kwargs = dict()
# then we bake it into the converted model unless there is already
# a nonstandard VAE installed.
kwargs = {}
app_config = InvokeAIAppConfig.get_config()
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
if vae_path.exists():
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
kwargs["vae_path"] = vae_path
return _convert_ckpt_and_cache(