mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
if importing a v2 ckpt model, convert to diffusers
This commit is contained in:
parent
717d53a773
commit
4f7af55bc3
@ -646,6 +646,13 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
|
||||
|
||||
def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Optional[str]:
|
||||
manager = gen.model_manager
|
||||
|
||||
if not (config_file := _ask_for_config_file(path_or_url, completer)):
|
||||
return
|
||||
if config_file.stem == 'v2-inference-v':
|
||||
print('** InvokeAI cannot run SD 2.X checkpoints directly. Model will be converted into diffusers format')
|
||||
return optimize_model(path_or_url, gen, opt, completer, config_file)
|
||||
|
||||
default_name = Path(path_or_url).stem
|
||||
default_description = f'Imported model {default_name}'
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
@ -654,8 +661,7 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
if not (config_file := _ask_for_config_file(path_or_url, completer)):
|
||||
return
|
||||
|
||||
completer.complete_extensions(('.ckpt','.safetensors'))
|
||||
vae = None
|
||||
default = Path(Globals.root,'models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt')
|
||||
@ -713,7 +719,7 @@ def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
|
||||
'2': 'v2-inference-v.yaml',
|
||||
'3': 'v1-inpainting-inference.yaml',
|
||||
}
|
||||
print(f'''What type of model is this?:
|
||||
print('''What type of model is this?:
|
||||
[1] A model based on Stable Diffusion 1.X
|
||||
[2] A model based on Stable Diffusion 2.X
|
||||
[3] An inpainting model based on Stable Diffusion 1.X
|
||||
@ -733,10 +739,9 @@ def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
|
||||
return config_path
|
||||
|
||||
|
||||
def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_file: Path=None):
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
original_config_file = None
|
||||
|
||||
if model_name_or_path == gen.model_name:
|
||||
print("** Can't convert the active model. !switch to another model first. **")
|
||||
@ -751,6 +756,9 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
print(f'** {model_name_or_path} is not a legacy .ckpt weights file')
|
||||
return
|
||||
elif os.path.exists(model_name_or_path):
|
||||
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
|
||||
if not original_config_file:
|
||||
return
|
||||
ckpt_path = Path(model_name_or_path)
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
@ -758,9 +766,6 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
ckpt_path.stem,
|
||||
f'Converted model {ckpt_path.stem}'
|
||||
)
|
||||
original_config_file = _ask_for_config_file(model_name_or_path, completer)
|
||||
if not original_config_file:
|
||||
return
|
||||
else:
|
||||
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user