mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
improve support for V2 variant legacy checkpoints
This commit enhances support for V2 variant (epsilon and v-predict) import and conversion to diffusers, by prompting the user to select the proper config file during startup time autoimport as well as in the invokeai installer script..
This commit is contained in:
parent
61d5cb2536
commit
694925f427
@ -158,10 +158,16 @@ def main():
|
||||
except Exception as e:
|
||||
report_model_error(opt, e)
|
||||
|
||||
# completer is the readline object
|
||||
completer = get_completer(opt, models=gen.model_manager.list_models())
|
||||
|
||||
# try to autoconvert new models
|
||||
if path := opt.autoimport:
|
||||
gen.model_manager.heuristic_import(
|
||||
str(path), convert=False, commit_to_conf=opt.conf
|
||||
str(path),
|
||||
convert=False,
|
||||
commit_to_conf=opt.conf,
|
||||
config_file_callback=lambda x: _pick_configuration_file(completer,x),
|
||||
)
|
||||
|
||||
if path := opt.autoconvert:
|
||||
@ -180,7 +186,7 @@ def main():
|
||||
)
|
||||
|
||||
try:
|
||||
main_loop(gen, opt)
|
||||
main_loop(gen, opt, completer)
|
||||
except KeyboardInterrupt:
|
||||
print(
|
||||
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
||||
@ -191,7 +197,7 @@ def main():
|
||||
|
||||
|
||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||
def main_loop(gen, opt):
|
||||
def main_loop(gen, opt, completer):
|
||||
"""prompt/read/execute loop"""
|
||||
global infile
|
||||
done = False
|
||||
@ -202,7 +208,6 @@ def main_loop(gen, opt):
|
||||
# The readline completer reads history from the .dream_history file located in the
|
||||
# output directory specified at the time of script launch. We do not currently support
|
||||
# changing the history file midstream when the output directory is changed.
|
||||
completer = get_completer(opt, models=gen.model_manager.list_models())
|
||||
set_default_output_dir(opt, completer)
|
||||
if gen.model:
|
||||
add_embedding_terms(gen, completer)
|
||||
@ -661,17 +666,8 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
|
||||
model_name=model_name,
|
||||
description=model_desc,
|
||||
convert=convert,
|
||||
config_file_callback=lambda x: _pick_configuration_file(completer,x),
|
||||
)
|
||||
|
||||
if not imported_name:
|
||||
if config_file := _pick_configuration_file(completer):
|
||||
imported_name = gen.model_manager.heuristic_import(
|
||||
model_path,
|
||||
model_name=model_name,
|
||||
description=model_desc,
|
||||
convert=convert,
|
||||
model_config_file=config_file,
|
||||
)
|
||||
if not imported_name:
|
||||
print("** Aborting import.")
|
||||
return
|
||||
@ -687,14 +683,14 @@ def import_model(model_path: str, gen, opt, completer, convert=False):
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f">> {imported_name} successfully installed")
|
||||
|
||||
def _pick_configuration_file(completer)->Path:
|
||||
def _pick_configuration_file(completer, checkpoint_path: Path)->Path:
|
||||
print(
|
||||
"""
|
||||
Please select the type of this model:
|
||||
f"""
|
||||
Please select the type of the model at checkpoint {checkpoint_path}:
|
||||
[1] A Stable Diffusion v1.x ckpt/safetensors model
|
||||
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
|
||||
[3] A Stable Diffusion v2.x base model (512 pixels)
|
||||
[4] A Stable Diffusion v2.x v-predictive model (768 pixels)
|
||||
[3] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
|
||||
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
|
||||
[5] Other (you will be prompted to enter the config file path)
|
||||
[Q] I have no idea! Skip the import.
|
||||
""")
|
||||
|
@ -109,6 +109,7 @@ def install_requested_models(
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
config_file_callback=_pick_configuration_file,
|
||||
commit_to_conf=config_file_path
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
@ -138,6 +139,45 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
# -------------------------------------
|
||||
def _pick_configuration_file(checkpoint_path: Path)->Path:
|
||||
print(
|
||||
"""
|
||||
Please select the type of this model:
|
||||
[1] A Stable Diffusion v1.x ckpt/safetensors model
|
||||
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
|
||||
[3] A Stable Diffusion v2.x base model (512 pixels; no 'parameterization:' in its yaml file)
|
||||
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for 'parameterization: "v"' in its yaml file)
|
||||
[5] Other (you will be prompted to enter the config file path)
|
||||
[Q] I have no idea! Skip the import.
|
||||
""")
|
||||
choices = [
|
||||
global_config_dir() / 'stable-diffusion' / x
|
||||
for x in [
|
||||
'v1-inference.yaml',
|
||||
'v1-inpainting-inference.yaml',
|
||||
'v2-inference.yaml',
|
||||
'v2-inference-v.yaml',
|
||||
]
|
||||
]
|
||||
|
||||
ok = False
|
||||
while not ok:
|
||||
try:
|
||||
choice = input('select 0-5, Q > ').strip()
|
||||
if choice.startswith(('q','Q')):
|
||||
return
|
||||
if choice == '5':
|
||||
choice = Path(input('Select config file for this model> ').strip()).absolute()
|
||||
ok = choice.exists()
|
||||
else:
|
||||
choice = choices[int(choice)-1]
|
||||
ok = True
|
||||
except (ValueError, IndexError):
|
||||
print(f'{choice} is not a valid choice')
|
||||
except EOFError:
|
||||
return
|
||||
return choice
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
|
@ -19,7 +19,7 @@ import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, Callable
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
@ -765,6 +765,7 @@ class ModelManager(object):
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path],Path] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Accept a string which could be:
|
||||
@ -838,7 +839,10 @@ class ModelManager(object):
|
||||
Path(thing).rglob("*.safetensors")
|
||||
):
|
||||
if model_name := self.heuristic_import(
|
||||
str(m), convert, commit_to_conf=commit_to_conf
|
||||
str(m),
|
||||
convert,
|
||||
commit_to_conf=commit_to_conf,
|
||||
config_file_callback=config_file_callback,
|
||||
):
|
||||
print(f" >> {model_name} successfully imported")
|
||||
return model_name
|
||||
@ -901,11 +905,14 @@ class ModelManager(object):
|
||||
print(
|
||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
)
|
||||
return
|
||||
else:
|
||||
print(
|
||||
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
)
|
||||
|
||||
if not model_config_file and config_file_callback:
|
||||
model_config_file = config_file_callback(model_path)
|
||||
if not model_config_file:
|
||||
return
|
||||
|
||||
if model_config_file.name.startswith('v2'):
|
||||
|
Loading…
Reference in New Issue
Block a user