fix bugs in online ckpt conversion of 2.0 models

This commit fixes bugs related to the on-the-fly conversion and loading of
legacy checkpoint models built on SD-2.0 base.

- When legacy checkpoints built on SD-2.0 models were converted
  on-the-fly using --ckpt_convert, generation would crash with a
  precision incompatibility error.

- In addition, broken logic was causing some 2.0-derived ckpt files to
  be converted into diffusers and then processed through the legacy
  generation routines - not good.
This commit is contained in:
Lincoln Stein 2023-03-28 00:11:37 -04:00
parent 8e2fd4c96a
commit 41a8fdea53
3 changed files with 9 additions and 7 deletions

View File

@ -772,11 +772,11 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
original_config_file = Path(model_info["config"])
model_name = model_name_or_path
model_description = model_info["description"]
vae = model_info["vae"]
vae = model_info.get("vae")
else:
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return
if vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem):
if vae and (vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem)):
vae_repo = dict(repo_id=vae_repo)
else:
vae_repo = None

View File

@ -1264,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
cache_dir=cache_dir,
)
pipe = pipeline_class(
vae=vae,
text_encoder=text_model,
vae=vae.to(precision),
text_encoder=text_model.to(precision),
tokenizer=tokenizer,
unet=unet,
unet=unet.to(precision),
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,

View File

@ -172,9 +172,9 @@ class ModelManager(object):
"""
# if we are converting legacy files automatically, then
# there are no legacy ckpts!
if Globals.ckpt_convert:
return False
info = self.model_info(model_name)
if Globals.ckpt_convert or info.format=='diffusers' or self.is_v2_config(info.config):
return False
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
return True
return False
@ -544,6 +544,8 @@ class ModelManager(object):
return pipeline, width, height, model_hash
def is_v2_config(self, config: Path) -> bool:
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
try:
mconfig = OmegaConf.load(config)
return (