mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
support conversion of v2 models
- This PR introduces a CLI prompt for the proper configuration file to use when converting a ckpt file, in order to support both inpainting and v2 models files. - When user tries to directly !import a v2 model, it prints out a proper warning that v2 ckpts are not directly supported.
This commit is contained in:
parent
097e41e8d2
commit
b71e675e8d
@ -657,15 +657,8 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
|
||||
model_name=default_name,
|
||||
model_description=default_description
|
||||
)
|
||||
config_file = None
|
||||
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
config_file = input('Configuration file for this model: ').strip()
|
||||
done = os.path.exists(config_file)
|
||||
if not (config_file := _ask_for_config_file(path_or_url, completer)):
|
||||
return
|
||||
|
||||
completer.complete_extensions(('.ckpt','.safetensors'))
|
||||
vae = None
|
||||
@ -693,8 +686,14 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
|
||||
def _verify_load(model_name:str, gen)->bool:
|
||||
print('>> Verifying that new model loads...')
|
||||
current_model = gen.model_name
|
||||
if not gen.model_manager.get_model(model_name):
|
||||
try:
|
||||
if not gen.model_manager.get_model(model_name):
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'** model failed to load: {str(e)}')
|
||||
print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
|
||||
return False
|
||||
|
||||
do_switch = input('Keep model loaded? [y] ')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
gen.set_model(model_name)
|
||||
@ -709,6 +708,35 @@ def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_des
|
||||
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
|
||||
return model_name, model_description
|
||||
|
||||
def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
|
||||
default = 1
|
||||
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
|
||||
default = 3
|
||||
choices={
|
||||
'1': 'v1-inference.yaml',
|
||||
'2': 'v2-inference-v.yaml',
|
||||
'3': 'v1-inpainting-inference.yaml',
|
||||
}
|
||||
print(f'''What type of model is this?:
|
||||
[1] A model based on Stable Diffusion 1.X
|
||||
[2] A model based on Stable Diffusion 2.X
|
||||
[3] An inpainting model based on Stable Diffusion 1.X
|
||||
[4] Something else''')
|
||||
choice = input(f'Your choice: [{default}] ')
|
||||
choice = choice.strip() or default
|
||||
if config_file := choices.get(choice,None):
|
||||
return Path('configs','stable-diffusion',config_file)
|
||||
|
||||
# otherwise ask user to select
|
||||
done = False
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
|
||||
while not done:
|
||||
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
|
||||
done = not config_path or os.path.exists(config_path)
|
||||
return config_path
|
||||
|
||||
|
||||
def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
@ -731,12 +759,9 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
ckpt_path.stem,
|
||||
f'Converted model {ckpt_path.stem}'
|
||||
)
|
||||
is_inpainting = input('Is this an inpainting model? [n] ').startswith(('y','Y'))
|
||||
original_config_file = Path(
|
||||
'configs',
|
||||
'stable-diffusion',
|
||||
'v1-inpainting-inference.yaml' if is_inpainting else 'v1-inference.yaml'
|
||||
)
|
||||
original_config_file = _ask_for_config_file(model_name_or_path, completer)
|
||||
if not original_config_file:
|
||||
return
|
||||
else:
|
||||
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user