mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[enhancement] import .safetensors ckpt files directly (#2353)
This small fix makes it possible to import and run safetensors ckpt files directly without doing a conversion step first.
This commit is contained in:
commit
87c9398266
@ -573,7 +573,7 @@ def import_model(model_path:str, gen, opt, completer):
|
|||||||
|
|
||||||
if model_path.startswith(('http:','https:','ftp:')):
|
if model_path.startswith(('http:','https:','ftp:')):
|
||||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||||
@ -627,9 +627,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
|||||||
model_description=default_description
|
model_description=default_description
|
||||||
)
|
)
|
||||||
config_file = None
|
config_file = None
|
||||||
|
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||||
completer.complete_extensions(('.yaml','.yml'))
|
completer.complete_extensions(('.yaml','.yml'))
|
||||||
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
completer.set_line(str(default))
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
config_file = input('Configuration file for this model: ').strip()
|
config_file = input('Configuration file for this model: ').strip()
|
||||||
|
@ -147,7 +147,7 @@ class ModelManager(object):
|
|||||||
Return true if this is a legacy (.ckpt) model
|
Return true if this is a legacy (.ckpt) model
|
||||||
'''
|
'''
|
||||||
info = self.model_info(model_name)
|
info = self.model_info(model_name)
|
||||||
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
if 'weights' in info and info['weights'].endswith(('.ckpt','.safetensors')):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -366,8 +366,14 @@ class ModelManager(object):
|
|||||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||||
if os.path.exists(vae):
|
if os.path.exists(vae):
|
||||||
print(f' | Loading VAE weights from: {vae}')
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
|
vae_ckpt = None
|
||||||
|
vae_dict = None
|
||||||
|
if vae.endswith('.safetensors'):
|
||||||
|
vae_ckpt = safetensors.torch.load_file(vae)
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||||
|
else:
|
||||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
|
||||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
print(f' | VAE file {vae} not found. Skipping.')
|
print(f' | VAE file {vae} not found. Skipping.')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user