Work around missing core conversion model issue

- This adds additional logic to the safetensors->diffusers conversion script
  to check for and install missing core conversion models at runtime.

- Fixes #5934
This commit is contained in:
Lincoln Stein 2024-03-13 07:57:03 -04:00 committed by psychedelicious
parent e3f29ed320
commit 1bd8e33f8c

View File

@ -76,6 +76,25 @@ logger = InvokeAILogger.get_logger(__name__)
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert"
def install_dependencies():
"""
Check for, and install, missing model dependencies.
"""
conversion_models = [
"clip-vit-large-patch14",
"CLIP-ViT-H-14-laion2B-s32B-b79K",
"stable-diffusion-2-clip",
"stable-diffusion-safety-checker",
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased",
]
if any(not (CONVERT_MODEL_ROOT / x).exists() for x in conversion_models):
logger.warning("Installing missing core safetensor conversion models")
from invokeai.backend.install.invokeai_configure import download_conversion_models # noqa
download_conversion_models()
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
@ -1697,6 +1716,8 @@ def download_controlnet_from_original_ckpt(
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
install_dependencies()
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
@ -1717,6 +1738,7 @@ def convert_ckpt_to_diffusers(
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
install_dependencies()
pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
# TO DO: save correct repo variant
@ -1736,6 +1758,7 @@ def convert_controlnet_to_diffusers(
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
install_dependencies()
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
# TO DO: save correct repo variant