From b1a99d772cf7ce24521e17a6ab936ba7dfccdacc Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 18 May 2023 13:31:11 -0400 Subject: [PATCH] added method to convert vaes --- .../convert_ckpt_to_diffusers.py | 56 +++++++++++----- .../backend/model_management/model_cache.py | 2 +- .../backend/model_management/model_manager.py | 64 ++++++++++++++----- 3 files changed, 87 insertions(+), 35 deletions(-) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 54247962b0..5874d35c6b 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -33,6 +33,7 @@ from .model_cache import ModelCache try: from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig except ImportError: raise ImportError( "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." @@ -614,16 +615,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False return new_checkpoint - def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # Extract state dict for VAE. Works both with burnt-in + # VAEs, and with standalone VAEs. + # checkpoint can either be a all-in-one stable diffusion + # model, or an isolated vae .ckpt. This tests for + # a key that will be present in the all-in-one model + # that isn't present in the isolated ckpt. + probe_key = "first_stage_model.encoder.conv_in.weight" + if probe_key in checkpoint: + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + else: + vae_state_dict = checkpoint + + new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config) + return new_checkpoint + +def convert_ldm_vae_state_dict(vae_state_dict, config): new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] @@ -1049,6 +1063,19 @@ def replace_checkpoint_vae(checkpoint, vae_path:str): new_key = f'first_stage_model.{vae_key}' checkpoint[new_key] = state_dict[vae_key] +def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL: + vae_config = create_vae_diffusers_config( + vae_config, image_size=image_size + ) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + checkpoint, vae_config + ) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + def load_pipeline_from_original_stable_diffusion_ckpt( checkpoint_path: str, original_config_file: str = None, @@ -1244,15 +1271,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( if vae: logger.debug("Using replacement diffusers VAE") else: # convert the original or replacement VAE - vae_config = create_vae_diffusers_config( - original_config, image_size=image_size - ) - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - checkpoint, vae_config - ) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) + vae = convert_ldm_vae_to_diffusers( + checkpoint, + original_config, + image_size) # Convert the text model. model_type = pipeline_type diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 8050eb58b8..5c2b498acf 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -32,7 +32,7 @@ import safetensors.torch from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin from diffusers import logging as diffusers_logging -from huggingface_hub import HfApi +from huggingface_hub import HfApi, scan_cache_dir from picklescan.scanner import scan_file_path from pydantic import BaseModel from transformers import logging as transformers_logging diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 6fe1c38168..a77f0613f3 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -131,9 +131,7 @@ from __future__ import annotations import os import re -import sys import textwrap -from contextlib import suppress from dataclasses import dataclass from enum import Enum, auto from packaging import version @@ -145,6 +143,7 @@ import safetensors import safetensors.torch import torch from diffusers import AutoencoderKL +from diffusers.utils import is_safetensors_available from huggingface_hub import scan_cache_dir from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig @@ -157,7 +156,6 @@ from invokeai.backend.util import download_with_resume from ..util import CUDA_DEVICE from .model_cache import (ModelCache, ModelLocker, SDModelType, SilenceWarnings) - # We are only starting to number the config file with release 3. # The config file version doesn't have to start at release version, but it will help # reduce confusion. @@ -307,17 +305,12 @@ class ModelManager(object): """ - # This is a temporary workaround for callers that use "type/name" as the model name + # Commented-out workaround for callers that use "type/name" as the model name # because they haven't adjusted to the new return format of `list_models()` - if "/" in model_name: - model_key = model_name - else: - model_key = self.create_key(model_name, model_type) - - # TODO: delete default model or add check that this stable diffusion model - # if not model_name: - # model_name = self.default_model() - + # if "/" in model_name: + # model_key = model_name + # else: + model_key = self.create_key(model_name, model_type) if model_key not in self.config: raise InvalidModelError( f'"{model_key}" is not a known model name. Please check your models.yaml file' @@ -358,7 +351,6 @@ class ModelManager(object): if model_type == SDModelType.Vae and "vae" in mconfig: print("NOT_IMPLEMENTED - RETURN CUSTOM VAE") - model_context = self.cache.get_model( location, model_type = model_type, @@ -414,7 +406,7 @@ class ModelManager(object): """ Given a model name returns the OmegaConf (dict-like) object describing it. """ - if not self.exists(model_name, model_type): + if not self.model_exists(model_name, model_type): return None return self.config[self.create_key(model_name, model_type)] @@ -962,6 +954,44 @@ class ModelManager(object): ) return diffusers_path + def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path: + """ + Convert the VAE indicated in mconfig into a diffusers AutoencoderKL + object, cache it to disk, and return Path to converted + file. If already on disk then just returns Path. + """ + weights_file = global_resolve_path(mconfig.weights) + config_file = global_resolve_path(mconfig.config) + diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem + image_size = mconfig.get('width') or mconfig.get('height') or 512 + + # return cached version if it exists + if diffusers_path.exists(): + return diffusers_path + + # this avoids circular import error + from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers + checkpoint = torch.load(weights_file, map_location="cpu")\ + if weights_file.suffix in ['.ckpt','.pt'] \ + else safetensors.torch.load_file(weights_file) + + # sometimes weights are hidden under "state_dict", and sometimes not + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + config = OmegaConf.load(config_file) + + vae_model = convert_ldm_vae_to_diffusers( + checkpoint = checkpoint, + vae_config = config, + image_size = image_size + ) + vae_model.save_pretrained( + diffusers_path, + safe_serialization=is_safetensors_available() + ) + return diffusers_path + def _get_vae_for_conversion( self, weights: Path, @@ -1130,7 +1160,7 @@ class ModelManager(object): @classmethod def _delete_model_from_cache(cls,repo_id): - cache_info = scan_cache_dir(global_cache_dir("hub")) + cache_info = scan_cache_dir(global_cache_dir("hub")) # I'm sure there is a way to do this with comprehensions # but the code quickly became incomprehensible! @@ -1178,7 +1208,7 @@ class ModelManager(object): current_version = self.config.get("_version","1.0.0") if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION): self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}') - self.logger.warning(f'The original file will be renamed models.yaml.orig') + self.logger.warning('The original file will be renamed models.yaml.orig') if self.config_path: old_file = Path(self.config_path) new_name = old_file.parent / 'models.yaml.orig'