model merge backend, CLI and TUI working

This commit is contained in:
Lincoln Stein
2023-07-06 12:21:42 -04:00
parent f7daa6e71d
commit ec7c2f07c6
5 changed files with 174 additions and 135 deletions

View File

@ -1,4 +1,5 @@
"""
Initialization file for invokeai.frontend.merge
"""
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
from .merge_diffusers import main as invokeai_merge_diffusers

View File

@ -6,10 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import enum
import os
import sys
import warnings
from argparse import Namespace
from pathlib import Path
from typing import List, Union
@ -21,12 +18,12 @@ from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.services.config import InvokeAIAppConfig
from ...backend.model_management import (
merge_diffusion_models_and_save,
ModelManager, MergeInterpolationMethod, BaseModelType
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import (
ModelMerger, MergeInterpolationMethod,
ModelManager, ModelType, BaseModelType,
)
from ...frontend.install.widgets import FloatTitleSlider
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
config = InvokeAIAppConfig.get_config()
@ -48,12 +45,13 @@ def _parse_args() -> Namespace:
)
parser.add_argument(
"--models",
dest="model_names",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--base_type",
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
help="The base model shared by the models to be merged",
@ -115,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -131,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
'Models Built on SD-1.x',
'Models Built on SD-2.x',
],
value=[self.current_base],
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 1",
color="GOOD",
editable=False,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -145,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
rely=7,
)
self.add_widget_intelligent(
npyscreen.FixedText,
@ -153,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -163,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
@ -172,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -185,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_width=max_width,
scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
)
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
TextBox,
name="Name for merged model:",
labelColor="CONTROL",
max_height=3,
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL",
value=False,
value=True,
scroll_exit=True,
)
self.nextrely += 1
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
@ -264,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]]
args = dict(
models=models,
model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value,
interp=interp,
force=self.force.value,
@ -302,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self) -> List[str]:
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
model_names = [
name
for name in self.model_manager.model_names()
if self.model_manager.model_info(name).get("format") == "diffusers"
info["name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
]
return sorted(model_names)
def _populate_models(self,value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
def __init__(self, model_manager:ModelManager):
super().__init__()
conf = OmegaConf.load(config.model_conf_path)
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
self.model_manager = model_manager
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
@ -324,38 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace):
mergeapp = Mergeapp()
model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run()
args = mergeapp.merge_arguments
merge_diffusion_models_and_save(**args)
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
args.models and len(args.models) >= 1 and len(args.models) <= 3
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
args.merged_model_name = "+".join(args.model_names)
logger.info(
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
model_manager = ModelManager(config.model_conf_path)
assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main():
args = _parse_args()
config.parse_args(['--root',args.root_dir])
config.parse_args(['--root',str(args.root_dir)])
try:
if args.front_end: