improve support for V2 variant legacy checkpoints (#2926)

This commit enhances support for V2 variant (epsilon and v-predict)
import and conversion to diffusers, by prompting the user to select the
proper config file during startup time autoimport as well as in the
invokeai installer script. Previously the user was only prompted when
doing an `!import` from the command line or when using the WebUI Model
Manager.
This commit is contained in:
Lincoln Stein 2023-03-11 20:54:01 -05:00 committed by GitHub
commit 8d80802a35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 22 deletions

View File

@ -158,10 +158,16 @@ def main():
except Exception as e:
report_model_error(opt, e)
# completer is the readline object
completer = get_completer(opt, models=gen.model_manager.list_models())
# try to autoconvert new models
if path := opt.autoimport:
gen.model_manager.heuristic_import(
str(path), convert=False, commit_to_conf=opt.conf
str(path),
convert=False,
commit_to_conf=opt.conf,
config_file_callback=lambda x: _pick_configuration_file(completer,x),
)
if path := opt.autoconvert:
@ -180,7 +186,7 @@ def main():
)
try:
main_loop(gen, opt)
main_loop(gen, opt, completer)
except KeyboardInterrupt:
print(
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
@ -191,7 +197,7 @@ def main():
# TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt):
def main_loop(gen, opt, completer):
"""prompt/read/execute loop"""
global infile
done = False
@ -202,7 +208,6 @@ def main_loop(gen, opt):
# The readline completer reads history from the .dream_history file located in the
# output directory specified at the time of script launch. We do not currently support
# changing the history file midstream when the output directory is changed.
completer = get_completer(opt, models=gen.model_manager.list_models())
set_default_output_dir(opt, completer)
if gen.model:
add_embedding_terms(gen, completer)
@ -661,16 +666,7 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
model_name=model_name,
description=model_desc,
convert=convert,
)
if not imported_name:
if config_file := _pick_configuration_file(completer):
imported_name = gen.model_manager.heuristic_import(
model_path,
model_name=model_name,
description=model_desc,
convert=convert,
model_config_file=config_file,
config_file_callback=lambda x: _pick_configuration_file(completer,x),
)
if not imported_name:
print("** Aborting import.")
@ -687,14 +683,14 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
completer.update_models(gen.model_manager.list_models())
print(f">> {imported_name} successfully installed")
def _pick_configuration_file(completer)->Path:
def _pick_configuration_file(completer, checkpoint_path: Path)->Path:
print(
"""
Please select the type of this model:
f"""
Please select the type of the model at checkpoint {checkpoint_path}:
[1] A Stable Diffusion v1.x ckpt/safetensors model
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
[3] A Stable Diffusion v2.x base model (512 pixels)
[4] A Stable Diffusion v2.x v-predictive model (768 pixels)
[3] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
[5] Other (you will be prompted to enter the config file path)
[Q] I have no idea! Skip the import.
""")

View File

@ -109,6 +109,7 @@ def install_requested_models(
model_manager.heuristic_import(
path_url_or_repo,
convert=convert_to_diffusers,
config_file_callback=_pick_configuration_file,
commit_to_conf=config_file_path
)
except KeyboardInterrupt:
@ -138,6 +139,45 @@ def yes_or_no(prompt: str, default_yes=True):
else:
return response[0] in ("y", "Y")
# -------------------------------------
def _pick_configuration_file(checkpoint_path: Path)->Path:
print(
"""
Please select the type of this model:
[1] A Stable Diffusion v1.x ckpt/safetensors model
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
[3] A Stable Diffusion v2.x base model (512 pixels; no 'parameterization:' in its yaml file)
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for 'parameterization: "v"' in its yaml file)
[5] Other (you will be prompted to enter the config file path)
[Q] I have no idea! Skip the import.
""")
choices = [
global_config_dir() / 'stable-diffusion' / x
for x in [
'v1-inference.yaml',
'v1-inpainting-inference.yaml',
'v2-inference.yaml',
'v2-inference-v.yaml',
]
]
ok = False
while not ok:
try:
choice = input('select 0-5, Q > ').strip()
if choice.startswith(('q','Q')):
return
if choice == '5':
choice = Path(input('Select config file for this model> ').strip()).absolute()
ok = choice.exists()
else:
choice = choices[int(choice)-1]
ok = True
except (ValueError, IndexError):
print(f'{choice} is not a valid choice')
except EOFError:
return
return choice
# -------------------------------------
def get_root(root: str = None) -> str:

View File

@ -19,7 +19,7 @@ import warnings
from enum import Enum
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Callable
import safetensors
import safetensors.torch
@ -765,6 +765,7 @@ class ModelManager(object):
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path],Path] = None,
) -> str:
"""
Accept a string which could be:
@ -838,7 +839,10 @@ class ModelManager(object):
Path(thing).rglob("*.safetensors")
):
if model_name := self.heuristic_import(
str(m), convert, commit_to_conf=commit_to_conf
str(m),
convert,
commit_to_conf=commit_to_conf,
config_file_callback=config_file_callback,
):
print(f" >> {model_name} successfully imported")
return model_name
@ -901,11 +905,14 @@ class ModelManager(object):
print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
)
return
else:
print(
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
)
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
if not model_config_file:
return
if model_config_file.name.startswith('v2'):