mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix bugs in online ckpt conversion of 2.0 models
This commit fixes bugs related to the on-the-fly conversion and loading of legacy checkpoint models built on SD-2.0 base. - When legacy checkpoints built on SD-2.0 models were converted on-the-fly using --ckpt_convert, generation would crash with a precision incompatibility error. - In addition, broken logic was causing some 2.0-derived ckpt files to be converted into diffusers and then processed through the legacy generation routines - not good.
This commit is contained in:
parent
8e2fd4c96a
commit
41a8fdea53
@ -772,11 +772,11 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
||||
original_config_file = Path(model_info["config"])
|
||||
model_name = model_name_or_path
|
||||
model_description = model_info["description"]
|
||||
vae = model_info["vae"]
|
||||
vae = model_info.get("vae")
|
||||
else:
|
||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||
return
|
||||
if vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem):
|
||||
if vae and (vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem)):
|
||||
vae_repo = dict(repo_id=vae_repo)
|
||||
else:
|
||||
vae_repo = None
|
||||
|
@ -1264,10 +1264,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
|
@ -172,9 +172,9 @@ class ModelManager(object):
|
||||
"""
|
||||
# if we are converting legacy files automatically, then
|
||||
# there are no legacy ckpts!
|
||||
if Globals.ckpt_convert:
|
||||
return False
|
||||
info = self.model_info(model_name)
|
||||
if Globals.ckpt_convert or info.format=='diffusers' or self.is_v2_config(info.config):
|
||||
return False
|
||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
||||
return True
|
||||
return False
|
||||
@ -544,6 +544,8 @@ class ModelManager(object):
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def is_v2_config(self, config: Path) -> bool:
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root, config)
|
||||
try:
|
||||
mconfig = OmegaConf.load(config)
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user