Only replace vae when it is the broken SDXL 1.0 version

This commit is contained in:
Lincoln Stein 2024-01-06 13:41:35 -05:00 committed by Kent Keirsey
parent a20e17330b
commit ffa05a0bb3
2 changed files with 44 additions and 5 deletions

View File

@ -0,0 +1,31 @@
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
"""
This module exports the function has_baked_in_sdxl_vae().
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
which doesn't work properly in fp16 mode.
"""
import hashlib
from pathlib import Path
from safetensors.torch import load_file
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
hash = _vae_hash(checkpoint_path)
return hash != SDXL_1_0_VAE_HASH
def _vae_hash(checkpoint_path: Path) -> str:
checkpoint = load_file(checkpoint_path, device="cpu")
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
hash = hashlib.new("sha256")
for key in vae_keys:
value = checkpoint[key]
hash.update(bytes(key, "UTF-8"))
hash.update(bytes(str(value), "UTF-8"))
return hash.hexdigest()

View File

@ -1,11 +1,16 @@
import json import json
import os import os
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Literal, Optional from typing import Literal, Optional
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pydantic import Field from pydantic import Field
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
from invokeai.backend.util.logging import InvokeAILogger
from .base import ( from .base import (
BaseModelType, BaseModelType,
DiffusersModel, DiffusersModel,
@ -116,17 +121,20 @@ class StableDiffusionXLModel(DiffusersModel):
# The convert script adapted from the diffusers package uses # The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many # strings for the base model type. To avoid making too many
# source code changes, we simply translate here # source code changes, we simply translate here
if Path(output_path).exists():
return output_path
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed, # Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
# then we bake it into the converted model. # then we bake it into the converted model unless there is already
from invokeai.app.services.config import InvokeAIAppConfig # a nonstandard VAE installed.
kwargs = {}
kwargs = dict()
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix" vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
if vae_path.exists(): if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
kwargs["vae_path"] = vae_path kwargs["vae_path"] = vae_path
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(