From 8a31e5c5e3dd7be361fdd984170fd1bac801712a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 17 Jan 2023 00:18:09 -0500 Subject: [PATCH] allow safetensors models to be imported --- ldm/invoke/CLI.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index cfe9a64ed5..83b5281847 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -572,7 +572,7 @@ def import_model(model_path:str, gen, opt, completer): if model_path.startswith(('http:','https:','ftp:')): model_name = import_ckpt_model(model_path, gen, opt, completer) - elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path): + elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path): model_name = import_ckpt_model(model_path, gen, opt, completer) elif re.match('^[\w.+-]+/[\w.+-]+$',model_path): model_name = import_diffuser_model(model_path, gen, opt, completer) @@ -628,9 +628,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str: model_description=default_description ) config_file = None - + default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml') completer.complete_extensions(('.yaml','.yml')) - completer.set_line('configs/stable-diffusion/v1-inference.yaml') + completer.set_line(str(default)) done = False while not done: config_file = input('Configuration file for this model: ').strip()