diff --git a/docs/installation/050_INSTALLING_MODELS.md b/docs/installation/050_INSTALLING_MODELS.md index 10589098d2..1d9a0b6c7f 100644 --- a/docs/installation/050_INSTALLING_MODELS.md +++ b/docs/installation/050_INSTALLING_MODELS.md @@ -211,6 +211,26 @@ description for the model, whether to make this the default model that is loaded at InvokeAI startup time, and whether to replace its VAE. Generally the answer to the latter question is "no". +### Specifying a configuration file for legacy checkpoints + +Some checkpoint files come with instructions to use a specific .yaml +configuration file. For InvokeAI load this file correctly, please put +the config file in the same directory as the corresponding `.ckpt` or +`.safetensors` file and make sure the file has the same basename as +the weights file. Here is an example: + +```bash +wonderful-model-v2.ckpt +wonderful-model-v2.yaml +``` + +Similarly, to use a custom VAE, name the VAE like this: + +```bash +wonderful-model-v2.vae.pt +``` + + ### Converting legacy models into `diffusers` The CLI `!convert_model` will convert a `.safetensors` or `.ckpt` diff --git a/installer/templates/invoke.sh.in b/installer/templates/invoke.sh.in index 4576c7172f..eebb3a3c5d 100644 --- a/installer/templates/invoke.sh.in +++ b/installer/templates/invoke.sh.in @@ -1,5 +1,8 @@ #!/bin/bash +# coauthored by Lincoln Stein, Eugene Brodsky and JoshuaKimsey +# Copyright 2023, The InvokeAI Development Team + #### # This launch script assumes that: # 1. it is located in the runtime directory, @@ -18,78 +21,135 @@ cd "$scriptdir" . .venv/bin/activate export INVOKEAI_ROOT="$scriptdir" +PARAMS=$@ # set required env var for torch on mac MPS if [ "$(uname -s)" == "Darwin" ]; then export PYTORCH_ENABLE_MPS_FALLBACK=1 fi -if [ "$0" != "bash" ]; then - while true - do - echo "Do you want to generate images using the" - echo "1. command-line interface" - echo "2. browser-based UI" - echo "3. run textual inversion training" - echo "4. merge models (diffusers type only)" - echo "5. download and install models" - echo "6. change InvokeAI startup options" - echo "7. re-run the configure script to fix a broken install" - echo "8. open the developer console" - echo "9. update InvokeAI" - echo "10. command-line help" - echo "Q - Quit" - echo "" - read -p "Please enter 1-10, Q: [2] " yn - choice=${yn:='2'} - case $choice in - 1) - echo "Starting the InvokeAI command-line..." - invokeai $@ +do_choice() { + case $1 in + 1) + echo "Generate images with a browser-based interface" + clear + invokeai --web $PARAMS ;; - 2) - echo "Starting the InvokeAI browser-based UI..." - invokeai --web $@ + 2) + echo "Generate images using a command-line interface" + clear + invokeai $PARAMS ;; - 3) - echo "Starting Textual Inversion:" - invokeai-ti --gui $@ + 3) + echo "Textual inversion training" + clear + invokeai-ti --gui $PARAMS ;; - 4) - echo "Merging Models:" - invokeai-merge --gui $@ + 4) + echo "Merge models (diffusers type only)" + clear + invokeai-merge --gui $PARAMS ;; - 5) + 5) + echo "Download and install models" + clear invokeai-model-install --root ${INVOKEAI_ROOT} ;; - 6) + 6) + echo "Change InvokeAI startup options" + clear invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models ;; - 7) + 7) + echo "Re-run the configure script to fix a broken install" + clear invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only ;; - 8) - echo "Developer Console:" + 8) + echo "Open the developer console" + clear file_name=$(basename "${BASH_SOURCE[0]}") bash --init-file "$file_name" ;; - 9) - echo "Update:" + 9) + echo "Update InvokeAI" + clear invokeai-update ;; - 10) + 10) + echo "Command-line help" + clear invokeai --help ;; - [qQ]) - exit 0 + *) + echo "Exiting..." + exit ;; - *) - echo "Invalid selection" - exit;; esac - done + clear +} + +do_dialog() { + while true + do + options=( + 1 "Generate images with a browser-based interface" + 2 "Generate images using a command-line interface" + 3 "Textual inversion training" + 4 "Merge models (diffusers type only)" + 5 "Download and install models" + 6 "Change InvokeAI startup options" + 7 "Re-run the configure script to fix a broken install" + 8 "Open the developer console" + 9 "Update InvokeAI" + 10 "Command-line help") + + choice=$(dialog --clear \ + --backtitle "InvokeAI" \ + --title "What would you like to run?" \ + --menu "Select an option:" \ + 0 0 0 \ + "${options[@]}" \ + 2>&1 >/dev/tty) || clear + do_choice "$choice" + done + clear +} + +do_line_input() { + echo " ** For a more attractive experience, please install the 'dialog' utility. **" + echo "" + while true + do + echo "Do you want to generate images using the" + echo "1. browser-based UI" + echo "2. command-line interface" + echo "3. run textual inversion training" + echo "4. merge models (diffusers type only)" + echo "5. download and install models" + echo "6. change InvokeAI startup options" + echo "7. re-run the configure script to fix a broken install" + echo "8. open the developer console" + echo "9. update InvokeAI" + echo "10. command-line help" + echo "Q - Quit" + echo "" + read -p "Please enter 1-10, Q: [1] " yn + choice=${yn:='1'} + do_choice $choice + done +} + +if [ "$0" != "bash" ]; then + # Dialog seems to be a standard installtion for most Linux distros, but this checks to ensure it is present regardless + if command -v dialog &> /dev/null ; then + do_dialog + else + do_line_input + fi else # in developer console python --version echo "Press ^D to exit" export PS1="(InvokeAI) \u@\h \w> " fi + diff --git a/ldm/invoke/ckpt_to_diffuser.py b/ldm/invoke/ckpt_to_diffuser.py index 1d41fa5bd1..44f48a77cd 100644 --- a/ldm/invoke/ckpt_to_diffuser.py +++ b/ldm/invoke/ckpt_to_diffuser.py @@ -18,17 +18,17 @@ """ Conversion script for the LDM checkpoints. """ import re -import torch import warnings from pathlib import Path -from ldm.invoke.globals import ( - global_cache_dir, - global_config_dir, - ) -from ldm.invoke.model_manager import ModelManager, SDLegacyType -from safetensors.torch import load_file from typing import Union +import torch +from safetensors.torch import load_file + +from .globals import global_cache_dir, global_config_dir + +from .model_manager import ModelManager, SDLegacyType + try: from omegaconf import OmegaConf except ImportError: @@ -48,16 +48,31 @@ from diffusers import ( PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel, - logging as dlogging, ) -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers import logging as dlogging +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import ( + LDMBertConfig, + LDMBertModel, +) +from diffusers.pipelines.paint_by_example import ( + PaintByExampleImageEncoder, + PaintByExamplePipeline, +) +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) from diffusers.utils import is_safetensors_available -from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionConfig, +) from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline + def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. @@ -83,7 +98,9 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -99,7 +116,9 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -150,7 +169,9 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -158,7 +179,12 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, ): """ This does the final conversion step: take locally converted weights and apply a global renaming @@ -167,7 +193,9 @@ def assign_to_checkpoint( Assigns the weights to the new checkpoint. """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + assert isinstance( + paths, list + ), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: @@ -179,7 +207,9 @@ def assign_to_checkpoint( num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + old_tensor = old_tensor.reshape( + (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] + ) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) @@ -190,7 +220,10 @@ def assign_to_checkpoint( new_path = path["new"] # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: + if ( + attention_paths_to_split is not None + and new_path in attention_paths_to_split + ): continue # Global renaming happens here @@ -228,19 +261,29 @@ def create_unet_diffusers_config(original_config, image_size: int): unet_params = original_config.model.params.unet_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [ + unet_params.model_channels * mult for mult in unet_params.channel_mult + ] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = ( + "CrossAttnDownBlock2D" + if resolution in unet_params.attention_resolutions + else "DownBlock2D" + ) down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = ( + "CrossAttnUpBlock2D" + if resolution in unet_params.attention_resolutions + else "UpBlock2D" + ) up_block_types.append(block_type) resolution //= 2 @@ -248,7 +291,9 @@ def create_unet_diffusers_config(original_config, image_size: int): head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params.use_linear_in_transformer + if "use_linear_in_transformer" in unet_params + else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 @@ -329,28 +374,46 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False if sum(k.startswith("model_ema") for k in keys) > 100: print(f" | Checkpoint {path} has both EMA and non-EMA weights.") if extract_ema: - print( - ' | Extracting EMA weights (usually better for inference)' - ) + print(" | Extracting EMA weights (usually better for inference)") for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:]) + if flat_ema_key in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key + ) + elif flat_ema_key_alt in checkpoint: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key_alt + ) + else: + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + key + ) else: print( - ' | Extracting only the non-EMA weights (usually better for fine-tuning)' + " | Extracting only the non-EMA weights (usually better for fine-tuning)" ) for key in keys: - if key.startswith(unet_key): + if key.startswith("model.diffusion_model") and key in checkpoint: unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ + "time_embed.0.weight" + ] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[ + "time_embed.0.bias" + ] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[ + "time_embed.2.weight" + ] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[ + "time_embed.2.bias" + ] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] @@ -361,21 +424,39 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + num_input_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "input_blocks" in layer + } + ) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + num_middle_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "middle_block" in layer + } + ) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + num_output_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "output_blocks" in layer + } + ) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) @@ -386,29 +467,45 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.weight" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.bias" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) if len(attentions): paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) resnet_0 = middle_blocks[0] @@ -424,7 +521,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) for i in range(num_output_blocks): @@ -442,25 +543,36 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key + ] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] + index = list(output_block_list.values()).index( + ["conv.bias", "conv.weight"] + ) + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.weight" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.bias" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -473,13 +585,27 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + resnet_0_paths = renew_resnet_paths( + output_block_layers, n_shave_prefix_segments=1 + ) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_path = ".".join( + [ + "up_blocks", + str(block_id), + "resnets", + str(layer_in_block_id), + path["new"], + ] + ) new_checkpoint[new_path] = unet_state_dict[old_path] @@ -499,17 +625,29 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ + "encoder.conv_out.weight" + ] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ + "encoder.norm_out.weight" + ] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ + "encoder.norm_out.bias" + ] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ + "decoder.conv_out.weight" + ] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ + "decoder.norm_out.weight" + ] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ + "decoder.norm_out.bias" + ] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] @@ -517,31 +655,55 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + num_down_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "encoder.down" in layer + } + ) down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] + for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + num_up_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "decoder.up" in layer + } + ) up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] + for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key + ] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 @@ -550,31 +712,51 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 @@ -583,12 +765,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) return new_checkpoint @@ -630,7 +824,9 @@ def convert_ldm_bert_checkpoint(checkpoint, config): # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + hf_model.model.embed_positions.weight.data = ( + checkpoint.transformer.pos_emb.emb.weight + ) # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) @@ -644,7 +840,9 @@ def convert_ldm_bert_checkpoint(checkpoint, config): def convert_ldm_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14",cache_dir=global_cache_dir('hub')) + text_model = CLIPTextModel.from_pretrained( + "openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub") + ) keys = list(checkpoint.keys()) @@ -652,7 +850,9 @@ def convert_ldm_clip_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[ + key + ] text_model.load_state_dict(text_model_dict) @@ -660,8 +860,14 @@ def convert_ldm_clip_checkpoint(checkpoint): textenc_conversion_lst = [ - ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ( + "cond_stage_model.model.positional_embedding", + "text_model.embeddings.position_embedding.weight", + ), + ( + "cond_stage_model.model.token_embedding.weight", + "text_model.embeddings.token_embedding.weight", + ), ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), ] @@ -676,16 +882,24 @@ textenc_transformer_conversion_lst = [ (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), + ( + "token_embedding.weight", + "transformer.text_model.embeddings.token_embedding.weight", + ), + ( + "positional_embedding", + "transformer.text_model.embeddings.position_embedding.weight", + ), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) def convert_paint_by_example_checkpoint(checkpoint): - cache_dir = global_cache_dir('hub') - config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir) + cache_dir = global_cache_dir("hub") + config = CLIPVisionConfig.from_pretrained( + "openai/clip-vit-large-patch14", cache_dir=cache_dir + ) model = PaintByExampleImageEncoder(config) keys = list(checkpoint.keys()) @@ -694,7 +908,9 @@ def convert_paint_by_example_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[ + key + ] # load clip vision model.model.load_state_dict(text_model_dict) @@ -752,24 +968,32 @@ def convert_paint_by_example_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint): - cache_dir=global_cache_dir('hub') - text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir) + cache_dir = global_cache_dir("hub") + text_model = CLIPTextModel.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir + ) keys = list(checkpoint.keys()) text_model_dict = {} - if 'cond_stage_model.model.text_projection' in keys: + if "cond_stage_model.model.text_projection" in keys: d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) - elif 'cond_stage_model.model.ln_final.bias' in keys: - d_model = int(checkpoint['cond_stage_model.model.ln_final.bias'].shape[0]) + elif "cond_stage_model.model.ln_final.bias" in keys: + d_model = int(checkpoint["cond_stage_model.model.ln_final.bias"].shape[0]) else: - raise KeyError('Expected key "cond_stage_model.model.text_projection" not found in model') + raise KeyError( + 'Expected key "cond_stage_model.model.text_projection" not found in model' + ) - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + text_model_dict[ + "text_model.embeddings.position_ids" + ] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: - if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + if ( + "resblocks.23" in key + ): # Diffusers drops the final layer and only uses the penultimate layer continue if key in textenc_conversion_map: text_model_dict[textenc_conversion_map[key]] = checkpoint[key] @@ -777,18 +1001,34 @@ def convert_open_clip_checkpoint(checkpoint): new_key = key[len("cond_stage_model.model.transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][ + :d_model, : + ] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][ + d_model : d_model * 2, : + ] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][ + d_model * 2 :, : + ] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][ + d_model : d_model * 2 + ] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][ + d_model * 2 : + ] else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) text_model_dict[new_key] = checkpoint[key] @@ -796,22 +1036,33 @@ def convert_open_clip_checkpoint(checkpoint): return text_model +def replace_checkpoint_vae(checkpoint, vae_path:str): + if vae_path.endswith(".safetensors"): + vae_ckpt = load_file(vae_path) + else: + vae_ckpt = torch.load(vae_path, map_location="cpu") + state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt + for vae_key in state_dict: + new_key = f'first_stage_model.{vae_key}' + checkpoint[new_key] = state_dict[vae_key] + def load_pipeline_from_original_stable_diffusion_ckpt( - checkpoint_path:str, - original_config_file:str=None, - num_in_channels:int=None, - scheduler_type:str='pndm', - pipeline_type:str=None, - image_size:int=None, - prediction_type:str=None, - extract_ema:bool=True, - upcast_attn:bool=False, - vae:AutoencoderKL=None, - precision:torch.dtype=torch.float32, - return_generator_pipeline:bool=False, - scan_needed:bool=True, -)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]: - ''' + checkpoint_path: str, + original_config_file: str = None, + num_in_channels: int = None, + scheduler_type: str = "pndm", + pipeline_type: str = None, + image_size: int = None, + prediction_type: str = None, + extract_ema: bool = True, + upcast_attn: bool = False, + vae: AutoencoderKL = None, + vae_path: str = None, + precision: torch.dtype = torch.float32, + return_generator_pipeline: bool = False, + scan_needed:bool=True, +) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]: + """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. @@ -819,15 +1070,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt( global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is recommended that you override the default values and/or supply an `original_config_file` wherever possible. - :param checkpoint_path: Path to `.ckpt` file. - :param original_config_file: Path to `.yaml` config file corresponding to the original architecture. + :param checkpoint_path: Path to `.ckpt` file. + :param original_config_file: Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models. :param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use 768 for Stable Diffusion v2. :param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2. :param num_in_channels: The number of input channels. If `None` number of input channels will be automatically - inferred. + inferred. :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for @@ -837,10 +1088,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt( :param precision: precision to use - torch.float16, torch.float32 or torch.autocast :param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when running stable diffusion 2.1. - ''' + :param vae: A diffusers VAE to load into the pipeline. + :param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline. + """ with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") verbosity = dlogging.get_verbosity() dlogging.set_verbosity_error() @@ -850,8 +1103,13 @@ def load_pipeline_from_original_stable_diffusion_ckpt( checkpoint = torch.load(checkpoint_path) else: checkpoint = load_file(checkpoint_path) - cache_dir = global_cache_dir('hub') - pipeline_class = StableDiffusionGeneratorPipeline if return_generator_pipeline else StableDiffusionPipeline + + cache_dir = global_cache_dir("hub") + pipeline_class = ( + StableDiffusionGeneratorPipeline + if return_generator_pipeline + else StableDiffusionPipeline + ) # Sometimes models don't have the global_step item if "global_step" in checkpoint: @@ -861,36 +1119,45 @@ def load_pipeline_from_original_stable_diffusion_ckpt( global_step = None # sometimes there is a state_dict key and sometimes not - if 'state_dict' in checkpoint: + if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] upcast_attention = False if original_config_file is None: model_type = ModelManager.probe_model_type(checkpoint) - + if model_type == SDLegacyType.V2_v: - original_config_file = global_config_dir() / 'stable-diffusion' / 'v2-inference-v.yaml' + original_config_file = ( + global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml" + ) if global_step == 110000: # v2.1 needs to upcast attention upcast_attention = True - elif model_type == SDLegacyType.V2_e: - original_config_file = ( - global_config_dir() / "stable-diffusion" / "v2-inference.yaml" - ) - + elif model_type == SDLegacyType.V2_e: + original_config_file = ( + global_config_dir() / "stable-diffusion" / "v2-inference.yaml" + ) elif model_type == SDLegacyType.V1_INPAINT: - original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inpainting-inference.yaml' - + original_config_file = ( + global_config_dir() + / "stable-diffusion" + / "v1-inpainting-inference.yaml" + ) + elif model_type == SDLegacyType.V1: - original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inference.yaml' + original_config_file = ( + global_config_dir() / "stable-diffusion" / "v1-inference.yaml" + ) else: - raise Exception('Unknown checkpoint type') + raise Exception("Unknown checkpoint type") original_config = OmegaConf.load(original_config_file) if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + original_config["model"]["params"]["unet_config"]["params"][ + "in_channels" + ] = num_in_channels if ( "parameterization" in original_config["model"]["params"] @@ -947,7 +1214,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt( raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config = create_unet_diffusers_config( + original_config, image_size=image_size + ) unet_config["upcast_attention"] = upcast_attention unet = UNet2DConditionModel(**unet_config) @@ -957,28 +1226,43 @@ def load_pipeline_from_original_stable_diffusion_ckpt( unet.load_state_dict(converted_unet_checkpoint) - # Convert the VAE model, or use the one passed - if not vae: - print(' | Using checkpoint model\'s original VAE') - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + # If a replacement VAE path was specified, we'll incorporate that into + # the checkpoint model and then convert it + if vae_path: + print(f" | Converting VAE {vae_path}") + replace_checkpoint_vae(checkpoint,vae_path) + # otherwise we use the original VAE, provided that + # an externally loaded diffusers VAE was not passed + elif not vae: + print(" | Using checkpoint model's original VAE") + + if vae: + print(" | Using replacement diffusers VAE") + else: # convert the original or replacement VAE + vae_config = create_vae_diffusers_config( + original_config, image_size=image_size + ) + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + checkpoint, vae_config + ) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) - else: - print(' | Using VAE specified in config') # Convert the text model. model_type = pipeline_type if model_type is None: - model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + model_type = original_config.model.params.cond_stage_config.target.split( + "." + )[-1] if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", - subfolder="tokenizer", - cache_dir=cache_dir, - ) + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", + subfolder="tokenizer", + cache_dir=cache_dir, + ) pipe = pipeline_class( vae=vae, text_encoder=text_model, @@ -991,8 +1275,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt( ) elif model_type == "PaintByExample": vision_model = convert_paint_by_example_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir) - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir) + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", cache_dir=cache_dir + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir + ) pipe = PaintByExamplePipeline( vae=vae, image_encoder=vision_model, @@ -1001,11 +1289,18 @@ def load_pipeline_from_original_stable_diffusion_ckpt( safety_checker=None, feature_extractor=feature_extractor, ) - elif model_type in ['FrozenCLIPEmbedder','WeightedFrozenCLIPEmbedder']: + elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]: text_model = convert_ldm_clip_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14",cache_dir=cache_dir) - safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker',cache_dir=global_cache_dir("hub")) - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir) + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", cache_dir=cache_dir + ) + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", + cache_dir=global_cache_dir("hub"), + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir + ) pipe = pipeline_class( vae=vae.to(precision), text_encoder=text_model.to(precision), @@ -1018,27 +1313,33 @@ def load_pipeline_from_original_stable_diffusion_ckpt( else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) - tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",cache_dir=cache_dir) - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + tokenizer = BertTokenizerFast.from_pretrained( + "bert-base-uncased", cache_dir=cache_dir + ) + pipe = LDMTextToImagePipeline( + vqvae=vae, + bert=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) dlogging.set_verbosity(verbosity) return pipe -def convert_ckpt_to_diffuser( - checkpoint_path:Union[str,Path], - dump_path:Union[str,Path], - **kwargs, + +def convert_ckpt_to_diffusers( + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + **kwargs, ): - ''' + """ Takes all the arguments of load_pipeline_from_original_stable_diffusion_ckpt(), and in addition a path-like object indicating the location of the desired diffusers model to be written. - ''' - pipe = load_pipeline_from_original_stable_diffusion_ckpt( - checkpoint_path, - **kwargs - ) - + """ + pipe = load_pipeline_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs) + pipe.save_pretrained( dump_path, safe_serialization=is_safetensors_available(), diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 214ef022bb..ff88ae8c1c 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -19,7 +19,7 @@ import warnings from enum import Enum from pathlib import Path from shutil import move, rmtree -from typing import Any, Optional, Union, Callable +from typing import Any, Callable, Optional, Union import safetensors import safetensors.torch @@ -35,12 +35,7 @@ from picklescan.scanner import scan_file_path from ldm.invoke.devices import CPU_DEVICE from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ldm.invoke.globals import Globals, global_cache_dir -from ldm.util import ( - ask_user, - download_with_resume, - instantiate_from_config, - url_attachment_name, -) +from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name class SDLegacyType(Enum): @@ -384,15 +379,16 @@ class ModelManager(object): if not os.path.isabs(weights): weights = os.path.normpath(os.path.join(Globals.root, weights)) + # check whether this is a v2 file and force conversion + convert = Globals.ckpt_convert or self.is_v2_config(config) + # if converting automatically to diffusers, then we do the conversion and return # a diffusers pipeline - if Globals.ckpt_convert: + if convert: print( f">> Converting legacy checkpoint {model_name} into a diffusers model..." ) - from ldm.invoke.ckpt_to_diffuser import ( - load_pipeline_from_original_stable_diffusion_ckpt, - ) + from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt self.offload_model(self.current_model) if vae_config := self._choose_diffusers_vae(model_name): @@ -433,13 +429,13 @@ class ModelManager(object): weight_bytes = f.read() model_hash = self._cached_sha256(weights, weight_bytes) sd = None - + if weights.endswith(".ckpt"): self.scan_model(model_name, weights) sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu") else: sd = safetensors.torch.load(weight_bytes) - + del weight_bytes # merged models from auto11 merge board are flat for some reason if "state_dict" in sd: @@ -462,8 +458,8 @@ class ModelManager(object): vae = os.path.normpath(os.path.join(Globals.root, vae)) if os.path.exists(vae): print(f" | Loading VAE weights from: {vae}") - if vae.endswith((".ckpt",".pt")): - self.scan_model(vae,vae) + if vae.endswith((".ckpt", ".pt")): + self.scan_model(vae, vae) vae_ckpt = torch.load(vae, map_location="cpu") else: vae_ckpt = safetensors.torch.load_file(vae) @@ -547,6 +543,15 @@ class ModelManager(object): return pipeline, width, height, model_hash + def is_v2_config(self, config: Path) -> bool: + try: + mconfig = OmegaConf.load(config) + return ( + mconfig["model"]["params"]["unet_config"]["params"]["context_dim"] > 768 + ) + except: + return False + def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path: if isinstance(model_name, DictConfig) or isinstance(model_name, dict): mconfig = model_name @@ -724,7 +729,7 @@ class ModelManager(object): SDLegacyType.V2_v (V2 using 'v_prediction' prediction type) SDLegacyType.UNKNOWN """ - global_step = checkpoint.get('global_step') + global_step = checkpoint.get("global_step") state_dict = checkpoint.get("state_dict") or checkpoint try: @@ -751,14 +756,14 @@ class ModelManager(object): return SDLegacyType.UNKNOWN def heuristic_import( - self, - path_url_or_repo: str, - convert: bool = False, - model_name: str = None, - description: str = None, - model_config_file: Path = None, - commit_to_conf: Path = None, - config_file_callback: Callable[[Path],Path] = None, + self, + path_url_or_repo: str, + convert: bool = False, + model_name: str = None, + description: str = None, + model_config_file: Path = None, + commit_to_conf: Path = None, + config_file_callback: Callable[[Path], Path] = None, ) -> str: """ Accept a string which could be: @@ -833,10 +838,10 @@ class ModelManager(object): Path(thing).rglob("*.safetensors") ): if model_name := self.heuristic_import( - str(m), - convert, - commit_to_conf=commit_to_conf, - config_file_callback=config_file_callback, + str(m), + convert, + commit_to_conf=commit_to_conf, + config_file_callback=config_file_callback, ): print(f" >> {model_name} successfully imported") return model_name @@ -864,57 +869,66 @@ class ModelManager(object): # another round of heuristics to guess the correct config file. checkpoint = None - if model_path.suffix.endswith((".ckpt",".pt")): - self.scan_model(model_path,model_path) + if model_path.suffix.endswith((".ckpt", ".pt")): + self.scan_model(model_path, model_path) checkpoint = torch.load(model_path) else: checkpoint = safetensors.torch.load_file(model_path) # additional probing needed if no config file provided if model_config_file is None: - model_type = self.probe_model_type(checkpoint) - if model_type == SDLegacyType.V1: - print(" | SD-v1 model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inference.yaml" - ) - elif model_type == SDLegacyType.V1_INPAINT: - print(" | SD-v1 inpainting model detected") - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml" - ) - elif model_type == SDLegacyType.V2_v: - print( - " | SD-v2-v model detected" - ) - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" - ) - elif model_type == SDLegacyType.V2_e: - print( - " | SD-v2-e model detected" - ) - model_config_file = Path( - Globals.root, "configs/stable-diffusion/v2-inference.yaml" - ) - elif model_type == SDLegacyType.V2: - print( - f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." - ) + # Is there a like-named .yaml file in the same directory as the + # weights file? If so, we treat this as our model + if model_path.with_suffix(".yaml").exists(): + model_config_file = model_path.with_suffix(".yaml") + print(f" | Using config file {model_config_file.name}") else: - print( - f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path." - ) + model_type = self.probe_model_type(checkpoint) + if model_type == SDLegacyType.V1: + print(" | SD-v1 model detected") + model_config_file = Path( + Globals.root, "configs/stable-diffusion/v1-inference.yaml" + ) + elif model_type == SDLegacyType.V1_INPAINT: + print(" | SD-v1 inpainting model detected") + model_config_file = Path( + Globals.root, + "configs/stable-diffusion/v1-inpainting-inference.yaml", + ) + elif model_type == SDLegacyType.V2_v: + print(" | SD-v2-v model detected") + model_config_file = Path( + Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" + ) + elif model_type == SDLegacyType.V2_e: + print(" | SD-v2-e model detected") + model_config_file = Path( + Globals.root, "configs/stable-diffusion/v2-inference.yaml" + ) + elif model_type == SDLegacyType.V2: + print( + f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path." + ) + else: + print( + f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path." + ) if not model_config_file and config_file_callback: model_config_file = config_file_callback(model_path) if not model_config_file: return - if model_config_file.name.startswith('v2'): + if self.is_v2_config(model_config_file): convert = True - print( - " | This SD-v2 model will be converted to diffusers format for use" - ) + print(" | This SD-v2 model will be converted to diffusers format for use") + + # look for a custom vae + vae_path = None + for suffix in ["pt", "ckpt", "safetensors"]: + if (model_path.with_suffix(f".vae.{suffix}")).exists(): + vae_path = model_path.with_suffix(f".vae.{suffix}") + print(f" | Using VAE file {vae_path.name}") + vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse") if convert: diffuser_path = Path( @@ -923,7 +937,8 @@ class ModelManager(object): model_name = self.convert_and_import( model_path, diffusers_path=diffuser_path, - vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), + vae=vae, + vae_path=vae_path, model_name=model_name, model_description=description, original_config_file=model_config_file, @@ -941,7 +956,8 @@ class ModelManager(object): model_name=model_name, model_description=description, vae=str( - Path( + vae_path + or Path( Globals.root, "models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt", ) @@ -953,15 +969,16 @@ class ModelManager(object): return model_name def convert_and_import( - self, - ckpt_path: Path, - diffusers_path: Path, - model_name=None, - model_description=None, - vae=None, - original_config_file: Path = None, - commit_to_conf: Path = None, - scan_needed: bool=True, + self, + ckpt_path: Path, + diffusers_path: Path, + model_name=None, + model_description=None, + vae: dict = None, + vae_path: Path = None, + original_config_file: Path = None, + commit_to_conf: Path = None, + scan_needed: bool = True, ) -> str: """ Convert a legacy ckpt weights file to diffuser model and import @@ -975,7 +992,7 @@ class ModelManager(object): new_config = None - from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser + from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffusers if diffusers_path.exists(): print( @@ -990,12 +1007,13 @@ class ModelManager(object): # By passing the specified VAE to the conversion function, the autoencoder # will be built into the model rather than tacked on afterward via the config file vae_model = self._load_vae(vae) if vae else None - convert_ckpt_to_diffuser( + convert_ckpt_to_diffusers( ckpt_path, diffusers_path, extract_ema=True, original_config_file=original_config_file, vae=vae_model, + vae_path=str(vae_path) if vae_path else None, scan_needed=scan_needed, ) print( @@ -1048,7 +1066,7 @@ class ModelManager(object): # In the event that the original entry is using a custom ckpt VAE, we try to # map that VAE onto a diffuser VAE using a hard-coded dictionary. # I would prefer to do this differently: We load the ckpt model into memory, swap the - # VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped + # VAE in memory, and then pass that to convert_ckpt_to_diffusers() so that the swapped # VAE is built into the model. However, when I tried this I got obscure key errors. if vae: return vae @@ -1134,14 +1152,14 @@ class ModelManager(object): legacy_locations = [ Path( models_dir, - "CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker" + "CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker", ), Path("bert-base-uncased/models--bert-base-uncased"), Path( "openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14" ), ] - legacy_locations.extend(list(global_cache_dir("diffusers").glob('*'))) + legacy_locations.extend(list(global_cache_dir("diffusers").glob("*"))) legacy_layout = False for model in legacy_locations: legacy_layout = legacy_layout or model.exists() @@ -1185,7 +1203,7 @@ class ModelManager(object): source.unlink() else: move(source, dest) - + # now clean up by removing any empty directories empty = [ root