mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
improve ability to bulk import .ckpt and .safetensors
This commit cleans up the code that did bulk imports of legacy model files. The code has been refactored, and the user is now offered the option of importing all the model files found in the directory, or selecting which ones to import.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import click
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@ -6,7 +7,7 @@ import traceback
|
||||
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, List
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
@ -21,7 +22,6 @@ from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
import click # type: ignore
|
||||
import ldm.invoke
|
||||
import pyparsing # type: ignore
|
||||
|
||||
@ -592,12 +592,8 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
models = list(Path(model_path).rglob('*.ckpt')) + list(Path(model_path).rglob('*.safetensors'))
|
||||
|
||||
if models:
|
||||
# Only the last model name will be used below.
|
||||
for model in sorted(models):
|
||||
|
||||
if click.confirm(f'Import {model.stem} ?', default=True):
|
||||
model_name = import_ckpt_model(model, gen, opt, completer)
|
||||
print()
|
||||
models = import_checkpoint_list(models, gen, opt, completer)
|
||||
model_name = models[0] if len(models) == 1 else None
|
||||
else:
|
||||
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
|
||||
|
||||
@ -614,13 +610,49 @@ def import_model(model_path: str, gen, opt, completer):
|
||||
print('** model failed to load. Discarding configuration entry')
|
||||
gen.model_manager.del_model(model_name)
|
||||
return
|
||||
if input('Make this the default model? [n] ').strip() in ('y','Y'):
|
||||
if click.confirm('Make this the default model?', default=False):
|
||||
gen.model_manager.set_default_model(model_name)
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f'>> {model_name} successfully installed')
|
||||
|
||||
def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
|
||||
'''
|
||||
Does a mass import of all the checkpoint/safetensors on a path list
|
||||
'''
|
||||
model_names = list()
|
||||
choice = input('** Directory of checkpoint/safetensors models detected. Install <a>ll or <s>elected models? [a] ') or 'a'
|
||||
do_all = choice.startswith('a')
|
||||
if do_all:
|
||||
config_file = _ask_for_config_file(models[0], completer, plural=True)
|
||||
manager = gen.model_manager
|
||||
for model in sorted(models):
|
||||
model_name = f'{model.stem}'
|
||||
model_description = f'Imported model {model_name}'
|
||||
if model_name in manager.model_names():
|
||||
print(f'** {model_name} is already imported. Skipping.')
|
||||
elif manager.import_ckpt_model(
|
||||
model,
|
||||
config = config_file,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf):
|
||||
model_names.append(model_name)
|
||||
print(f'>> Model {model_name} imported successfully')
|
||||
else:
|
||||
print(f'** Model {model} failed to import')
|
||||
else:
|
||||
for model in sorted(models):
|
||||
if click.confirm(f'Import {model.stem} ?', default=True):
|
||||
if model_name := import_ckpt_model(model, gen, opt, completer):
|
||||
print(f'>> Model {model.stem} imported successfully')
|
||||
model_names.append(model_name)
|
||||
else:
|
||||
printf('** Model {model} failed to import')
|
||||
print()
|
||||
return model_names
|
||||
|
||||
def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) -> Optional[str]:
|
||||
manager = gen.model_manager
|
||||
default_name = Path(path_or_repo).stem
|
||||
@ -632,7 +664,7 @@ def import_diffuser_model(path_or_repo: Union[Path, str], gen, _, completer) ->
|
||||
model_description=default_description
|
||||
)
|
||||
vae = None
|
||||
if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'):
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
if not manager.import_diffuser_model(
|
||||
@ -696,8 +728,7 @@ def _verify_load(model_name:str, gen)->bool:
|
||||
print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
|
||||
return False
|
||||
|
||||
do_switch = input('Keep model loaded? [y] ')
|
||||
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
|
||||
if click.confirm('Keep model loaded?', default=True):
|
||||
gen.set_model(model_name)
|
||||
else:
|
||||
print('>> Restoring previous model')
|
||||
@ -710,20 +741,26 @@ def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_des
|
||||
model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
|
||||
return model_name, model_description
|
||||
|
||||
def _ask_for_config_file(model_path: Union[str,Path], completer)->Path:
|
||||
default = 1
|
||||
def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path:
|
||||
default = '1'
|
||||
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
|
||||
default = 3
|
||||
default = '3'
|
||||
choices={
|
||||
'1': 'v1-inference.yaml',
|
||||
'2': 'v2-inference-v.yaml',
|
||||
'3': 'v1-inpainting-inference.yaml',
|
||||
}
|
||||
print('''What type of model is this?:
|
||||
|
||||
prompt = '''What type of models are these?:
|
||||
[1] Models based on Stable Diffusion 1.X
|
||||
[2] Models based on Stable Diffusion 2.X
|
||||
[3] Inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else''' if plural else '''What type of model is this?:
|
||||
[1] A model based on Stable Diffusion 1.X
|
||||
[2] A model based on Stable Diffusion 2.X
|
||||
[3] An inpainting model based on Stable Diffusion 1.X
|
||||
[4] Something else''')
|
||||
[3] An inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else'''
|
||||
print(prompt)
|
||||
choice = input(f'Your choice: [{default}] ')
|
||||
choice = choice.strip() or default
|
||||
if config_file := choices.get(choice,None):
|
||||
@ -782,7 +819,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_
|
||||
return
|
||||
|
||||
vae = None
|
||||
if input('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] ').strip() in ('y','Y'):
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
new_config = gen.model_manager.convert_and_import(
|
||||
@ -798,11 +835,10 @@ def optimize_model(model_name_or_path:str, gen, opt, completer, original_config_
|
||||
return
|
||||
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
if input(f'Load optimized model {model_name}? [y] ').strip() not in ('n','N'):
|
||||
if click.confirm(f'Load optimized model {model_name}?', default=True):
|
||||
gen.set_model(model_name)
|
||||
|
||||
response = input(f'Delete the original .ckpt file at ({ckpt_path} ? [n] ')
|
||||
if response.startswith(('y','Y')):
|
||||
if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f'{ckpt_path} deleted')
|
||||
|
||||
@ -815,10 +851,10 @@ def del_config(model_name:str, gen, opt, completer):
|
||||
print(f"** Unknown model {model_name}")
|
||||
return
|
||||
|
||||
if input(f'Remove {model_name} from the list of models known to InvokeAI? [y] ').strip().startswith(('n','N')):
|
||||
if not click.confirm(f'Remove {model_name} from the list of models known to InvokeAI?',default=True):
|
||||
return
|
||||
|
||||
delete_completely = input('Completely remove the model file or directory from disk? [n] ').startswith(('y','Y'))
|
||||
delete_completely = click.confirm('Completely remove the model file or directory from disk?',default=False)
|
||||
gen.model_manager.del_model(model_name,delete_files=delete_completely)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f'** {model_name} deleted')
|
||||
@ -847,7 +883,7 @@ def edit_model(model_name:str, gen, opt, completer):
|
||||
# this does the update
|
||||
manager.add_model(new_name, info, True)
|
||||
|
||||
if input('Make this the default model? [n] ').startswith(('y','Y')):
|
||||
if click.confirm('Make this the default model?',default=False):
|
||||
manager.set_default_model(new_name)
|
||||
manager.commit(opt.conf)
|
||||
completer.update_models(manager.list_models())
|
||||
@ -1179,8 +1215,7 @@ def report_model_error(opt:Namespace, e:Exception):
|
||||
if yes_to_all:
|
||||
print('** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE')
|
||||
else:
|
||||
response = input('Do you want to run invokeai-configure script to select and/or reinstall models? [y] ')
|
||||
if response.startswith(('n', 'N')):
|
||||
if click.confirm('Do you want to run invokeai-configure script to select and/or reinstall models?', default=True):
|
||||
return
|
||||
|
||||
print('invokeai-configure is launching....\n')
|
||||
|
Reference in New Issue
Block a user