mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
if importing a v2 ckpt model, convert to diffusers
This commit is contained in:
parent
717d53a773
commit
4f7af55bc3
@ -646,6 +646,13 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
|
|||||||
|
|
||||||
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
|
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
|
|
||||||
|
if not (config_file := _ask_for_config_file(path_or_url, completer)):
|
||||||
|
return
|
||||||
|
if config_file.stem == 'v2-inference-v':
|
||||||
|
print('** InvokeAI cannot run SD 2.X checkpoints directly. Model will be converted into diffusers format')
|
||||||
|
return optimize_model(path_or_url, gen, opt, completer, config_file)
|
||||||
|
|
||||||
default_name = Path(path_or_url).stem
|
default_name = Path(path_or_url).stem
|
||||||
default_description = f'Imported model {default_name}'
|
default_description = f'Imported model {default_name}'
|
||||||
model_name, model_description = _get_model_name_and_desc(
|
model_name, model_description = _get_model_name_and_desc(
|
||||||
@ -654,8 +661,7 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
|
|||||||
model_name=default_name,
|
model_name=default_name,
|
||||||
model_description=default_description
|
model_description=default_description
|
||||||
)
|
)
|
||||||
if not (config_file := _ask_for_config_file(path_or_url, completer)):
|
|
||||||
return
|
|
||||||
completer.complete_extensions(('.ckpt','.safetensors'))
|
completer.complete_extensions(('.ckpt','.safetensors'))
|
||||||
vae = None
|
vae = None
|
||||||
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
|
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
|
||||||
@ -713,7 +719,7 @@ def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
|
|||||||
'2': 'v2-inference-v.yaml',
|
'2': 'v2-inference-v.yaml',
|
||||||
'3': 'v1-inpainting-inference.yaml',
|
'3': 'v1-inpainting-inference.yaml',
|
||||||
}
|
}
|
||||||
print(f'''What type of model is this?:
|
print('''What type of model is this?:
|
||||||
[1] A model based on Stable Diffusion 1.X
|
[1] A model based on Stable Diffusion 1.X
|
||||||
[2] A model based on Stable Diffusion 2.X
|
[2] A model based on Stable Diffusion 2.X
|
||||||
[3] An inpainting model based on Stable Diffusion 1.X
|
[3] An inpainting model based on Stable Diffusion 1.X
|
||||||
@ -733,10 +739,9 @@ def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
|
|||||||
return config_path
|
return config_path
|
||||||
|
|
||||||
|
|
||||||
def optimize_model(model_name_or_path:str, gen, opt, completer):
|
def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_file: Path=None):
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
ckpt_path = None
|
ckpt_path = None
|
||||||
original_config_file = None
|
|
||||||
|
|
||||||
if model_name_or_path == gen.model_name:
|
if model_name_or_path == gen.model_name:
|
||||||
print("** Can't convert the active model. !switch to another model first. **")
|
print("** Can't convert the active model. !switch to another model first. **")
|
||||||
@ -751,6 +756,9 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
|||||||
print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
|
print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
|
||||||
return
|
return
|
||||||
elif os.path.exists(model_name_or_path):
|
elif os.path.exists(model_name_or_path):
|
||||||
|
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
|
||||||
|
if not original_config_file:
|
||||||
|
return
|
||||||
ckpt_path = Path(model_name_or_path)
|
ckpt_path = Path(model_name_or_path)
|
||||||
model_name, model_description = _get_model_name_and_desc(
|
model_name, model_description = _get_model_name_and_desc(
|
||||||
manager,
|
manager,
|
||||||
@ -758,9 +766,6 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
|||||||
ckpt_path.stem,
|
ckpt_path.stem,
|
||||||
f'Converted model {ckpt_path.stem}'
|
f'Converted model {ckpt_path.stem}'
|
||||||
)
|
)
|
||||||
original_config_file = _ask_for_config_file(model_name_or_path, completer)
|
|
||||||
if not original_config_file:
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user