model merge backend, CLI and TUI working

This commit is contained in:
Lincoln Stein 2023-07-06 12:21:42 -04:00
parent f7daa6e71d
commit ec7c2f07c6
5 changed files with 174 additions and 135 deletions

View File

@ -4,5 +4,5 @@ Initialization file for invokeai.backend.model_management
from .model_manager import ModelManager, ModelInfo, AddModelResult
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .model_merge import merge_diffusion_models_and_save, MergeInterpolationMethod
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -279,7 +279,7 @@ class InvalidModelError(Exception):
pass
class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import")
name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")

View File

@ -11,19 +11,25 @@ from enum import Enum
from pathlib import Path
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from typing import List
from typing import List, Union
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum):
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
WeightedSum = "weighted_sum"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
@ -62,15 +68,15 @@ def merge_diffusion_models(
def merge_diffusion_models_and_save (
models: List["str"],
base_model: BaseModelType,
self,
model_names: List[str],
base_model: Union[BaseModelType,str],
merged_model_name: str,
config: InvokeAIAppConfig,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
):
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
@ -84,34 +90,38 @@ def merge_diffusion_models_and_save (
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_manager = ModelManager(config.model_conf_path)
model_paths = list()
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in models:
info = model_manager.model_info(mod, base_model=base_model, model_type=ModelType.main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
if mod == models[0]:
vae = info["vae"]
model_paths.extend([info["path"]])
merged_pipe = merge_diffusion_models(
model_paths, alpha, interp, force, **kwargs
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)
dump_path = config.models_path / base_model.value / ModelType.main.value
dump_path = config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path = dump_path,
description = f"Merge of models {', '.join(models)}",
path = str(dump_path),
description = f"Merge of models {', '.join(model_names)}",
model_format = "diffusers",
variant = ModelVariantType.Normal.value,
vae = vae,
)
model_manager.add_model(merged_model_name,
return self.manager.add_model(merged_model_name,
base_model = base_model,
model_type = ModelType.Main,
model_attributes = attributes,

View File

@ -1,4 +1,5 @@
"""
Initialization file for invokeai.frontend.merge
"""
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
from .merge_diffusers import main as invokeai_merge_diffusers

View File

@ -6,10 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import enum
import os
import sys
import warnings
from argparse import Namespace
from pathlib import Path
from typing import List, Union
@ -21,12 +18,12 @@ from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.services.config import InvokeAIAppConfig
from ...backend.model_management import (
merge_diffusion_models_and_save,
ModelManager, MergeInterpolationMethod, BaseModelType
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import (
ModelMerger, MergeInterpolationMethod,
ModelManager, ModelType, BaseModelType,
)
from ...frontend.install.widgets import FloatTitleSlider
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
config = InvokeAIAppConfig.get_config()
@ -48,12 +45,13 @@ def _parse_args() -> Namespace:
)
parser.add_argument(
"--models",
dest="model_names",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--base_type",
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
help="The base model shared by the models to be merged",
@ -115,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -131,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
'Models Built on SD-1.x',
'Models Built on SD-2.x',
],
value=[self.current_base],
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 1",
color="GOOD",
editable=False,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -145,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
rely=7,
)
self.add_widget_intelligent(
npyscreen.FixedText,
@ -153,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -163,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
@ -172,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -185,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_width=max_width,
scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
)
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
TextBox,
name="Name for merged model:",
labelColor="CONTROL",
max_height=3,
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL",
value=False,
value=True,
scroll_exit=True,
)
self.nextrely += 1
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
@ -264,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]]
args = dict(
models=models,
model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value,
interp=interp,
force=self.force.value,
@ -302,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self) -> List[str]:
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
model_names = [
name
for name in self.model_manager.model_names()
if self.model_manager.model_info(name).get("format") == "diffusers"
info["name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
]
return sorted(model_names)
def _populate_models(self,value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
def __init__(self, model_manager:ModelManager):
super().__init__()
conf = OmegaConf.load(config.model_conf_path)
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
self.model_manager = model_manager
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
@ -324,38 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace):
mergeapp = Mergeapp()
model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run()
args = mergeapp.merge_arguments
merge_diffusion_models_and_save(**args)
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
args.models and len(args.models) >= 1 and len(args.models) <= 3
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
args.merged_model_name = "+".join(args.model_names)
logger.info(
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
model_manager = ModelManager(config.model_conf_path)
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main():
args = _parse_args()
config.parse_args(['--root',args.root_dir])
config.parse_args(['--root',str(args.root_dir)])
try:
if args.front_end: