mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up merge_models
This commit is contained in:
@ -1,91 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.invoke.globals import (Globals, global_cache_dir, global_config_file,
|
||||
global_set_root)
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
parser = argparse.ArgumentParser(description="InvokeAI textual inversion training")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
required=True,
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest='clobber',
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert len(args.models) >= 1 and len(args.models) <= 3, "provide 2 or 3 models to merge"
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
assert (args.clobber or args.merged_model_name not in model_manager.model_names()), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
# It seems that the merge pipeline is not honoring cache_dir, so we set the
|
||||
# HF_HOME environment variable here *before* we load diffusers.
|
||||
cache_dir = str(global_cache_dir("diffusers"))
|
||||
os.environ["HF_HOME"] = cache_dir
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
|
||||
try:
|
||||
merge_diffusion_models(**vars(args))
|
||||
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||
except Exception as e:
|
||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
sys.exit(-1)
|
@ -1,217 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import argparse
|
||||
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir, global_config_file
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" %(self.value, self.out_of)
|
||||
l = (len(str(self.out_of)))*2+4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
interpolations = ['weighted_sum',
|
||||
'sigmoid',
|
||||
'inv_sigmoid',
|
||||
'add_difference']
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names = self.get_model_names()
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Select up to three models to merge",
|
||||
value=''
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='First Model:',
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Second Model:',
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0,'None')
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Third Model:',
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1,
|
||||
)
|
||||
|
||||
for m in [self.model1,self.model2,self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Merge Method:',
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
max_height=len(self.interpolations),
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name='Weight (alpha) to assign to second and third models:',
|
||||
out_of=1,
|
||||
step=0.05,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name='Force merge of incompatible models',
|
||||
value=False,
|
||||
)
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Name for merged model',
|
||||
value='',
|
||||
)
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values=['add_difference'],
|
||||
self.merged_model_name.value += f'+{models[selected_model3]}'
|
||||
else:
|
||||
self.merge_method.values=self.interpolations
|
||||
self.merge_method.value=0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify('Starting the merge...')
|
||||
import ldm.invoke.merge_diffusers # this keeps the message up while diffusers loads
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self)->dict:
|
||||
model_names = self.model_names
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0]-1])
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha = self.alpha.value,
|
||||
interp = self.interpolations[self.merge_method.value[0]],
|
||||
force = self.force.value,
|
||||
merged_model_name = self.merged_model_name.value,
|
||||
)
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self)->bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(f'The chosen merged model destination, {model_out}, is already in use. Overwrite?')
|
||||
|
||||
def validate_field_values(self)->bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0]-1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
|
||||
if len(bad_fields) > 0:
|
||||
message = 'The following problems were detected and must be corrected:'
|
||||
for problem in bad_fields:
|
||||
message += f'\n* {problem}'
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self)->List[str]:
|
||||
model_names = [name for name in self.model_manager.model_names() if self.model_manager.model_info(name).get('format') == 'diffusers']
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm('MAIN', mergeModelsForm, name='Merge Models Settings')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||
parser.add_argument(
|
||||
'--root_dir','--root-dir',
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help='Path to the invokeai runtime directory',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
cache_dir = str(global_cache_dir('diffusers')) # because not clear the merge pipeline is honoring cache_dir
|
||||
os.environ['HF_HOME'] = cache_dir
|
||||
|
||||
mergeapp = Mergeapp()
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
args.update(cache_dir = cache_dir)
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
|
||||
try:
|
||||
merge_diffusion_models(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
except Exception as e:
|
||||
print(f'** An error occurred while merging the pipelines: {str(e)}')
|
||||
print('** DETAILS:')
|
||||
print(traceback.format_exc())
|
||||
sys.exit(-1)
|
Reference in New Issue
Block a user