model merge backend, CLI and TUI working

This commit is contained in:
Lincoln Stein 2023-07-06 12:21:42 -04:00
parent f7daa6e71d
commit ec7c2f07c6
5 changed files with 174 additions and 135 deletions

View File

@ -4,5 +4,5 @@ Initialization file for invokeai.backend.model_management
from .model_manager import ModelManager, ModelInfo, AddModelResult
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .model_merge import merge_diffusion_models_and_save, MergeInterpolationMethod
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -279,7 +279,7 @@ class InvalidModelError(Exception):
pass
class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import")
name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model")
@ -496,7 +496,7 @@ class ModelManager(object):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
)->dict:
) -> dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.

View File

@ -11,109 +11,119 @@ from enum import Enum
from pathlib import Path
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from typing import List
from typing import List, Union
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum):
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
WeightedSum = "weighted_sum"
def merge_diffusion_models(
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
force=force,
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save (
self,
model_names: List[str],
base_model: Union[BaseModelType,str],
merged_model_name: str,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)
dlogging.set_verbosity(verbosity)
return merged_pipe
dump_path = config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
def merge_diffusion_models_and_save (
models: List["str"],
base_model: BaseModelType,
merged_model_name: str,
config: InvokeAIAppConfig,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
):
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_manager = ModelManager(config.model_conf_path)
model_paths = list()
vae = None
for mod in models:
info = model_manager.model_info(mod, base_model=base_model, model_type=ModelType.main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
if mod == models[0]:
vae = info["vae"]
model_paths.extend([info["path"]])
merged_pipe = merge_diffusion_models(
model_paths, alpha, interp, force, **kwargs
)
dump_path = config.models_path / base_model.value / ModelType.main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path = dump_path,
description = f"Merge of models {', '.join(models)}",
model_format = "diffusers",
variant = ModelVariantType.Normal.value,
vae = vae,
)
model_manager.add_model(merged_model_name,
base_model = base_model,
model_type = ModelType.Main,
model_attributes = attributes,
clobber = True
)
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path = str(dump_path),
description = f"Merge of models {', '.join(model_names)}",
model_format = "diffusers",
variant = ModelVariantType.Normal.value,
vae = vae,
)
return self.manager.add_model(merged_model_name,
base_model = base_model,
model_type = ModelType.Main,
model_attributes = attributes,
clobber = True
)

View File

@ -1,4 +1,5 @@
"""
Initialization file for invokeai.frontend.merge
"""
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models
from .merge_diffusers import main as invokeai_merge_diffusers

View File

@ -6,10 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import curses
import enum
import os
import sys
import warnings
from argparse import Namespace
from pathlib import Path
from typing import List, Union
@ -21,12 +18,12 @@ from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.services.config import InvokeAIAppConfig
from ...backend.model_management import (
merge_diffusion_models_and_save,
ModelManager, MergeInterpolationMethod, BaseModelType
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import (
ModelMerger, MergeInterpolationMethod,
ModelManager, ModelType, BaseModelType,
)
from ...frontend.install.widgets import FloatTitleSlider
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
config = InvokeAIAppConfig.get_config()
@ -48,12 +45,13 @@ def _parse_args() -> Namespace:
)
parser.add_argument(
"--models",
dest="model_names",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--base_type",
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
help="The base model shared by the models to be merged",
@ -115,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names()
self.current_base = 0
max_width = max([len(x) for x in self.model_names])
max_width += 6
horizontal_layout = max_width * 3 < window_width
@ -131,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False,
)
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
'Models Built on SD-1.x',
'Models Built on SD-2.x',
],
value=[self.current_base],
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent(
npyscreen.FixedText,
value="MODEL 1",
color="GOOD",
editable=False,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -145,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
scroll_exit=True,
rely=5,
rely=7,
)
self.add_widget_intelligent(
npyscreen.FixedText,
@ -153,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne,
@ -163,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names),
max_width=max_width,
relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
scroll_exit=True,
)
self.add_widget_intelligent(
@ -172,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD",
editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None,
rely=6 if horizontal_layout else None,
)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
@ -185,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_width=max_width,
scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None,
rely=7 if horizontal_layout else None,
)
for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
TextBox,
name="Name for merged model:",
labelColor="CONTROL",
max_height=3,
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL",
value=False,
value=True,
scroll_exit=True,
)
self.nextrely += 1
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
@ -264,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]]
args = dict(
models=models,
model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value,
interp=interp,
force=self.force.value,
@ -302,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
return True
def get_model_names(self) -> List[str]:
def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
model_names = [
name
for name in self.model_manager.model_names()
if self.model_manager.model_info(name).get("format") == "diffusers"
info["name"]
for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if info["model_format"] == "diffusers"
]
return sorted(model_names)
def _populate_models(self,value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
def __init__(self, model_manager:ModelManager):
super().__init__()
conf = OmegaConf.load(config.model_conf_path)
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
self.model_manager = model_manager
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
@ -324,38 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace):
mergeapp = Mergeapp()
model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run()
args = mergeapp.merge_arguments
merge_diffusion_models_and_save(**args)
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
args.models and len(args.models) >= 1 and len(args.models) <= 3
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
args.merged_model_name = "+".join(args.model_names)
logger.info(
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(config.model_conf_path))
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
model_manager = ModelManager(config.model_conf_path)
assert (
not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main():
args = _parse_args()
config.parse_args(['--root',args.root_dir])
config.parse_args(['--root',str(args.root_dir)])
try:
if args.front_end: