Prevent crash when converting models from within CLI using legacy model URL (#2846)

- Crash would occur at the end of this sequence:
  - launch CLI
  - !convert <URL pointing to a legacy ckpt file>
  - Answer "Y" when asked to delete original .ckpt file

- This commit modifies model_manager.heuristic_import() to silently
delete the downloaded legacy file after it has been converted into a
diffusers model. The user is no longer asked to approve deletion.

NB: This should be cherry-picked into main once refactor is done.
This commit is contained in:
Lincoln Stein 2023-03-07 00:09:11 -05:00 committed by GitHub
commit 68c2722c02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View File

@ -744,8 +744,8 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
except KeyboardInterrupt:
return
manager.commit(opt.conf)
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
manager.commit(opt.conf)
if ckpt_path and click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
ckpt_path.unlink(missing_ok=True)
print(f"{ckpt_path} deleted")

View File

@ -781,6 +781,7 @@ class ModelManager(object):
"""
model_path: Path = None
thing = path_url_or_repo # to save typing
is_temporary = False
print(f">> Probing {thing} for import")
@ -789,7 +790,7 @@ class ModelManager(object):
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
is_temporary = True
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
print(
@ -896,6 +897,10 @@ class ModelManager(object):
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
)
# in the event that this file was downloaded automatically prior to conversion
# we do not keep the original .ckpt/.safetensors around
if is_temporary:
model_path.unlink(missing_ok=True)
else:
model_name = self.import_ckpt_model(
model_path,