mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix ckpt_convert module to work with dreambooth v2 models (#2776)
- Discord member @marcus.llewellyn reported that some civitai 2.1-derived checkpoints were not converting properly (probably dreambooth-generated): https://discord.com/channels/1020123559063990373/1078386197589655582/1078387806122025070 - @blessedcoolant tracked this down to a missing key that was used to derive vector length of the CLIP model used by fetching the second dimension of the tensor at "cond_stage_model.model.text_projection". - On inspection, I found that the same second dimension can be recovered from key 'cond_stage_model.model.ln_final.bias', and use that instead. I hope this is correct; tested on multiple v1, v2 and inpainting models and they converted correctly. - While debugging this, I found and fixed several other issues: - model download script was not pre-downloading the OpenCLIP text_encoder or text_tokenizer. This is fixed. - got rid of legacy code in `ckpt_to_diffuser.py` and replaced with calls into `model_manager` - more consistent status reporting in the CLI.
This commit is contained in:
commit
c69fcb1c10
@ -625,7 +625,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
|
||||
completer.set_default_dir(opt.outdir)
|
||||
|
||||
|
||||
def import_model(model_path: str, gen, opt, completer, convert=False) -> str:
|
||||
def import_model(model_path: str, gen, opt, completer, convert=False):
|
||||
"""
|
||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
|
||||
(3) a huggingface repository id; or (4) a local directory containing a
|
||||
@ -679,7 +679,7 @@ def _verify_load(model_name: str, gen) -> bool:
|
||||
current_model = gen.model_name
|
||||
try:
|
||||
if not gen.set_model(model_name):
|
||||
return False
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"** model failed to load: {str(e)}")
|
||||
print(
|
||||
@ -706,7 +706,7 @@ def _get_model_name_and_desc(
|
||||
)
|
||||
return model_name, model_description
|
||||
|
||||
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer) -> str:
|
||||
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
model_name_or_path = model_name_or_path.replace("\\", "/") # windows
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
@ -740,19 +740,14 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer) ->
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_name = import_model(model_name_or_path, gen, opt, completer, convert=True)
|
||||
import_model(model_name_or_path, gen, opt, completer, convert=True)
|
||||
except KeyboardInterrupt:
|
||||
return
|
||||
|
||||
if not model_name:
|
||||
print("** Conversion failed. Aborting.")
|
||||
return
|
||||
|
||||
manager.commit(opt.conf)
|
||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f"{ckpt_path} deleted")
|
||||
return model_name
|
||||
|
||||
|
||||
def del_config(model_name: str, gen, opt, completer):
|
||||
|
@ -17,16 +17,15 @@
|
||||
# Original file at: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from ldm.invoke.globals import (
|
||||
Globals,
|
||||
global_cache_dir,
|
||||
global_config_dir,
|
||||
)
|
||||
from ldm.invoke.model_manager import ModelManager, SDLegacyType
|
||||
from safetensors.torch import load_file
|
||||
from typing import Union
|
||||
|
||||
@ -760,7 +759,12 @@ def convert_open_clip_checkpoint(checkpoint):
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
||||
if 'cond_stage_model.model.text_projection' in keys:
|
||||
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
||||
elif 'cond_stage_model.model.ln_final.bias' in keys:
|
||||
d_model = int(checkpoint['cond_stage_model.model.ln_final.bias'].shape[0])
|
||||
else:
|
||||
raise KeyError('Expected key "cond_stage_model.model.text_projection" not found in model')
|
||||
|
||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||
|
||||
@ -856,20 +860,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
upcast_attention = False
|
||||
if original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
model_type = ModelManager.probe_model_type(checkpoint)
|
||||
|
||||
if model_type == SDLegacyType.V2:
|
||||
original_config_file = global_config_dir() / 'stable-diffusion' / 'v2-inference-v.yaml'
|
||||
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif str(checkpoint_path).lower().find('inpaint') >= 0: # brittle - please pass original_config_file parameter!
|
||||
print(f' | checkpoint has "inpaint" in name, assuming an inpainting model')
|
||||
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inpainting-inference.yaml'
|
||||
else:
|
||||
|
||||
elif model_type == SDLegacyType.V1:
|
||||
original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inference.yaml'
|
||||
|
||||
else:
|
||||
raise Exception('Unknown checkpoint type')
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
@ -960,7 +967,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2",
|
||||
subfolder="tokenizer",
|
||||
cache_dir=global_cache_dir('diffusers')
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
|
@ -191,14 +191,18 @@ def download_bert():
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_clip():
|
||||
print("Installing CLIP model...", file=sys.stderr)
|
||||
def download_sd1_clip():
|
||||
print("Installing SD1 clip model...", file=sys.stderr)
|
||||
version = "openai/clip-vit-large-patch14"
|
||||
print("Tokenizer...", file=sys.stderr)
|
||||
download_from_hf(CLIPTokenizer, version)
|
||||
print("Text model...", file=sys.stderr)
|
||||
download_from_hf(CLIPTextModel, version)
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_sd2_clip():
|
||||
version = 'stabilityai/stable-diffusion-2'
|
||||
print("Installing SD2 clip model...", file=sys.stderr)
|
||||
download_from_hf(CLIPTokenizer, version, subfolder='tokenizer')
|
||||
download_from_hf(CLIPTextModel, version, subfolder='text_encoder')
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_realesrgan():
|
||||
@ -832,7 +836,8 @@ def main():
|
||||
else:
|
||||
print("\n** DOWNLOADING SUPPORT MODELS **")
|
||||
download_bert()
|
||||
download_clip()
|
||||
download_sd1_clip()
|
||||
download_sd2_clip()
|
||||
download_realesrgan()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
|
@ -725,7 +725,7 @@ class ModelManager(object):
|
||||
SDLegacyType.V1
|
||||
SDLegacyType.V1_INPAINT
|
||||
SDLegacyType.V2
|
||||
UNKNOWN
|
||||
SDLegacyType.UNKNOWN
|
||||
"""
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
@ -785,7 +785,7 @@ class ModelManager(object):
|
||||
print(f">> Probing {thing} for import")
|
||||
|
||||
if thing.startswith(("http:", "https:", "ftp:")):
|
||||
print(f" | {thing} appears to be a URL")
|
||||
print(f" | {thing} appears to be a URL")
|
||||
model_path = self._resolve_path(
|
||||
thing, "models/ldm/stable-diffusion-v1"
|
||||
) # _resolve_path does a download if needed
|
||||
@ -793,15 +793,15 @@ class ModelManager(object):
|
||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||
print(
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
)
|
||||
return
|
||||
else:
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
@ -812,13 +812,13 @@ class ModelManager(object):
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
if (Path(thing) / "model_index.json").exists():
|
||||
print(f">> {thing} appears to be a diffusers model.")
|
||||
print(f" | {thing} appears to be a diffusers model.")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f">> {thing} appears to be a directory. Will scan for models to import"
|
||||
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||
)
|
||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||
Path(thing).rglob("*.safetensors")
|
||||
@ -830,7 +830,7 @@ class ModelManager(object):
|
||||
return model_name
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
@ -847,7 +847,7 @@ class ModelManager(object):
|
||||
return
|
||||
|
||||
if model_path.stem in self.config: # already imported
|
||||
print(" | Already imported. Skipping")
|
||||
print(" | Already imported. Skipping")
|
||||
return
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
@ -860,18 +860,18 @@ class ModelManager(object):
|
||||
|
||||
model_config_file = None
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
print(" | SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(
|
||||
" | SD-v2 model detected; model will be converted to diffusers format"
|
||||
" | SD-v2 model detected; model will be converted to diffusers format"
|
||||
)
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
@ -923,7 +923,7 @@ class ModelManager(object):
|
||||
vae=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> dict:
|
||||
) -> str:
|
||||
"""
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
into models.yaml.
|
||||
|
Loading…
Reference in New Issue
Block a user