support conversion of v2 models

- This PR introduces a CLI prompt for the proper configuration file to
  use when converting a ckpt file, in order to support both inpainting
  and v2 models files.

- When user tries to directly !import a v2 model, it prints out a proper
  warning that v2 ckpts are not directly supported.
This commit is contained in:
Lincoln Stein 2023-02-11 09:39:41 -05:00
parent 097e41e8d2
commit b71e675e8d

View File

@ -657,15 +657,8 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
model_name=default_name, model_name=default_name,
model_description=default_description model_description=default_description
) )
config_file = None if not (config_file := _ask_for_config_file(path_or_url, completer)):
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml') return
completer.complete_extensions(('.yaml','.yml'))
completer.set_line(str(default))
done = False
while not done:
config_file = input('Configuration file for this model: ').strip()
done = os.path.exists(config_file)
completer.complete_extensions(('.ckpt','.safetensors')) completer.complete_extensions(('.ckpt','.safetensors'))
vae = None vae = None
@ -693,8 +686,14 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
def _verify_load(model_name:str, gen)->bool: def _verify_load(model_name:str, gen)->bool:
print('>> Verifying that new model loads...') print('>> Verifying that new model loads...')
current_model = gen.model_name current_model = gen.model_name
if not gen.model_manager.get_model(model_name): try:
if not gen.model_manager.get_model(model_name):
return False
except Exception as e:
print(f'** model failed to load: {str(e)}')
print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
return False return False
do_switch = input('Keep model loaded? [y] ') do_switch = input('Keep model loaded? [y] ')
if len(do_switch)==0 or do_switch[0] in ('y','Y'): if len(do_switch)==0 or do_switch[0] in ('y','Y'):
gen.set_model(model_name) gen.set_model(model_name)
@ -709,6 +708,35 @@ def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_des
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
return model_name, model_description return model_name, model_description
def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
default = 1
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
default = 3
choices={
'1': 'v1-inference.yaml',
'2': 'v2-inference-v.yaml',
'3': 'v1-inpainting-inference.yaml',
}
print(f'''What type of model is this?:
[1] A model based on Stable Diffusion 1.X
[2] A model based on Stable Diffusion 2.X
[3] An inpainting model based on Stable Diffusion 1.X
[4] Something else''')
choice = input(f'Your choice: [{default}] ')
choice = choice.strip() or default
if config_file := choices.get(choice,None):
return Path('configs','stable-diffusion',config_file)
# otherwise ask user to select
done = False
completer.complete_extensions(('.yaml','.yml'))
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
while not done:
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
done = not config_path or os.path.exists(config_path)
return config_path
def optimize_model(model_name_or_path:str, gen, opt, completer): def optimize_model(model_name_or_path:str, gen, opt, completer):
manager = gen.model_manager manager = gen.model_manager
ckpt_path = None ckpt_path = None
@ -731,12 +759,9 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
ckpt_path.stem, ckpt_path.stem,
f'Converted model {ckpt_path.stem}' f'Converted model {ckpt_path.stem}'
) )
is_inpainting = input('Is this an inpainting model? [n] ').startswith(('y','Y')) original_config_file = _ask_for_config_file(model_name_or_path, completer)
original_config_file = Path( if not original_config_file:
'configs', return
'stable-diffusion',
'v1-inpainting-inference.yaml' if is_inpainting else 'v1-inference.yaml'
)
else: else:
print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file') print(f'** {model_name_or_path} is neither an existing model nor the path to a .ckpt file')
return return