if importing a v2 ckpt model, convert to diffusers

This commit is contained in:
Lincoln Stein 2023-02-11 16:35:45 -05:00
parent 717d53a773
commit 4f7af55bc3

View File

@ -646,6 +646,13 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]: def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
manager = gen.model_manager manager = gen.model_manager
if not (config_file := _ask_for_config_file(path_or_url, completer)):
return
if config_file.stem == 'v2-inference-v':
print('** InvokeAI cannot run SD 2.X checkpoints directly. Model will be converted into diffusers format')
return optimize_model(path_or_url, gen, opt, completer, config_file)
default_name = Path(path_or_url).stem default_name = Path(path_or_url).stem
default_description = f'Imported model {default_name}' default_description = f'Imported model {default_name}'
model_name, model_description = _get_model_name_and_desc( model_name, model_description = _get_model_name_and_desc(
@ -654,8 +661,7 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
model_name=default_name, model_name=default_name,
model_description=default_description model_description=default_description
) )
if not (config_file := _ask_for_config_file(path_or_url, completer)):
return
completer.complete_extensions(('.ckpt','.safetensors')) completer.complete_extensions(('.ckpt','.safetensors'))
vae = None vae = None
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt') default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
@ -713,7 +719,7 @@ def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
'2': 'v2-inference-v.yaml', '2': 'v2-inference-v.yaml',
'3': 'v1-inpainting-inference.yaml', '3': 'v1-inpainting-inference.yaml',
} }
print(f'''What type of model is this?: print('''What type of model is this?:
[1] A model based on Stable Diffusion 1.X [1] A model based on Stable Diffusion 1.X
[2] A model based on Stable Diffusion 2.X [2] A model based on Stable Diffusion 2.X
[3] An inpainting model based on Stable Diffusion 1.X [3] An inpainting model based on Stable Diffusion 1.X
@ -733,10 +739,9 @@ def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
return config_path return config_path
def optimize_model(model_name_or_path:str, gen, opt, completer): def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_file: Path=None):
manager = gen.model_manager manager = gen.model_manager
ckpt_path = None ckpt_path = None
original_config_file = None
if model_name_or_path == gen.model_name: if model_name_or_path == gen.model_name:
print("** Can't convert the active model. !switch to another model first. **") print("** Can't convert the active model. !switch to another model first. **")
@ -751,6 +756,9 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
print(f'** {model_name_or_path} is not a legacy .ckpt weights file') print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
return return
elif os.path.exists(model_name_or_path): elif os.path.exists(model_name_or_path):
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
if not original_config_file:
return
ckpt_path = Path(model_name_or_path) ckpt_path = Path(model_name_or_path)
model_name, model_description = _get_model_name_and_desc( model_name, model_description = _get_model_name_and_desc(
manager, manager,
@ -758,9 +766,6 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
ckpt_path.stem, ckpt_path.stem,
f'Converted model {ckpt_path.stem}' f'Converted model {ckpt_path.stem}'
) )
original_config_file = _ask_for_config_file(model_name_or_path, completer)
if not original_config_file:
return
else: else:
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file') print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
return return