2023-01-26 20:10:16 +00:00
"""
2023-03-03 06:02:00 +00:00
invokeai . frontend . merge exports a single function call merge_diffusion_models ( )
2023-01-22 23:07:53 +00:00
used to merge 2 - 3 models together and create a new InvokeAI - registered diffusion model .
2023-01-26 20:10:16 +00:00
Copyright ( c ) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
2023-02-03 01:26:45 +00:00
import curses
2024-02-16 03:41:29 +00:00
import re
2023-01-26 20:10:16 +00:00
import sys
from argparse import Namespace
from pathlib import Path
2024-02-16 03:41:29 +00:00
from typing import List , Optional , Tuple
2023-01-26 20:10:16 +00:00
import npyscreen
2023-02-05 23:35:01 +00:00
from npyscreen import widget
2023-01-22 23:07:53 +00:00
2023-07-06 16:21:42 +00:00
from invokeai . app . services . config import InvokeAIAppConfig
2024-02-16 03:41:29 +00:00
from invokeai . app . services . download import DownloadQueueService
from invokeai . app . services . image_files . image_files_disk import DiskImageFileStorage
from invokeai . app . services . model_install import ModelInstallService
from invokeai . app . services . model_metadata import ModelMetadataStoreSQL
from invokeai . app . services . model_records import ModelRecordServiceBase , ModelRecordServiceSQL
from invokeai . app . services . shared . sqlite . sqlite_util import init_db
from invokeai . backend . model_manager import (
BaseModelType ,
ModelFormat ,
ModelType ,
ModelVariantType ,
)
from invokeai . backend . model_manager . merge import ModelMerger
from invokeai . backend . util . logging import InvokeAILogger
2023-08-18 14:57:18 +00:00
from invokeai . frontend . install . widgets import FloatTitleSlider , SingleSelectColumns , TextBox
2023-01-26 20:10:16 +00:00
2023-05-26 00:41:26 +00:00
config = InvokeAIAppConfig . get_config ( )
2024-02-16 03:41:29 +00:00
logger = InvokeAILogger . get_logger ( )
BASE_TYPES = [
( BaseModelType . StableDiffusion1 , " Models Built on SD-1.x " ) ,
( BaseModelType . StableDiffusion2 , " Models Built on SD-2.x " ) ,
( BaseModelType . StableDiffusionXL , " Models Built on SDXL " ) ,
]
2023-03-03 06:02:00 +00:00
2023-07-27 14:54:01 +00:00
2023-01-26 20:10:16 +00:00
def _parse_args ( ) - > Namespace :
parser = argparse . ArgumentParser ( description = " InvokeAI model merging " )
parser . add_argument (
" --root_dir " ,
type = Path ,
2023-05-04 04:43:51 +00:00
default = config . root ,
2023-01-26 20:10:16 +00:00
help = " Path to the invokeai runtime directory " ,
2023-01-22 23:07:53 +00:00
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --front_end " ,
" --gui " ,
dest = " front_end " ,
action = " store_true " ,
default = False ,
help = " Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored. " ,
2023-01-22 23:07:53 +00:00
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --models " ,
2023-07-06 16:21:42 +00:00
dest = " model_names " ,
2023-01-26 20:10:16 +00:00
type = str ,
nargs = " + " ,
help = " Two to three model names to be merged " ,
)
2023-07-06 00:25:47 +00:00
parser . add_argument (
2023-07-06 16:21:42 +00:00
" --base_model " ,
2023-07-06 00:25:47 +00:00
type = str ,
2024-02-16 03:41:29 +00:00
choices = [ x [ 0 ] . value for x in BASE_TYPES ] ,
2023-07-06 00:25:47 +00:00
help = " The base model shared by the models to be merged " ,
)
2023-01-26 20:10:16 +00:00
parser . add_argument (
" --merged_model_name " ,
" --destination " ,
dest = " merged_model_name " ,
type = str ,
help = " Name of the output model. If not specified, will be the concatenation of the input model names. " ,
)
parser . add_argument (
" --alpha " ,
type = float ,
default = 0.5 ,
help = " The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models " ,
)
parser . add_argument (
" --interpolation " ,
dest = " interp " ,
type = str ,
choices = [ " weighted_sum " , " sigmoid " , " inv_sigmoid " , " add_difference " ] ,
default = " weighted_sum " ,
help = ' Interpolation method to use. If three models are present, only " add_difference " will work. ' ,
)
parser . add_argument (
" --force " ,
action = " store_true " ,
help = " Try to merge models even if they are incompatible with each other " ,
)
parser . add_argument (
" --clobber " ,
" --overwrite " ,
dest = " clobber " ,
action = " store_true " ,
help = " Overwrite the merged model if --merged_model_name already exists " ,
)
return parser . parse_args ( )
2023-01-22 23:07:53 +00:00
2023-01-26 20:10:16 +00:00
# ------------------------- GUI HERE -------------------------
class mergeModelsForm ( npyscreen . FormMultiPageAction ) :
2023-02-17 19:46:26 +00:00
interpolations = [ " weighted_sum " , " sigmoid " , " inv_sigmoid " ]
2023-01-26 20:10:16 +00:00
def __init__ ( self , parentApp , name ) :
self . parentApp = parentApp
2023-02-05 23:35:01 +00:00
self . ALLOW_RESIZE = True
self . FIX_MINIMUM_SIZE_WHEN_CREATED = False
2023-01-26 20:10:16 +00:00
super ( ) . __init__ ( parentApp , name )
@property
2024-02-16 03:41:29 +00:00
def record_store ( self ) :
return self . parentApp . record_store
2023-01-26 20:10:16 +00:00
def afterEditing ( self ) :
self . parentApp . setNextForm ( None )
def create ( self ) :
2023-02-05 23:35:01 +00:00
window_height , window_width = curses . initscr ( ) . getmaxyx ( )
2023-07-06 16:21:42 +00:00
self . current_base = 0
2024-02-16 03:41:29 +00:00
self . models = self . get_models ( BASE_TYPES [ self . current_base ] [ 0 ] )
self . model_names = [ x [ 1 ] for x in self . models ]
2023-02-03 01:26:45 +00:00
max_width = max ( [ len ( x ) for x in self . model_names ] )
max_width + = 6
2023-02-05 23:35:01 +00:00
horizontal_layout = max_width * 3 < window_width
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
color = " CONTROL " ,
2023-03-03 05:02:15 +00:00
value = " Select two models to merge and optionally a third. " ,
2023-02-03 01:26:45 +00:00
editable = False ,
)
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
color = " CONTROL " ,
2023-03-03 05:02:15 +00:00
value = " Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next. " ,
2023-02-03 01:26:45 +00:00
editable = False ,
)
2023-07-06 16:21:42 +00:00
self . nextrely + = 1
self . base_select = self . add_widget_intelligent (
SingleSelectColumns ,
2024-02-16 03:41:29 +00:00
values = [ x [ 1 ] for x in BASE_TYPES ] ,
2023-07-06 16:21:42 +00:00
value = [ self . current_base ] ,
columns = 4 ,
max_height = 2 ,
relx = 8 ,
scroll_exit = True ,
)
self . base_select . on_changed = self . _populate_models
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 1 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-07-06 16:21:42 +00:00
rely = 6 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
)
self . model1 = self . add_widget_intelligent (
npyscreen . SelectOne ,
values = self . model_names ,
value = 0 ,
max_height = len ( self . model_names ) ,
max_width = max_width ,
scroll_exit = True ,
2023-07-06 16:21:42 +00:00
rely = 7 ,
2023-02-03 01:26:45 +00:00
)
2023-01-26 20:10:16 +00:00
self . add_widget_intelligent (
2023-02-03 01:26:45 +00:00
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 2 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-02-05 23:35:01 +00:00
relx = max_width + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 6 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
)
2023-02-03 01:26:45 +00:00
self . model2 = self . add_widget_intelligent (
npyscreen . SelectOne ,
2023-02-05 23:35:01 +00:00
name = " (2) " ,
2023-01-26 20:10:16 +00:00
values = self . model_names ,
2023-02-03 01:26:45 +00:00
value = 1 ,
max_height = len ( self . model_names ) ,
max_width = max_width ,
2023-02-05 23:35:01 +00:00
relx = max_width + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 7 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-02-03 01:26:45 +00:00
self . add_widget_intelligent (
npyscreen . FixedText ,
2023-02-05 23:35:01 +00:00
value = " MODEL 3 " ,
color = " GOOD " ,
2023-02-03 01:26:45 +00:00
editable = False ,
2023-02-05 23:35:01 +00:00
relx = max_width * 2 + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 6 if horizontal_layout else None ,
2023-02-03 01:26:45 +00:00
)
models_plus_none = self . model_names . copy ( )
2023-02-05 23:35:01 +00:00
models_plus_none . insert ( 0 , " None " )
2023-02-03 01:26:45 +00:00
self . model3 = self . add_widget_intelligent (
npyscreen . SelectOne ,
2023-02-05 23:35:01 +00:00
name = " (3) " ,
2023-02-03 01:26:45 +00:00
values = models_plus_none ,
value = 0 ,
2023-02-05 23:35:01 +00:00
max_height = len ( self . model_names ) + 1 ,
2023-02-03 01:26:45 +00:00
max_width = max_width ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
2023-02-05 23:35:01 +00:00
relx = max_width * 2 + 3 if horizontal_layout else None ,
2023-07-06 16:21:42 +00:00
rely = 7 if horizontal_layout else None ,
2023-01-26 20:10:16 +00:00
)
2023-02-05 23:35:01 +00:00
for m in [ self . model1 , self . model2 , self . model3 ] :
2023-02-03 01:26:45 +00:00
m . when_value_edited = self . models_changed
2023-01-26 20:10:16 +00:00
self . merged_model_name = self . add_widget_intelligent (
2023-07-06 16:21:42 +00:00
TextBox ,
2023-01-26 20:10:16 +00:00
name = " Name for merged model: " ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-07-06 16:21:42 +00:00
max_height = 3 ,
2023-01-26 20:10:16 +00:00
value = " " ,
scroll_exit = True ,
)
self . force = self . add_widget_intelligent (
npyscreen . Checkbox ,
2023-07-06 16:21:42 +00:00
name = " Force merge of models created by different diffusers library versions " ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-07-06 16:21:42 +00:00
value = True ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-07-06 16:21:42 +00:00
self . nextrely + = 1
2023-01-26 20:10:16 +00:00
self . merge_method = self . add_widget_intelligent (
npyscreen . TitleSelectOne ,
name = " Merge Method: " ,
values = self . interpolations ,
value = 0 ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
max_height = len ( self . interpolations ) + 1 ,
scroll_exit = True ,
)
self . alpha = self . add_widget_intelligent (
FloatTitleSlider ,
name = " Weight (alpha) to assign to second and third models: " ,
2023-02-17 19:46:26 +00:00
out_of = 1.0 ,
step = 0.01 ,
2023-01-26 20:10:16 +00:00
lowest = 0 ,
value = 0.5 ,
2023-02-05 23:35:01 +00:00
labelColor = " CONTROL " ,
2023-01-26 20:10:16 +00:00
scroll_exit = True ,
)
2023-02-03 01:26:45 +00:00
self . model1 . editing = True
2023-01-26 20:10:16 +00:00
def models_changed ( self ) :
2023-02-03 01:26:45 +00:00
models = self . model1 . values
selected_model1 = self . model1 . value [ 0 ]
selected_model2 = self . model2 . value [ 0 ]
selected_model3 = self . model3 . value [ 0 ]
2023-02-05 23:35:01 +00:00
merged_model_name = f " { models [ selected_model1 ] } + { models [ selected_model2 ] } "
2023-02-03 01:26:45 +00:00
self . merged_model_name . value = merged_model_name
2023-02-05 23:35:01 +00:00
2023-02-03 01:26:45 +00:00
if selected_model3 > 0 :
2023-03-03 06:02:00 +00:00
self . merge_method . values = [ " add_difference ( A+(B-C) ) " ]
self . merged_model_name . value + = f " + { models [ selected_model3 - 1 ] } " # In model3 there is one more element in the list (None). So we have to subtract one.
2023-01-26 20:10:16 +00:00
else :
2023-02-05 23:35:01 +00:00
self . merge_method . values = self . interpolations
self . merge_method . value = 0
2023-01-26 20:10:16 +00:00
def on_ok ( self ) :
if self . validate_field_values ( ) and self . check_for_overwrite ( ) :
self . parentApp . setNextForm ( None )
self . editing = False
self . parentApp . merge_arguments = self . marshall_arguments ( )
2023-02-05 23:35:01 +00:00
npyscreen . notify ( " Starting the merge... " )
2023-01-26 20:10:16 +00:00
else :
self . editing = True
def on_cancel ( self ) :
sys . exit ( 0 )
2023-02-05 23:35:01 +00:00
def marshall_arguments ( self ) - > dict :
2024-02-16 03:41:29 +00:00
model_keys = [ x [ 0 ] for x in self . models ]
2023-02-03 01:26:45 +00:00
models = [
2024-02-16 03:41:29 +00:00
model_keys [ self . model1 . value [ 0 ] ] ,
model_keys [ self . model2 . value [ 0 ] ] ,
2023-02-05 23:35:01 +00:00
]
2023-02-03 01:26:45 +00:00
if self . model3 . value [ 0 ] > 0 :
2024-02-16 03:41:29 +00:00
models . append ( model_keys [ self . model3 . value [ 0 ] - 1 ] )
2023-03-03 06:02:00 +00:00
interp = " add_difference "
2023-02-17 20:42:06 +00:00
else :
2023-03-03 06:02:00 +00:00
interp = self . interpolations [ self . merge_method . value [ 0 ] ]
2023-02-03 01:26:45 +00:00
2023-11-10 23:44:43 +00:00
args = {
2024-02-16 03:41:29 +00:00
" model_keys " : models ,
" base_model " : tuple ( BaseModelType ) [ self . base_select . value [ 0 ] ] ,
2023-11-10 23:44:43 +00:00
" alpha " : self . alpha . value ,
" interp " : interp ,
" force " : self . force . value ,
" merged_model_name " : self . merged_model_name . value ,
}
2023-01-26 20:10:16 +00:00
return args
def check_for_overwrite ( self ) - > bool :
model_out = self . merged_model_name . value
if model_out not in self . model_names :
return True
else :
return npyscreen . notify_yes_no (
f " The chosen merged model destination, { model_out } , is already in use. Overwrite? "
)
2023-02-05 23:35:01 +00:00
def validate_field_values ( self ) - > bool :
2023-01-26 20:10:16 +00:00
bad_fields = [ ]
2023-02-03 01:26:45 +00:00
model_names = self . model_names
2023-11-10 23:44:43 +00:00
selected_models = { model_names [ self . model1 . value [ 0 ] ] , model_names [ self . model2 . value [ 0 ] ] }
2023-02-03 01:26:45 +00:00
if self . model3 . value [ 0 ] > 0 :
2023-02-05 23:35:01 +00:00
selected_models . add ( model_names [ self . model3 . value [ 0 ] - 1 ] )
2023-02-03 01:26:45 +00:00
if len ( selected_models ) < 2 :
2023-02-05 23:35:01 +00:00
bad_fields . append ( f " Please select two or three DIFFERENT models to compare. You selected { selected_models } " )
2023-01-26 20:10:16 +00:00
if len ( bad_fields ) > 0 :
2023-02-05 23:35:01 +00:00
message = " The following problems were detected and must be corrected: "
2023-01-26 20:10:16 +00:00
for problem in bad_fields :
2023-02-05 23:35:01 +00:00
message + = f " \n * { problem } "
2023-01-26 20:10:16 +00:00
npyscreen . notify_confirm ( message )
return False
else :
return True
2024-02-16 03:41:29 +00:00
def get_models ( self , base_model : Optional [ BaseModelType ] = None ) - > List [ Tuple [ str , str ] ] : # key to name
models = [
( x . key , x . name )
for x in self . record_store . search_by_attr ( model_type = ModelType . Main , base_model = base_model )
if x . format == ModelFormat ( " diffusers " ) and x . variant == ModelVariantType ( " normal " )
2023-01-26 20:10:16 +00:00
]
2024-02-16 03:41:29 +00:00
return sorted ( models , key = lambda x : x [ 1 ] )
2023-01-26 20:10:16 +00:00
2024-02-16 03:41:29 +00:00
def _populate_models ( self , value : List [ int ] ) :
base_model = BASE_TYPES [ value [ 0 ] ] [ 0 ]
self . models = self . get_models ( base_model )
self . model_names = [ x [ 1 ] for x in self . models ]
2023-07-27 14:54:01 +00:00
2023-07-06 16:21:42 +00:00
models_plus_none = self . model_names . copy ( )
models_plus_none . insert ( 0 , " None " )
self . model1 . values = self . model_names
self . model2 . values = self . model_names
self . model3 . values = models_plus_none
2023-07-27 14:54:01 +00:00
2023-07-06 16:21:42 +00:00
self . display ( )
2023-07-27 14:54:01 +00:00
2023-01-26 20:10:16 +00:00
class Mergeapp ( npyscreen . NPSAppManaged ) :
2024-02-16 03:41:29 +00:00
def __init__ ( self , record_store : ModelRecordServiceBase ) :
2023-01-26 20:10:16 +00:00
super ( ) . __init__ ( )
2024-02-16 03:41:29 +00:00
self . record_store = record_store
2023-01-26 20:10:16 +00:00
def onStart ( self ) :
2023-02-03 01:26:45 +00:00
npyscreen . setTheme ( npyscreen . Themes . ElegantTheme )
2023-01-26 20:10:16 +00:00
self . main = self . addForm ( " MAIN " , mergeModelsForm , name = " Merge Models Settings " )
2023-02-05 23:35:01 +00:00
2024-02-16 03:41:29 +00:00
def run_gui ( args : Namespace ) - > None :
record_store : ModelRecordServiceBase = get_config_store ( )
mergeapp = Mergeapp ( record_store )
2023-01-26 20:10:16 +00:00
mergeapp . run ( )
args = mergeapp . merge_arguments
2024-02-16 03:41:29 +00:00
merger = get_model_merger ( record_store )
2023-07-06 16:21:42 +00:00
merger . merge_diffusion_models_and_save ( * * args )
2024-02-16 03:41:29 +00:00
merged_model_name = args [ " merged_model_name " ]
logger . info ( f ' Models merged into new model: " { merged_model_name } " . ' )
2023-01-26 20:10:16 +00:00
def run_cli ( args : Namespace ) :
assert args . alpha > = 0 and args . alpha < = 1.0 , " alpha must be between 0 and 1 "
assert (
2023-07-06 16:21:42 +00:00
args . model_names and len ( args . model_names ) > = 1 and len ( args . model_names ) < = 3
2023-02-03 01:26:45 +00:00
) , " Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage. "
2023-01-26 20:10:16 +00:00
if not args . merged_model_name :
2023-07-06 16:21:42 +00:00
args . merged_model_name = " + " . join ( args . model_names )
2023-04-19 00:49:00 +00:00
logger . info ( f ' No --merged_model_name provided. Defaulting to " { args . merged_model_name } " ' )
2023-01-26 20:10:16 +00:00
2024-02-16 03:41:29 +00:00
record_store : ModelRecordServiceBase = get_config_store ( )
2023-07-06 16:21:42 +00:00
assert (
2024-02-16 03:41:29 +00:00
len ( record_store . search_by_attr ( args . merged_model_name , args . base_model , ModelType . Main ) ) == 0 or args . clobber
2023-07-06 16:21:42 +00:00
) , f ' A model named " { args . merged_model_name } " already exists. Use --clobber to overwrite. '
2023-01-26 20:10:16 +00:00
2024-02-16 03:41:29 +00:00
merger = get_model_merger ( record_store )
model_keys = [ ]
for name in args . model_names :
if len ( name ) == 32 and re . match ( r " ^[0-9a-f]$ " , name ) :
model_keys . append ( name )
else :
models = record_store . search_by_attr (
model_name = name , model_type = ModelType . Main , base_model = BaseModelType ( args . base_model )
)
assert len ( models ) > 0 , f " { name } : Unknown model "
assert len ( models ) < 2 , f " { name } : More than one model by this name. Please specify the model key instead. "
model_keys . append ( models [ 0 ] . key )
merger . merge_diffusion_models_and_save (
alpha = args . alpha ,
model_keys = model_keys ,
merged_model_name = args . merged_model_name ,
interp = args . interp ,
force = args . force ,
)
2023-07-06 16:21:42 +00:00
logger . info ( f ' Models merged into new model: " { args . merged_model_name } " . ' )
2023-01-26 20:10:16 +00:00
2024-02-16 03:41:29 +00:00
def get_config_store ( ) - > ModelRecordServiceSQL :
output_path = config . output_path
assert output_path is not None
image_files = DiskImageFileStorage ( output_path / " images " )
db = init_db ( config = config , logger = InvokeAILogger . get_logger ( ) , image_files = image_files )
return ModelRecordServiceSQL ( db , ModelMetadataStoreSQL ( db ) )
def get_model_merger ( record_store : ModelRecordServiceBase ) - > ModelMerger :
installer = ModelInstallService ( app_config = config , record_store = record_store , download_queue = DownloadQueueService ( ) )
installer . start ( )
return ModelMerger ( installer )
2023-01-26 20:10:16 +00:00
def main ( ) :
args = _parse_args ( )
2023-08-11 00:59:22 +00:00
if args . root_dir :
config . parse_args ( [ " --root " , str ( args . root_dir ) ] )
2024-02-16 03:41:29 +00:00
else :
config . parse_args ( [ ] )
2023-01-26 20:10:16 +00:00
2023-02-05 23:35:01 +00:00
try :
if args . front_end :
run_gui ( args )
else :
run_cli ( args )
except widget . NotEnoughSpaceForWidget as e :
if str ( e ) . startswith ( " Height of 1 allocated " ) :
2024-02-21 15:18:30 +00:00
logger . error ( " You need to have at least two diffusers models in order to merge " )
2023-02-05 23:35:01 +00:00
else :
2023-04-19 00:49:00 +00:00
logger . error ( " Not enough room for the user interface. Try making this window larger. " )
2023-02-05 23:35:01 +00:00
sys . exit ( - 1 )
2023-04-19 00:49:00 +00:00
except Exception as e :
2023-04-29 13:43:40 +00:00
logger . error ( e )
2023-02-05 23:35:01 +00:00
sys . exit ( - 1 )
except KeyboardInterrupt :
sys . exit ( - 1 )
2023-01-26 20:10:16 +00:00
if __name__ == " __main__ " :
main ( )