Merge branch 'v2.3' into bugfix/webui-accurate-intermediates

This commit is contained in:
Lincoln Stein 2023-02-23 22:07:18 -05:00 committed by GitHub
commit 2d990c1f54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 38 deletions

View File

@ -625,7 +625,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
completer.set_default_dir(opt.outdir) completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer, convert=False) -> str: def import_model(model_path: str, gen, opt, completer, convert=False):
""" """
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a (3) a huggingface repository id; or (4) a local directory containing a
@ -679,7 +679,7 @@ def _verify_load(model_name: str, gen) -> bool:
current_model = gen.model_name current_model = gen.model_name
try: try:
if not gen.set_model(model_name): if not gen.set_model(model_name):
return False return
except Exception as e: except Exception as e:
print(f"** model failed to load: {str(e)}") print(f"** model failed to load: {str(e)}")
print( print(
@ -706,7 +706,7 @@ def _get_model_name_and_desc(
) )
return model_name, model_description return model_name, model_description
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer) -> str: def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace("\\", "/") # windows model_name_or_path = model_name_or_path.replace("\\", "/") # windows
manager = gen.model_manager manager = gen.model_manager
ckpt_path = None ckpt_path = None
@ -740,19 +740,14 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer) ->
) )
else: else:
try: try:
model_name = import_model(model_name_or_path, gen, opt, completer, convert=True) import_model(model_name_or_path, gen, opt, completer, convert=True)
except KeyboardInterrupt: except KeyboardInterrupt:
return return
if not model_name:
print("** Conversion failed. Aborting.")
return
manager.commit(opt.conf) manager.commit(opt.conf)
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False): if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
ckpt_path.unlink(missing_ok=True) ckpt_path.unlink(missing_ok=True)
print(f"{ckpt_path} deleted") print(f"{ckpt_path} deleted")
return model_name
def del_config(model_name: str, gen, opt, completer): def del_config(model_name: str, gen, opt, completer):

View File

@ -17,16 +17,15 @@
# Original file at: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py # Original file at: https://github.com/huggingface/diffusers/blob/main/scripts/convert_ldm_original_checkpoint_to_diffusers.py
""" Conversion script for the LDM checkpoints. """ """ Conversion script for the LDM checkpoints. """
import os
import re import re
import torch import torch
import warnings import warnings
from pathlib import Path from pathlib import Path
from ldm.invoke.globals import ( from ldm.invoke.globals import (
Globals,
global_cache_dir, global_cache_dir,
global_config_dir, global_config_dir,
) )
from ldm.invoke.model_manager import ModelManager, SDLegacyType
from safetensors.torch import load_file from safetensors.torch import load_file
from typing import Union from typing import Union
@ -760,7 +759,12 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict = {} text_model_dict = {}
if 'cond_stage_model.model.text_projection' in keys:
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
elif 'cond_stage_model.model.ln_final.bias' in keys:
d_model = int(checkpoint['cond_stage_model.model.ln_final.bias'].shape[0])
else:
raise KeyError('Expected key "cond_stage_model.model.text_projection" not found in model')
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
@ -856,20 +860,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
upcast_attention = False upcast_attention = False
if original_config_file is None: if original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" model_type = ModelManager.probe_model_type(checkpoint)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if model_type == SDLegacyType.V2:
original_config_file = global_config_dir() / 'stable-diffusion' / 'v2-inference-v.yaml' original_config_file = global_config_dir() / 'stable-diffusion' / 'v2-inference-v.yaml'
if global_step == 110000: if global_step == 110000:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
elif str(checkpoint_path).lower().find('inpaint') >= 0: # brittle - please pass original_config_file parameter!
print(f' | checkpoint has "inpaint" in name, assuming an inpainting model') elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inpainting-inference.yaml' original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inpainting-inference.yaml'
else:
elif model_type == SDLegacyType.V1:
original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inference.yaml' original_config_file = global_config_dir() / 'stable-diffusion' / 'v1-inference.yaml'
else:
raise Exception('Unknown checkpoint type')
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None: if num_in_channels is not None:
@ -960,7 +967,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
text_model = convert_open_clip_checkpoint(checkpoint) text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2",
subfolder="tokenizer", subfolder="tokenizer",
cache_dir=global_cache_dir('diffusers') cache_dir=cache_dir,
) )
pipe = pipeline_class( pipe = pipeline_class(
vae=vae, vae=vae,

View File

@ -191,14 +191,18 @@ def download_bert():
# --------------------------------------------- # ---------------------------------------------
def download_clip(): def download_sd1_clip():
print("Installing CLIP model...", file=sys.stderr) print("Installing SD1 clip model...", file=sys.stderr)
version = "openai/clip-vit-large-patch14" version = "openai/clip-vit-large-patch14"
print("Tokenizer...", file=sys.stderr)
download_from_hf(CLIPTokenizer, version) download_from_hf(CLIPTokenizer, version)
print("Text model...", file=sys.stderr)
download_from_hf(CLIPTextModel, version) download_from_hf(CLIPTextModel, version)
# ---------------------------------------------
def download_sd2_clip():
version = 'stabilityai/stable-diffusion-2'
print("Installing SD2 clip model...", file=sys.stderr)
download_from_hf(CLIPTokenizer, version, subfolder='tokenizer')
download_from_hf(CLIPTextModel, version, subfolder='text_encoder')
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
@ -832,7 +836,8 @@ def main():
else: else:
print("\n** DOWNLOADING SUPPORT MODELS **") print("\n** DOWNLOADING SUPPORT MODELS **")
download_bert() download_bert()
download_clip() download_sd1_clip()
download_sd2_clip()
download_realesrgan() download_realesrgan()
download_gfpgan() download_gfpgan()
download_codeformer() download_codeformer()

View File

@ -725,7 +725,7 @@ class ModelManager(object):
SDLegacyType.V1 SDLegacyType.V1
SDLegacyType.V1_INPAINT SDLegacyType.V1_INPAINT
SDLegacyType.V2 SDLegacyType.V2
UNKNOWN SDLegacyType.UNKNOWN
""" """
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
@ -812,13 +812,13 @@ class ModelManager(object):
elif Path(thing).is_dir(): elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists(): if (Path(thing) / "model_index.json").exists():
print(f">> {thing} appears to be a diffusers model.") print(f" | {thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model( model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf thing, commit_to_conf=commit_to_conf
) )
else: else:
print( print(
f">> {thing} appears to be a directory. Will scan for models to import" f" |{thing} appears to be a directory. Will scan for models to import"
) )
for m in list(Path(thing).rglob("*.ckpt")) + list( for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors") Path(thing).rglob("*.safetensors")
@ -923,7 +923,7 @@ class ModelManager(object):
vae=None, vae=None,
original_config_file: Path = None, original_config_file: Path = None,
commit_to_conf: Path = None, commit_to_conf: Path = None,
) -> dict: ) -> str:
""" """
Convert a legacy ckpt weights file to diffuser model and import Convert a legacy ckpt weights file to diffuser model and import
into models.yaml. into models.yaml.