2023-01-26 20:10:16 +00:00
"""
2023-03-03 06:02:00 +00:00
invokeai . frontend . merge exports a single function call merge_diffusion_models ( )
2023-01-22 23:07:53 +00:00
used to merge 2 - 3 models together and create a new InvokeAI - registered diffusion model .
2023-01-26 20:10:16 +00:00
Copyright ( c ) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
2023-02-03 01:26:45 +00:00
import curses
2023-01-26 20:10:16 +00:00
import sys
from argparse import Namespace
from pathlib import Path
from typing import List , Union
import npyscreen
2023-02-05 23:35:01 +00:00
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from npyscreen import widget
2023-01-22 23:07:53 +00:00
from omegaconf import OmegaConf
2023-04-29 13:43:40 +00:00
import invokeai . backend . util . logging as logger
2023-07-06 16:21:42 +00:00
from invokeai . app . services . config import InvokeAIAppConfig
from invokeai . backend . model_management import (
ModelMerger ,
MergeInterpolationMethod ,
ModelManager ,
ModelType ,
BaseModelType ,
2023-07-06 00:25:47 +00:00
)
2023-07-06 16:21:42 +00:00
from invokeai . frontend . install . widgets import FloatTitleSlider , TextBox , SingleSelectColumns
2023-01-26 20:10:16 +00:00
2023-05-26 00:41:26 +00:00
config = InvokeAIAppConfig . get_config ( )
2023-03-03 06:02:00 +00:00
2023-07-27 14:54:01 +00:00
2023-01-26 20:10:16 +00:00
def _parse_args ( ) - > Namespace :
parser = argparse . ArgumentParser ( description = " InvokeAI model merging " )
parser . add_argument (
" --root_dir " ,
type = Path ,
2023-05-04 04:43:51 +00:00
default = config . root ,
2023-01-26 20:10:16 +00:00
help = " Path to the invokeai runtime directory " ,
2023-01-22 23:07:53 +00:00
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --front_end " ,
" --gui " ,
dest = " front_end " ,
action = " store_true " ,
default = False ,
help = " Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored. " ,
2023-01-22 23:07:53 +00:00
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --models " ,
2023-07-06 16:21:42 +00:00
dest = " model_names " ,
2023-01-26 20:10:16 +00:00
type = str ,
nargs = " + " ,
help = " Two to three model names to be merged " ,
)
2023-07-06 00:25:47 +00:00
parser . add_argument (
2023-07-06 16:21:42 +00:00
" --base_model " ,
2023-07-06 00:25:47 +00:00
type = str ,
choices = [ x . value for x in BaseModelType ] ,
help = " The base model shared by the models to be merged " ,
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --merged_model_name " ,
" --destination " ,
dest = " merged_model_name " ,
type = str ,
help = " Name of the output model. If not specified, will be the concatenation of the input model names. " ,
)
parser . add_argument (
" --alpha " ,
type = float ,
default = 0.5 ,
help = " The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models " ,
)
parser . add_argument (
" --interpolation " ,
dest = " interp " ,
type = str ,
choices = [ " weighted_sum " , " sigmoid " , " inv_sigmoid " , " add_difference " ] ,
default = " weighted_sum " ,
help = ' Interpolation method to use. If three models are present, only " add_difference " will work. ' ,
)
parser . add_argument (
" --force " ,
action = " store_true " ,
help = " Try to merge models even if they are incompatible with each other " ,
)
parser . add_argument (
" --clobber " ,
" --overwrite " ,
dest = " clobber " ,
action = " store_true " ,
help = " Overwrite the merged model if --merged_model_name already exists " ,
)
return parser . parse_args ( )
2023-01-22 23:07:53 +00:00
2023-01-26 20:10:16 +00:00
# ------------------------- GUI HERE -------------------------
class mergeModelsForm ( npyscreen . FormMultiPageAction ) :
2023-02-17 19:46:26 +00:00
interpolations = [ " weighted_sum " , " sigmoid " , " inv_sigmoid " ]
2023-01-26 20:10:16 +00:00
def __init__ ( self , parentApp , name ) :
self . parentApp = parentApp
2023-02-05 23:35:01 +00:00
self . ALLOW_RESIZE = True
self . FIX_MINIMUM_SIZE_WHEN_CREATED = False
2023-01-26 20:10:16 +00:00
super ( ) . __init__ ( parentApp , name )
@property
def model_manager ( self ) :
return self . parentApp . model_manager
def afterEditing ( self ) :
self . parentApp . setNextForm ( None )
def create ( self ) :
2023-02-05 23:35:01 +00:00
window_height , window_width = curses . initscr ( ) . getmaxyx ( )
2023-01-26 20:10:16 +00:00
self . model_names = self . get_model_names ( )
2023-07-06 16:21:42 +00:00
self . current_base = 0
2023-02-03 01:26:45 +00:00
max_width = max ( [ len ( x ) for x in self . model_names ] )
max_width + = 6
2023-02-05 23:35:01 +00:00
horizontal_layout = max_width * 3 < window_width
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
color = " CONTROL " ,
2023-03-03 05:02:15 +00:00
value = " Select two models to merge and optionally a third. " ,
2023-02-03 01:26:45 +00:00
editable = False ,
)
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
color = " CONTROL " ,
2023-03-03 05:02:15 +00:00
value = " Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next. " ,
2023-02-03 01:26:45 +00:00
editable = False ,
)
2023-07-06 16:21:42 +00:00
self . nextrely + = 1
self . base_select = self . add_widget_intelligent (
SingleSelectColumns ,
values = [
" Models Built on SD-1.x " ,
" Models Built on SD-2.x " ,
] ,
value = [ self . current_base ] ,
columns = 4 ,
max_height = 2 ,
relx = 8 ,
scroll_exit = True ,
)
self . base_select . on_changed = self . _populate_models
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 1 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-07-06 16:21:42 +00:00
rely = 6 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
)
self . model1 = self . add_widget_intelligent (
npyscreen . SelectOne ,
values = self . model_names ,
value = 0 ,
max_height = len ( self . model_names ) ,
max_width = max_width ,
scroll_exit = True ,
2023-07-06 16:21:42 +00:00
rely = 7 ,
2023-02-03 01:26:45 +00:00
)
2023-01-26 20:10:16 +00:00
self . add_widget_intelligent (
2023-02-03 01:26:45 +00:00
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 2 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-02-05 23:35:01 +00:00
relx = max_width + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 6 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
)
2023-02-03 01:26:45 +00:00
self . model2 = self . add_widget_intelligent (
npyscreen . SelectOne ,
2023-02-05 23:35:01 +00:00
name = " (2) " ,
2023-01-26 20:10:16 +00:00
values = self . model_names ,
2023-02-03 01:26:45 +00:00
value = 1 ,
max_height = len ( self . model_names ) ,
max_width = max_width ,
2023-02-05 23:35:01 +00:00
relx = max_width + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 7 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 3 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-02-05 23:35:01 +00:00
relx = max_width * 2 + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 6 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
)
models_plus_none = self . model_names . copy ( )
2023-02-05 23:35:01 +00:00
models_plus_none . insert ( 0 , " None " )
2023-02-03 01:26:45 +00:00
self . model3 = self . add_widget_intelligent (
npyscreen . SelectOne ,
2023-02-05 23:35:01 +00:00
name = " (3) " ,
2023-02-03 01:26:45 +00:00
values = models_plus_none ,
value = 0 ,
2023-02-05 23:35:01 +00:00
max_height = len ( self . model_names ) + 1 ,
2023-02-03 01:26:45 +00:00
max_width = max_width ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
2023-02-05 23:35:01 +00:00
relx = max_width * 2 + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 7 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
)
2023-02-05 23:35:01 +00:00
for m in [ self . model1 , self . model2 , self . model3 ] :
2023-02-03 01:26:45 +00:00
m . when_value_edited = self . models_changed
2023-01-26 20:10:16 +00:00
self . merged_model_name = self . add_widget_intelligent (
2023-07-06 16:21:42 +00:00
TextBox ,
2023-01-26 20:10:16 +00:00
name = " Name for merged model: " ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-07-06 16:21:42 +00:00
max_height = 3 ,
2023-01-26 20:10:16 +00:00
value = " " ,
scroll_exit = True ,
)
self . force = self . add_widget_intelligent (
npyscreen . Checkbox ,
2023-07-06 16:21:42 +00:00
name = " Force merge of models created by different diffusers library versions " ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-07-06 16:21:42 +00:00
value = True ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-07-06 16:21:42 +00:00
self . nextrely + = 1
2023-01-26 20:10:16 +00:00
self . merge_method = self . add_widget_intelligent (
npyscreen . TitleSelectOne ,
name = " Merge Method: " ,
values = self . interpolations ,
value = 0 ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
max_height = len ( self . interpolations ) + 1 ,
scroll_exit = True ,
)
self . alpha = self . add_widget_intelligent (
FloatTitleSlider ,
name = " Weight (alpha) to assign to second and third models: " ,
2023-02-17 19:46:26 +00:00
out_of = 1.0 ,
step = 0.01 ,
2023-01-26 20:10:16 +00:00
lowest = 0 ,
value = 0.5 ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-02-03 01:26:45 +00:00
self . model1 . editing = True
2023-01-26 20:10:16 +00:00
def models_changed ( self ) :
2023-02-03 01:26:45 +00:00
models = self . model1 . values
selected_model1 = self . model1 . value [ 0 ]
selected_model2 = self . model2 . value [ 0 ]
selected_model3 = self . model3 . value [ 0 ]
2023-02-05 23:35:01 +00:00
merged_model_name = f " { models [ selected_model1 ] } + { models [ selected_model2 ] } "
2023-02-03 01:26:45 +00:00
self . merged_model_name . value = merged_model_name
2023-02-05 23:35:01 +00:00
2023-02-03 01:26:45 +00:00
if selected_model3 > 0 :
2023-03-03 06:02:00 +00:00
self . merge_method . values = [ " add_difference ( A+(B-C) ) " ]
self . merged_model_name . value + = f " + { models [ selected_model3 - 1 ] } " # In model3 there is one more element in the list (None). So we have to subtract one.
2023-01-26 20:10:16 +00:00
else :
2023-02-05 23:35:01 +00:00
self . merge_method . values = self . interpolations
self . merge_method . value = 0
2023-01-26 20:10:16 +00:00
def on_ok ( self ) :
if self . validate_field_values ( ) and self . check_for_overwrite ( ) :
self . parentApp . setNextForm ( None )
self . editing = False
self . parentApp . merge_arguments = self . marshall_arguments ( )
2023-02-05 23:35:01 +00:00
npyscreen . notify ( " Starting the merge... " )
2023-01-26 20:10:16 +00:00
else :
self . editing = True
def on_cancel ( self ) :
sys . exit ( 0 )
2023-02-05 23:35:01 +00:00
def marshall_arguments ( self ) - > dict :
2023-02-03 01:26:45 +00:00
model_names = self . model_names
models = [
model_names [ self . model1 . value [ 0 ] ] ,
model_names [ self . model2 . value [ 0 ] ] ,
2023-02-05 23:35:01 +00:00
]
2023-02-03 01:26:45 +00:00
if self . model3 . value [ 0 ] > 0 :
2023-02-05 23:35:01 +00:00
models . append ( model_names [ self . model3 . value [ 0 ] - 1 ] )
2023-03-03 06:02:00 +00:00
interp = " add_difference "
2023-02-17 20:42:06 +00:00
else :
2023-03-03 06:02:00 +00:00
interp = self . interpolations [ self . merge_method . value [ 0 ] ]
2023-02-03 01:26:45 +00:00
2023-01-26 20:10:16 +00:00
args = dict (
2023-07-06 16:21:42 +00:00
model_names = models ,
base_model = tuple ( BaseModelType ) [ self . base_select . value [ 0 ] ] ,
2023-02-05 23:35:01 +00:00
alpha = self . alpha . value ,
2023-02-17 20:42:06 +00:00
interp = interp ,
2023-02-05 23:35:01 +00:00
force = self . force . value ,
merged_model_name = self . merged_model_name . value ,
2023-01-26 20:10:16 +00:00
)
return args
def check_for_overwrite ( self ) - > bool :
model_out = self . merged_model_name . value
if model_out not in self . model_names :
return True
else :
return npyscreen . notify_yes_no (
f " The chosen merged model destination, { model_out } , is already in use. Overwrite? "
)
2023-02-05 23:35:01 +00:00
def validate_field_values ( self ) - > bool :
2023-01-26 20:10:16 +00:00
bad_fields = [ ]
2023-02-03 01:26:45 +00:00
model_names = self . model_names
2023-02-05 23:35:01 +00:00
selected_models = set ( ( model_names [ self . model1 . value [ 0 ] ] , model_names [ self . model2 . value [ 0 ] ] ) )
2023-02-03 01:26:45 +00:00
if self . model3 . value [ 0 ] > 0 :
2023-02-05 23:35:01 +00:00
selected_models . add ( model_names [ self . model3 . value [ 0 ] - 1 ] )
2023-02-03 01:26:45 +00:00
if len ( selected_models ) < 2 :
2023-02-05 23:35:01 +00:00
bad_fields . append ( f " Please select two or three DIFFERENT models to compare. You selected { selected_models } " )
2023-01-26 20:10:16 +00:00
if len ( bad_fields ) > 0 :
2023-02-05 23:35:01 +00:00
message = " The following problems were detected and must be corrected: "
2023-01-26 20:10:16 +00:00
for problem in bad_fields :
2023-02-05 23:35:01 +00:00
message + = f " \n * { problem } "
2023-01-26 20:10:16 +00:00
npyscreen . notify_confirm ( message )
return False
else :
return True
2023-07-06 16:21:42 +00:00
def get_model_names ( self , base_model : BaseModelType = None ) - > List [ str ] :
2023-01-26 20:10:16 +00:00
model_names = [
2023-07-30 22:35:43 +00:00
info [ " model_name " ]
2023-07-06 16:21:42 +00:00
for info in self . model_manager . list_models ( model_type = ModelType . Main , base_model = base_model )
if info [ " model_format " ] == " diffusers "
2023-01-26 20:10:16 +00:00
]
return sorted ( model_names )
2023-07-06 16:21:42 +00:00
def _populate_models ( self , value = None ) :
base_model = tuple ( BaseModelType ) [ value [ 0 ] ]
self . model_names = self . get_model_names ( base_model )
2023-07-27 14:54:01 +00:00
2023-07-06 16:21:42 +00:00
models_plus_none = self . model_names . copy ( )
models_plus_none . insert ( 0 , " None " )
self . model1 . values = self . model_names
self . model2 . values = self . model_names
self . model3 . values = models_plus_none
2023-07-27 14:54:01 +00:00
2023-07-06 16:21:42 +00:00
self . display ( )
2023-07-27 14:54:01 +00:00
2023-01-26 20:10:16 +00:00
class Mergeapp ( npyscreen . NPSAppManaged ) :
2023-07-06 16:21:42 +00:00
def __init__ ( self , model_manager : ModelManager ) :
2023-01-26 20:10:16 +00:00
super ( ) . __init__ ( )
2023-07-06 16:21:42 +00:00
self . model_manager = model_manager
2023-01-26 20:10:16 +00:00
def onStart ( self ) :
2023-02-03 01:26:45 +00:00
npyscreen . setTheme ( npyscreen . Themes . ElegantTheme )
2023-01-26 20:10:16 +00:00
self . main = self . addForm ( " MAIN " , mergeModelsForm , name = " Merge Models Settings " )
2023-02-05 23:35:01 +00:00
2023-01-26 20:10:16 +00:00
def run_gui ( args : Namespace ) :
2023-07-06 16:21:42 +00:00
model_manager = ModelManager ( config . model_conf_path )
mergeapp = Mergeapp ( model_manager )
2023-01-26 20:10:16 +00:00
mergeapp . run ( )
args = mergeapp . merge_arguments
2023-07-06 16:21:42 +00:00
merger = ModelMerger ( model_manager )
merger . merge_diffusion_models_and_save ( * * args )
2023-04-29 13:43:40 +00:00
logger . info ( f ' Models merged into new model: " { args [ " merged_model_name " ] } " . ' )
2023-01-26 20:10:16 +00:00
def run_cli ( args : Namespace ) :
assert args . alpha > = 0 and args . alpha < = 1.0 , " alpha must be between 0 and 1 "
assert (
2023-07-06 16:21:42 +00:00
args . model_names and len ( args . model_names ) > = 1 and len ( args . model_names ) < = 3
2023-02-03 01:26:45 +00:00
) , " Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage. "
2023-01-26 20:10:16 +00:00
if not args . merged_model_name :
2023-07-06 16:21:42 +00:00
args . merged_model_name = " + " . join ( args . model_names )
2023-04-19 00:49:00 +00:00
logger . info ( f ' No --merged_model_name provided. Defaulting to " { args . merged_model_name } " ' )
2023-01-26 20:10:16 +00:00
2023-07-06 16:21:42 +00:00
model_manager = ModelManager ( config . model_conf_path )
assert (
not model_manager . model_exists ( args . merged_model_name , args . base_model , ModelType . Main ) or args . clobber
) , f ' A model named " { args . merged_model_name } " already exists. Use --clobber to overwrite. '
2023-01-26 20:10:16 +00:00
2023-07-06 16:21:42 +00:00
merger = ModelMerger ( model_manager )
merger . merge_diffusion_models_and_save ( * * vars ( args ) )
logger . info ( f ' Models merged into new model: " { args . merged_model_name } " . ' )
2023-01-26 20:10:16 +00:00
def main ( ) :
args = _parse_args ( )
2023-08-11 00:59:22 +00:00
if args . root_dir :
config . parse_args ( [ " --root " , str ( args . root_dir ) ] )
2023-01-26 20:10:16 +00:00
2023-02-05 23:35:01 +00:00
try :
if args . front_end :
run_gui ( args )
else :
run_cli ( args )
except widget . NotEnoughSpaceForWidget as e :
if str ( e ) . startswith ( " Height of 1 allocated " ) :
2023-04-19 00:49:00 +00:00
logger . error ( " You need to have at least two diffusers models defined in models.yaml in order to merge " )
2023-02-05 23:35:01 +00:00
else :
2023-04-19 00:49:00 +00:00
logger . error ( " Not enough room for the user interface. Try making this window larger. " )
2023-02-05 23:35:01 +00:00
sys . exit ( - 1 )
2023-04-19 00:49:00 +00:00
except Exception as e :
2023-04-29 13:43:40 +00:00
logger . error ( e )
2023-02-05 23:35:01 +00:00
sys . exit ( - 1 )
except KeyboardInterrupt :
sys . exit ( - 1 )
2023-01-26 20:10:16 +00:00
if __name__ == " __main__ " :
main ( )