allow safetensors models to be imported

This commit is contained in:
Lincoln Stein 2023-01-17 00:18:09 -05:00
parent 563196bd03
commit 8a31e5c5e3

View File

@ -572,7 +572,7 @@ def import_model(model_path:str, gen, opt, completer):
if model_path.startswith(('http:','https:','ftp:')):
model_name = import_ckpt_model(model_path, gen, opt, completer)
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
model_name = import_ckpt_model(model_path, gen, opt, completer)
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
model_name = import_diffuser_model(model_path, gen, opt, completer)
@ -628,9 +628,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
model_description=default_description
)
config_file = None
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
completer.complete_extensions(('.yaml','.yml'))
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
completer.set_line(str(default))
done = False
while not done:
config_file = input('Configuration file for this model: ').strip()