fix ckpt_convert module to work with dreambooth v2 models

- Discord member @marcus.llewellyn reported that some civitai 2.1-derived checkpoints were
  not converting properly (probably dreambooth-generated):
  https://discord.com/channels/1020123559063990373/1078386197589655582/1078387806122025070

- @blessedcoolant tracked this down to a missing key that was used to
  derive vector length of the CLIP model used by fetching the second
  dimension of the tensor at "cond_stage_model.model.text_projection".
  His proposed solution was to hardcode a value of 1024.

- On inspection, I found that the same second dimension can be
  recovered from key 'cond_stage_model.model.ln_final.bias', and use
  that instead. I hope this is correct; tested on multiple v1, v2 and
  inpainting models and they converted correctly.

- While debugging this, I found and fixed several other issues:

  - model download script was not pre-downloading the OpenCLIP
    text_encoder or text_tokenizer. This is fixed.
  - got rid of legacy code in `ckpt_to_diffuser.py` and replaced
    with calls into `model_manager`
  - more consistent status reporting in the CLI.
This commit is contained in:
Lincoln Stein 2023-02-23 15:43:58 -05:00
parent 2c9b29725b
commit 4f44b64052
4 changed files with 45 additions and 38 deletions

View File

@ -625,7 +625,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
completer.set_default_dir(opt.outdir) completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer, convert=False) -> str: def import_model(model_path: str, gen, opt, completer, convert=False):
""" """
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a (3) a huggingface repository id; or (4) a local directory containing a
@ -679,7 +679,7 @@ def _verify_load(model_name: str, gen) -> bool:
current_model = gen.model_name current_model = gen.model_name
try: try:
if not gen.set_model(model_name): if not gen.set_model(model_name):
return False return
except Exception as e: except Exception as e:
print(f"** model failed to load: {str(e)}") print(f"** model failed to load: {str(e)}")
print( print(
@ -706,7 +706,7 @@ def _get_model_name_and_desc(
) )
return model_name, model_description return model_name, model_description
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer) -> str: def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace("\\", "/") # windows model_name_or_path = model_name_or_path.replace("\\", "/") # windows
manager = gen.model_manager manager = gen.model_manager
ckpt_path = None ckpt_path = None
@ -740,19 +740,14 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer) ->
) )
else: else:
try: try:
model_name = import_model(model_name_or_path, gen, opt, completer, convert=True) import_model(model_name_or_path, gen, opt, completer, convert=True)
except KeyboardInterrupt: except KeyboardInterrupt:
return return
if not model_name:
print("** Conversion failed. Aborting.")
return
manager.commit(opt.conf) manager.commit(opt.conf)
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False): if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
ckpt_path.unlink(missing_ok=True) ckpt_path.unlink(missing_ok=True)
print(f"{ckpt_path} deleted") print(f"{ckpt_path} deleted")
return model_name
def del_config(model_name: str, gen, opt, completer): def del_config(model_name: str, gen, opt, completer):

View File

@ -17,16 +17,15 @@
# Original file at: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py # Original file at: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py
""" Conversion script for the LDM checkpoints. """ """ Conversion script for the LDM checkpoints. """
import os
import re import re
import torch import torch
import warnings import warnings
from pathlib import Path from pathlib import Path
from ldm.invoke.globals import ( from ldm.invoke.globals import (
Globals,
global_cache_dir, global_cache_dir,
global_config_dir, global_config_dir,
) )
from ldm.invoke.model_manager import ModelManager, SDLegacyType
from safetensors.torch import load_file from safetensors.torch import load_file
from typing import Union from typing import Union
@ -760,7 +759,12 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict = {} text_model_dict = {}
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) if 'cond_stage_model.model.text_projection' in keys:
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
elif 'cond_stage_model.model.ln_final.bias' in keys:
d_model = int(checkpoint['cond_stage_model.model.ln_final.bias'].shape[0])
else:
raise KeyError('Expected key "cond_stage_model.model.text_projection" not found in model')
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
@ -856,20 +860,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
upcast_attention = False upcast_attention = False
if original_config_file is None: if original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" model_type = ModelManager.probe_model_type(checkpoint)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if model_type == SDLegacyType.V2:
original_config_file = global_config_dir() / 'stable-diffusion' / 'v2-inference-v.yaml' original_config_file = global_config_dir() / 'stable-diffusion' / 'v2-inference-v.yaml'
if global_step == 110000: if global_step == 110000:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
elif str(checkpoint_path).lower().find('inpaint') >= 0: # brittle - please pass original_config_file parameter!
print(f' | checkpoint has "inpaint" in name, assuming an inpainting model') elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inpainting-inference.yaml' original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inpainting-inference.yaml'
else:
elif model_type == SDLegacyType.V1:
original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inference.yaml' original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inference.yaml'
else:
raise Exception('Unknown checkpoint type')
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
@ -960,7 +967,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
text_model = convert_open_clip_checkpoint(checkpoint) text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2",
subfolder="tokenizer", subfolder="tokenizer",
cache_dir=global_cache_dir('diffusers') cache_dir=cache_dir,
) )
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae,

View File

@ -191,14 +191,18 @@ def download_bert():
# --------------------------------------------- # ---------------------------------------------
def download_clip(): def download_sd1_clip():
print("Installing CLIP model...", file=sys.stderr) print("Installing SD1 clip model...", file=sys.stderr)
version = "openai/clip-vit-large-patch14" version = "openai/clip-vit-large-patch14"
print("Tokenizer...", file=sys.stderr)
download_from_hf(CLIPTokenizer, version) download_from_hf(CLIPTokenizer, version)
print("Text model...", file=sys.stderr)
download_from_hf(CLIPTextModel, version) download_from_hf(CLIPTextModel, version)
# ---------------------------------------------
def download_sd2_clip():
version = 'stabilityai/stable-diffusion-2'
print("Installing SD2 clip model...", file=sys.stderr)
download_from_hf(CLIPTokenizer, version, subfolder='tokenizer')
download_from_hf(CLIPTextModel, version, subfolder='text_encoder')
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
@ -832,7 +836,8 @@ def main():
else: else:
print("\n** DOWNLOADING SUPPORT MODELS **") print("\n** DOWNLOADING SUPPORT MODELS **")
download_bert() download_bert()
download_clip() download_sd1_clip()
download_sd2_clip()
download_realesrgan() download_realesrgan()
download_gfpgan() download_gfpgan()
download_codeformer() download_codeformer()

View File

@ -725,7 +725,7 @@ class ModelManager(object):
SDLegacyType.V1 SDLegacyType.V1
SDLegacyType.V1_INPAINT SDLegacyType.V1_INPAINT
SDLegacyType.V2 SDLegacyType.V2
UNKNOWN SDLegacyType.UNKNOWN
""" """
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
@ -785,7 +785,7 @@ class ModelManager(object):
print(f">> Probing {thing} for import") print(f">> Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")): if thing.startswith(("http:", "https:", "ftp:")):
print(f" | {thing} appears to be a URL") print(f" | {thing} appears to be a URL")
model_path = self._resolve_path( model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1" thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed ) # _resolve_path does a download if needed
@ -793,15 +793,15 @@ class ModelManager(object):
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")): elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]: if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print( print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import" f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
) )
return return
else: else:
print(f" | {thing} appears to be a checkpoint file on disk") print(f" | {thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1") model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists(): elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
print(f" | {thing} appears to be a diffusers file on disk") print(f" | {thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model( model_name = self.import_diffuser_model(
thing, thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"), vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
@ -812,13 +812,13 @@ class ModelManager(object):
elif Path(thing).is_dir(): elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists(): if (Path(thing) / "model_index.json").exists():
print(f">> {thing} appears to be a diffusers model.") print(f" | {thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model( model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf thing, commit_to_conf=commit_to_conf
) )
else: else:
print( print(
f">> {thing} appears to be a directory. Will scan for models to import" f" |{thing} appears to be a directory. Will scan for models to import"
) )
for m in list(Path(thing).rglob("*.ckpt")) + list( for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors") Path(thing).rglob("*.safetensors")
@ -830,7 +830,7 @@ class ModelManager(object):
return model_name return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing): elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
print(f" | {thing} appears to be a HuggingFace diffusers repo_id") print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model( model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf thing, commit_to_conf=commit_to_conf
) )
@ -847,7 +847,7 @@ class ModelManager(object):
return return
if model_path.stem in self.config: # already imported if model_path.stem in self.config: # already imported
print(" | Already imported. Skipping") print(" | Already imported. Skipping")
return return
# another round of heuristics to guess the correct config file. # another round of heuristics to guess the correct config file.
@ -860,18 +860,18 @@ class ModelManager(object):
model_config_file = None model_config_file = None
if model_type == SDLegacyType.V1: if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected") print(" | SD-v1 model detected")
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml" Globals.root, "configs/stable-diffusion/v1-inference.yaml"
) )
elif model_type == SDLegacyType.V1_INPAINT: elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected") print(" | SD-v1 inpainting model detected")
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml" Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
) )
elif model_type == SDLegacyType.V2: elif model_type == SDLegacyType.V2:
print( print(
" | SD-v2 model detected; model will be converted to diffusers format" " | SD-v2 model detected; model will be converted to diffusers format"
) )
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
@ -923,7 +923,7 @@ class ModelManager(object):
vae=None, vae=None,
original_config_file: Path = None, original_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
) -> dict: ) -> str:
""" """
Convert a legacy ckpt weights file to diffuser model and import Convert a legacy ckpt weights file to diffuser model and import
into models.yaml. into models.yaml.