mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Only replace vae when it is the broken SDXL 1.0 version
This commit is contained in:
parent
a20e17330b
commit
ffa05a0bb3
31
invokeai/backend/model_management/detect_baked_in_vae.py
Normal file
31
invokeai/backend/model_management/detect_baked_in_vae.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
This module exports the function has_baked_in_sdxl_vae().
|
||||
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
|
||||
which doesn't work properly in fp16 mode.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
|
||||
|
||||
|
||||
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
|
||||
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
|
||||
hash = _vae_hash(checkpoint_path)
|
||||
return hash != SDXL_1_0_VAE_HASH
|
||||
|
||||
|
||||
def _vae_hash(checkpoint_path: Path) -> str:
|
||||
checkpoint = load_file(checkpoint_path, device="cpu")
|
||||
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
|
||||
hash = hashlib.new("sha256")
|
||||
for key in vae_keys:
|
||||
value = checkpoint[key]
|
||||
hash.update(bytes(key, "UTF-8"))
|
||||
hash.update(bytes(str(value), "UTF-8"))
|
||||
|
||||
return hash.hexdigest()
|
@ -1,11 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
DiffusersModel,
|
||||
@ -116,17 +121,20 @@ class StableDiffusionXLModel(DiffusersModel):
|
||||
# The convert script adapted from the diffusers package uses
|
||||
# strings for the base model type. To avoid making too many
|
||||
# source code changes, we simply translate here
|
||||
if Path(output_path).exists():
|
||||
return output_path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||
|
||||
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
|
||||
# then we bake it into the converted model.
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
kwargs = dict()
|
||||
# then we bake it into the converted model unless there is already
|
||||
# a nonstandard VAE installed.
|
||||
kwargs = {}
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
|
||||
if vae_path.exists():
|
||||
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
|
||||
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
|
||||
kwargs["vae_path"] = vae_path
|
||||
|
||||
return _convert_ckpt_and_cache(
|
||||
|
Loading…
Reference in New Issue
Block a user