mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Only replace vae when it is the broken SDXL 1.0 version
This commit is contained in:
parent
a20e17330b
commit
ffa05a0bb3
31
invokeai/backend/model_management/detect_baked_in_vae.py
Normal file
31
invokeai/backend/model_management/detect_baked_in_vae.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
This module exports the function has_baked_in_sdxl_vae().
|
||||||
|
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
|
||||||
|
which doesn't work properly in fp16 mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
|
||||||
|
|
||||||
|
|
||||||
|
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
|
||||||
|
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
|
||||||
|
hash = _vae_hash(checkpoint_path)
|
||||||
|
return hash != SDXL_1_0_VAE_HASH
|
||||||
|
|
||||||
|
|
||||||
|
def _vae_hash(checkpoint_path: Path) -> str:
|
||||||
|
checkpoint = load_file(checkpoint_path, device="cpu")
|
||||||
|
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
|
||||||
|
hash = hashlib.new("sha256")
|
||||||
|
for key in vae_keys:
|
||||||
|
value = checkpoint[key]
|
||||||
|
hash.update(bytes(key, "UTF-8"))
|
||||||
|
hash.update(bytes(str(value), "UTF-8"))
|
||||||
|
|
||||||
|
return hash.hexdigest()
|
@ -1,11 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.model_management.detect_baked_in_vae import has_baked_in_sdxl_vae
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
DiffusersModel,
|
DiffusersModel,
|
||||||
@ -116,17 +121,20 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
# The convert script adapted from the diffusers package uses
|
# The convert script adapted from the diffusers package uses
|
||||||
# strings for the base model type. To avoid making too many
|
# strings for the base model type. To avoid making too many
|
||||||
# source code changes, we simply translate here
|
# source code changes, we simply translate here
|
||||||
|
if Path(output_path).exists():
|
||||||
|
return output_path
|
||||||
|
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||||
|
|
||||||
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
|
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
|
||||||
# then we bake it into the converted model.
|
# then we bake it into the converted model unless there is already
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
# a nonstandard VAE installed.
|
||||||
|
kwargs = {}
|
||||||
kwargs = dict()
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
|
vae_path = app_config.models_path / "sdxl/vae/sdxl-vae-fp16-fix"
|
||||||
if vae_path.exists():
|
if vae_path.exists() and not has_baked_in_sdxl_vae(Path(model_path)):
|
||||||
|
InvokeAILogger.get_logger().warning("No baked-in VAE detected. Inserting sdxl-vae-fp16-fix.")
|
||||||
kwargs["vae_path"] = vae_path
|
kwargs["vae_path"] = vae_path
|
||||||
|
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user