2023-01-26 20:10:16 +00:00
"""
2023-03-03 06:02:00 +00:00
invokeai . frontend . merge exports a single function call merge_diffusion_models ( )
2023-01-22 23:07:53 +00:00
used to merge 2 - 3 models together and create a new InvokeAI - registered diffusion model .
2023-01-26 20:10:16 +00:00
Copyright ( c ) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
2023-02-03 01:26:45 +00:00
import curses
2023-01-22 23:07:53 +00:00
import os
2023-01-26 20:10:16 +00:00
import sys
2023-02-03 15:14:51 +00:00
import warnings
2023-01-26 20:10:16 +00:00
from argparse import Namespace
from pathlib import Path
from typing import List , Union
import npyscreen
2023-02-05 23:35:01 +00:00
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from npyscreen import widget
2023-01-22 23:07:53 +00:00
from omegaconf import OmegaConf
2023-04-29 13:43:40 +00:00
import invokeai . backend . util . logging as logger
2023-05-26 00:41:26 +00:00
from invokeai . services . config import InvokeAIAppConfig
2023-03-03 05:02:15 +00:00
from . . . backend . model_management import ModelManager
2023-03-03 06:02:00 +00:00
from . . . frontend . install . widgets import FloatTitleSlider
2023-01-26 20:10:16 +00:00
DEST_MERGED_MODEL_DIR = " merged_models "
2023-05-26 00:41:26 +00:00
config = InvokeAIAppConfig . get_config ( )
2023-03-03 06:02:00 +00:00
2023-01-26 20:10:16 +00:00
def merge_diffusion_models (
model_ids_or_paths : List [ Union [ str , Path ] ] ,
alpha : float = 0.5 ,
interp : str = None ,
force : bool = False ,
* * kwargs ,
) - > DiffusionPipeline :
"""
model_ids_or_paths - up to three models , designated by their local paths or HuggingFace repo_ids
alpha - The interpolation parameter . Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged . A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging . Supports " sigmoid " , " inv_sigmoid " , " add_difference " and None.
Passing None uses the default interpolation which is weighted sum interpolation . For merging three checkpoints , only " add_difference " is supported .
force - Whether to ignore mismatch in model_config . json for the current models . Defaults to False .
* * kwargs - the default DiffusionPipeline . get_config_dict kwargs :
cache_dir , resume_download , force_download , proxies , local_files_only , use_auth_token , revision , torch_dtype , device_map
"""
2023-02-03 15:14:51 +00:00
with warnings . catch_warnings ( ) :
2023-02-05 23:35:01 +00:00
warnings . simplefilter ( " ignore " )
2023-02-03 15:14:51 +00:00
verbosity = dlogging . get_verbosity ( )
dlogging . set_verbosity_error ( )
2023-02-05 23:35:01 +00:00
2023-02-03 15:14:51 +00:00
pipe = DiffusionPipeline . from_pretrained (
model_ids_or_paths [ 0 ] ,
2023-05-04 04:43:51 +00:00
cache_dir = kwargs . get ( " cache_dir " , config . cache_dir ) ,
2023-02-03 15:14:51 +00:00
custom_pipeline = " checkpoint_merger " ,
)
merged_pipe = pipe . merge (
pretrained_model_name_or_path_list = model_ids_or_paths ,
alpha = alpha ,
interp = interp ,
force = force ,
* * kwargs ,
)
dlogging . set_verbosity ( verbosity )
2023-01-26 20:10:16 +00:00
return merged_pipe
def merge_diffusion_models_and_commit (
models : List [ " str " ] ,
merged_model_name : str ,
alpha : float = 0.5 ,
interp : str = None ,
force : bool = False ,
* * kwargs ,
) :
"""
2023-01-22 23:07:53 +00:00
models - up to three models , designated by their InvokeAI models . yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter . Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged . A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
2023-02-17 20:42:06 +00:00
interp - The interpolation method to use for the merging . Supports " weighted_average " , " sigmoid " , " inv_sigmoid " , " add_difference " and None .
Passing None uses the default interpolation which is weighted sum interpolation . For merging three checkpoints , only " add_difference " is supported . Add_difference is A + ( B - C ) .
2023-01-22 23:07:53 +00:00
force - Whether to ignore mismatch in model_config . json for the current models . Defaults to False .
* * kwargs - the default DiffusionPipeline . get_config_dict kwargs :
cache_dir , resume_download , force_download , proxies , local_files_only , use_auth_token , revision , torch_dtype , device_map
2023-01-26 20:10:16 +00:00
"""
2023-05-04 04:43:51 +00:00
config_file = config . model_conf_path
2023-01-22 23:07:53 +00:00
model_manager = ModelManager ( OmegaConf . load ( config_file ) )
2023-01-23 05:20:28 +00:00
for mod in models :
2023-01-26 20:10:16 +00:00
assert mod in model_manager . model_names ( ) , f ' ** Unknown model " { mod } " '
assert (
model_manager . model_info ( mod ) . get ( " format " , None ) == " diffusers "
) , f " ** { mod } is not a diffusers model. It must be optimized before merging. "
2023-01-22 23:07:53 +00:00
model_ids_or_paths = [ model_manager . model_name_or_path ( x ) for x in models ]
2023-01-26 20:10:16 +00:00
merged_pipe = merge_diffusion_models (
model_ids_or_paths , alpha , interp , force , * * kwargs
)
2023-05-04 04:43:51 +00:00
dump_path = config . models_dir / DEST_MERGED_MODEL_DIR
2023-01-26 20:10:16 +00:00
os . makedirs ( dump_path , exist_ok = True )
2023-01-22 23:07:53 +00:00
dump_path = dump_path / merged_model_name
2023-01-26 20:10:16 +00:00
merged_pipe . save_pretrained ( dump_path , safe_serialization = 1 )
import_args = dict (
model_name = merged_model_name , description = f ' Merge of models { " , " . join ( models ) } '
)
if vae := model_manager . config [ models [ 0 ] ] . get ( " vae " , None ) :
2023-04-29 13:43:40 +00:00
logger . info ( f " Using configured VAE assigned to { models [ 0 ] } " )
2023-01-26 20:10:16 +00:00
import_args . update ( vae = vae )
model_manager . import_diffuser_model ( dump_path , * * import_args )
model_manager . commit ( config_file )
def _parse_args ( ) - > Namespace :
parser = argparse . ArgumentParser ( description = " InvokeAI model merging " )
parser . add_argument (
" --root_dir " ,
type = Path ,
2023-05-04 04:43:51 +00:00
default = config . root ,
2023-01-26 20:10:16 +00:00
help = " Path to the invokeai runtime directory " ,
2023-01-22 23:07:53 +00:00
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --front_end " ,
" --gui " ,
dest = " front_end " ,
action = " store_true " ,
default = False ,
help = " Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored. " ,
2023-01-22 23:07:53 +00:00
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --models " ,
type = str ,
nargs = " + " ,
help = " Two to three model names to be merged " ,
)
parser . add_argument (
" --merged_model_name " ,
" --destination " ,
dest = " merged_model_name " ,
type = str ,
help = " Name of the output model. If not specified, will be the concatenation of the input model names. " ,
)
parser . add_argument (
" --alpha " ,
type = float ,
default = 0.5 ,
help = " The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models " ,
)
parser . add_argument (
" --interpolation " ,
dest = " interp " ,
type = str ,
choices = [ " weighted_sum " , " sigmoid " , " inv_sigmoid " , " add_difference " ] ,
default = " weighted_sum " ,
help = ' Interpolation method to use. If three models are present, only " add_difference " will work. ' ,
)
parser . add_argument (
" --force " ,
action = " store_true " ,
help = " Try to merge models even if they are incompatible with each other " ,
)
parser . add_argument (
" --clobber " ,
" --overwrite " ,
dest = " clobber " ,
action = " store_true " ,
help = " Overwrite the merged model if --merged_model_name already exists " ,
)
return parser . parse_args ( )
2023-01-22 23:07:53 +00:00
2023-01-26 20:10:16 +00:00
# ------------------------- GUI HERE -------------------------
class mergeModelsForm ( npyscreen . FormMultiPageAction ) :
2023-02-17 19:46:26 +00:00
interpolations = [ " weighted_sum " , " sigmoid " , " inv_sigmoid " ]
2023-01-26 20:10:16 +00:00
def __init__ ( self , parentApp , name ) :
self . parentApp = parentApp
2023-02-05 23:35:01 +00:00
self . ALLOW_RESIZE = True
self . FIX_MINIMUM_SIZE_WHEN_CREATED = False
2023-01-26 20:10:16 +00:00
super ( ) . __init__ ( parentApp , name )
@property
def model_manager ( self ) :
return self . parentApp . model_manager
def afterEditing ( self ) :
self . parentApp . setNextForm ( None )
def create ( self ) :
2023-02-05 23:35:01 +00:00
window_height , window_width = curses . initscr ( ) . getmaxyx ( )
2023-01-26 20:10:16 +00:00
self . model_names = self . get_model_names ( )
2023-02-03 01:26:45 +00:00
max_width = max ( [ len ( x ) for x in self . model_names ] )
max_width + = 6
2023-02-05 23:35:01 +00:00
horizontal_layout = max_width * 3 < window_width
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
color = " CONTROL " ,
2023-03-03 05:02:15 +00:00
value = " Select two models to merge and optionally a third. " ,
2023-02-03 01:26:45 +00:00
editable = False ,
)
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
color = " CONTROL " ,
2023-03-03 05:02:15 +00:00
value = " Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next. " ,
2023-02-03 01:26:45 +00:00
editable = False ,
)
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 1 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
rely = 4 if horizontal_layout else None ,
)
self . model1 = self . add_widget_intelligent (
npyscreen . SelectOne ,
values = self . model_names ,
value = 0 ,
max_height = len ( self . model_names ) ,
max_width = max_width ,
scroll_exit = True ,
rely = 5 ,
)
2023-01-26 20:10:16 +00:00
self . add_widget_intelligent (
2023-02-03 01:26:45 +00:00
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 2 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-02-05 23:35:01 +00:00
relx = max_width + 3 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
rely = 4 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
)
2023-02-03 01:26:45 +00:00
self . model2 = self . add_widget_intelligent (
npyscreen . SelectOne ,
2023-02-05 23:35:01 +00:00
name = " (2) " ,
2023-01-26 20:10:16 +00:00
values = self . model_names ,
2023-02-03 01:26:45 +00:00
value = 1 ,
max_height = len ( self . model_names ) ,
max_width = max_width ,
2023-02-05 23:35:01 +00:00
relx = max_width + 3 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
rely = 5 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 3 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-02-05 23:35:01 +00:00
relx = max_width * 2 + 3 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
rely = 4 if horizontal_layout else None ,
)
models_plus_none = self . model_names . copy ( )
2023-02-05 23:35:01 +00:00
models_plus_none . insert ( 0 , " None " )
2023-02-03 01:26:45 +00:00
self . model3 = self . add_widget_intelligent (
npyscreen . SelectOne ,
2023-02-05 23:35:01 +00:00
name = " (3) " ,
2023-02-03 01:26:45 +00:00
values = models_plus_none ,
value = 0 ,
2023-02-05 23:35:01 +00:00
max_height = len ( self . model_names ) + 1 ,
2023-02-03 01:26:45 +00:00
max_width = max_width ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
2023-02-05 23:35:01 +00:00
relx = max_width * 2 + 3 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
rely = 5 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
)
2023-02-05 23:35:01 +00:00
for m in [ self . model1 , self . model2 , self . model3 ] :
2023-02-03 01:26:45 +00:00
m . when_value_edited = self . models_changed
2023-01-26 20:10:16 +00:00
self . merged_model_name = self . add_widget_intelligent (
npyscreen . TitleText ,
name = " Name for merged model: " ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
value = " " ,
scroll_exit = True ,
)
self . force = self . add_widget_intelligent (
npyscreen . Checkbox ,
name = " Force merge of incompatible models " ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
value = False ,
scroll_exit = True ,
)
self . merge_method = self . add_widget_intelligent (
npyscreen . TitleSelectOne ,
name = " Merge Method: " ,
values = self . interpolations ,
value = 0 ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
max_height = len ( self . interpolations ) + 1 ,
scroll_exit = True ,
)
self . alpha = self . add_widget_intelligent (
FloatTitleSlider ,
name = " Weight (alpha) to assign to second and third models: " ,
2023-02-17 19:46:26 +00:00
out_of = 1.0 ,
step = 0.01 ,
2023-01-26 20:10:16 +00:00
lowest = 0 ,
value = 0.5 ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-02-03 01:26:45 +00:00
self . model1 . editing = True
2023-01-26 20:10:16 +00:00
def models_changed ( self ) :
2023-02-03 01:26:45 +00:00
models = self . model1 . values
selected_model1 = self . model1 . value [ 0 ]
selected_model2 = self . model2 . value [ 0 ]
selected_model3 = self . model3 . value [ 0 ]
2023-02-05 23:35:01 +00:00
merged_model_name = f " { models [ selected_model1 ] } + { models [ selected_model2 ] } "
2023-02-03 01:26:45 +00:00
self . merged_model_name . value = merged_model_name
2023-02-05 23:35:01 +00:00
2023-02-03 01:26:45 +00:00
if selected_model3 > 0 :
2023-03-03 06:02:00 +00:00
self . merge_method . values = [ " add_difference ( A+(B-C) ) " ]
self . merged_model_name . value + = f " + { models [ selected_model3 - 1 ] } " # In model3 there is one more element in the list (None). So we have to subtract one.
2023-01-26 20:10:16 +00:00
else :
2023-02-05 23:35:01 +00:00
self . merge_method . values = self . interpolations
self . merge_method . value = 0
2023-01-26 20:10:16 +00:00
def on_ok ( self ) :
if self . validate_field_values ( ) and self . check_for_overwrite ( ) :
self . parentApp . setNextForm ( None )
self . editing = False
self . parentApp . merge_arguments = self . marshall_arguments ( )
2023-02-05 23:35:01 +00:00
npyscreen . notify ( " Starting the merge... " )
2023-01-26 20:10:16 +00:00
else :
self . editing = True
def on_cancel ( self ) :
sys . exit ( 0 )
2023-02-05 23:35:01 +00:00
def marshall_arguments ( self ) - > dict :
2023-02-03 01:26:45 +00:00
model_names = self . model_names
models = [
model_names [ self . model1 . value [ 0 ] ] ,
model_names [ self . model2 . value [ 0 ] ] ,
2023-02-05 23:35:01 +00:00
]
2023-02-03 01:26:45 +00:00
if self . model3 . value [ 0 ] > 0 :
2023-02-05 23:35:01 +00:00
models . append ( model_names [ self . model3 . value [ 0 ] - 1 ] )
2023-03-03 06:02:00 +00:00
interp = " add_difference "
2023-02-17 20:42:06 +00:00
else :
2023-03-03 06:02:00 +00:00
interp = self . interpolations [ self . merge_method . value [ 0 ] ]
2023-02-03 01:26:45 +00:00
2023-01-26 20:10:16 +00:00
args = dict (
models = models ,
2023-02-05 23:35:01 +00:00
alpha = self . alpha . value ,
2023-02-17 20:42:06 +00:00
interp = interp ,
2023-02-05 23:35:01 +00:00
force = self . force . value ,
merged_model_name = self . merged_model_name . value ,
2023-01-26 20:10:16 +00:00
)
return args
def check_for_overwrite ( self ) - > bool :
model_out = self . merged_model_name . value
if model_out not in self . model_names :
return True
else :
return npyscreen . notify_yes_no (
f " The chosen merged model destination, { model_out } , is already in use. Overwrite? "
)
2023-02-05 23:35:01 +00:00
def validate_field_values ( self ) - > bool :
2023-01-26 20:10:16 +00:00
bad_fields = [ ]
2023-02-03 01:26:45 +00:00
model_names = self . model_names
2023-02-05 23:35:01 +00:00
selected_models = set (
( model_names [ self . model1 . value [ 0 ] ] , model_names [ self . model2 . value [ 0 ] ] )
)
2023-02-03 01:26:45 +00:00
if self . model3 . value [ 0 ] > 0 :
2023-02-05 23:35:01 +00:00
selected_models . add ( model_names [ self . model3 . value [ 0 ] - 1 ] )
2023-02-03 01:26:45 +00:00
if len ( selected_models ) < 2 :
2023-02-05 23:35:01 +00:00
bad_fields . append (
f " Please select two or three DIFFERENT models to compare. You selected { selected_models } "
)
2023-01-26 20:10:16 +00:00
if len ( bad_fields ) > 0 :
2023-02-05 23:35:01 +00:00
message = " The following problems were detected and must be corrected: "
2023-01-26 20:10:16 +00:00
for problem in bad_fields :
2023-02-05 23:35:01 +00:00
message + = f " \n * { problem } "
2023-01-26 20:10:16 +00:00
npyscreen . notify_confirm ( message )
return False
else :
return True
def get_model_names ( self ) - > List [ str ] :
model_names = [
name
for name in self . model_manager . model_names ( )
if self . model_manager . model_info ( name ) . get ( " format " ) == " diffusers "
]
return sorted ( model_names )
class Mergeapp ( npyscreen . NPSAppManaged ) :
def __init__ ( self ) :
super ( ) . __init__ ( )
2023-05-04 04:43:51 +00:00
conf = OmegaConf . load ( config . model_conf_path )
2023-01-26 20:10:16 +00:00
self . model_manager = ModelManager (
conf , " cpu " , " float16 "
) # precision doesn't really matter here
def onStart ( self ) :
2023-02-03 01:26:45 +00:00
npyscreen . setTheme ( npyscreen . Themes . ElegantTheme )
2023-01-26 20:10:16 +00:00
self . main = self . addForm ( " MAIN " , mergeModelsForm , name = " Merge Models Settings " )
2023-02-05 23:35:01 +00:00
2023-01-26 20:10:16 +00:00
def run_gui ( args : Namespace ) :
mergeapp = Mergeapp ( )
mergeapp . run ( )
args = mergeapp . merge_arguments
merge_diffusion_models_and_commit ( * * args )
2023-04-29 13:43:40 +00:00
logger . info ( f ' Models merged into new model: " { args [ " merged_model_name " ] } " . ' )
2023-01-26 20:10:16 +00:00
def run_cli ( args : Namespace ) :
assert args . alpha > = 0 and args . alpha < = 1.0 , " alpha must be between 0 and 1 "
assert (
2023-02-03 01:26:45 +00:00
args . models and len ( args . models ) > = 1 and len ( args . models ) < = 3
) , " Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage. "
2023-01-26 20:10:16 +00:00
if not args . merged_model_name :
args . merged_model_name = " + " . join ( args . models )
2023-04-29 13:43:40 +00:00
logger . info (
2023-04-19 00:49:00 +00:00
f ' No --merged_model_name provided. Defaulting to " { args . merged_model_name } " '
2023-01-26 20:10:16 +00:00
)
2023-05-04 04:43:51 +00:00
model_manager = ModelManager ( OmegaConf . load ( config . model_conf_path ) )
2023-01-26 20:10:16 +00:00
assert (
args . clobber or args . merged_model_name not in model_manager . model_names ( )
) , f ' A model named " { args . merged_model_name } " already exists. Use --clobber to overwrite. '
merge_diffusion_models_and_commit ( * * vars ( args ) )
2023-04-29 13:43:40 +00:00
logger . info ( f ' Models merged into new model: " { args . merged_model_name } " . ' )
2023-01-26 20:10:16 +00:00
def main ( ) :
args = _parse_args ( )
2023-05-04 04:43:51 +00:00
config . root = args . root_dir
2023-01-26 20:10:16 +00:00
2023-05-04 04:43:51 +00:00
cache_dir = config . cache_dir
2023-01-26 20:10:16 +00:00
os . environ [
" HF_HOME "
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args . cache_dir = cache_dir
2023-02-05 23:35:01 +00:00
try :
if args . front_end :
run_gui ( args )
else :
run_cli ( args )
except widget . NotEnoughSpaceForWidget as e :
if str ( e ) . startswith ( " Height of 1 allocated " ) :
2023-04-29 13:43:40 +00:00
logger . error (
2023-04-19 00:49:00 +00:00
" You need to have at least two diffusers models defined in models.yaml in order to merge "
2023-02-05 23:35:01 +00:00
)
else :
2023-04-29 13:43:40 +00:00
logger . error (
2023-04-19 00:49:00 +00:00
" Not enough room for the user interface. Try making this window larger. "
2023-03-03 06:02:00 +00:00
)
2023-02-05 23:35:01 +00:00
sys . exit ( - 1 )
2023-04-19 00:49:00 +00:00
except Exception as e :
2023-04-29 13:43:40 +00:00
logger . error ( e )
2023-02-05 23:35:01 +00:00
sys . exit ( - 1 )
except KeyboardInterrupt :
sys . exit ( - 1 )
2023-01-26 20:10:16 +00:00
if __name__ == " __main__ " :
main ( )