InvokeAI/ldm/invoke/merge_diffusers.py
2023-01-26 15:10:16 -05:00

382 lines
13 KiB
Python

"""
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import argparse
import os
import sys
from argparse import Namespace
from pathlib import Path
from typing import List, Union
import npyscreen
from diffusers import DiffusionPipeline
from omegaconf import OmegaConf
from ldm.invoke.globals import (
Globals,
global_cache_dir,
global_config_file,
global_models_dir,
global_set_root,
)
from ldm.invoke.model_manager import ModelManager
DEST_MERGED_MODEL_DIR = "merged_models"
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs,
)
return merged_pipe
def merge_diffusion_models_and_commit(
models: List["str"],
merged_model_name: str,
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
):
"""
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
config_file = global_config_file()
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
assert (
model_manager.model_info(mod).get("format", None) == "diffusers"
), f"** {mod} is not a diffusers model. It must be optimized before merging."
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, alpha, interp, force, **kwargs
)
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
import_args = dict(
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
)
if vae := model_manager.config[models[0]].get("vae", None):
print(f">> Using configured VAE assigned to {models[0]}")
import_args.update(vae=vae)
model_manager.import_diffuser_model(dump_path, **import_args)
model_manager.commit(config_file)
def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging")
parser.add_argument(
"--root_dir",
type=Path,
default=Globals.root,
help="Path to the invokeai runtime directory",
)
parser.add_argument(
"--front_end",
"--gui",
dest="front_end",
action="store_true",
default=False,
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
)
parser.add_argument(
"--models",
type=str,
nargs="+",
help="Two to three model names to be merged",
)
parser.add_argument(
"--merged_model_name",
"--destination",
dest="merged_model_name",
type=str,
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
)
parser.add_argument(
"--alpha",
type=float,
default=0.5,
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
)
parser.add_argument(
"--interpolation",
dest="interp",
type=str,
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
default="weighted_sum",
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
)
parser.add_argument(
"--force",
action="store_true",
help="Try to merge models even if they are incompatible with each other",
)
parser.add_argument(
"--clobber",
"--overwrite",
dest="clobber",
action="store_true",
help="Overwrite the merged model if --merged_model_name already exists",
)
return parser.parse_args()
# ------------------------- GUI HERE -------------------------
class FloatSlider(npyscreen.Slider):
# this is supposed to adjust display precision, but doesn't
def translate_value(self):
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
l = (len(str(self.out_of))) * 2 + 4
stri = stri.rjust(l)
return stri
class FloatTitleSlider(npyscreen.TitleText):
_entry_type = FloatSlider
class mergeModelsForm(npyscreen.FormMultiPageAction):
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"]
def __init__(self, parentApp, name):
self.parentApp = parentApp
super().__init__(parentApp, name)
@property
def model_manager(self):
return self.parentApp.model_manager
def afterEditing(self):
self.parentApp.setNextForm(None)
def create(self):
self.model_names = self.get_model_names()
self.add_widget_intelligent(
npyscreen.FixedText, name="Select up to three models to merge", value=""
)
self.models = self.add_widget_intelligent(
npyscreen.TitleMultiSelect,
name="Select two to three models to merge:",
values=self.model_names,
value=None,
max_height=len(self.model_names) + 1,
scroll_exit=True,
)
self.models.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText,
name="Name for merged model:",
value="",
scroll_exit=True,
)
self.force = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Force merge of incompatible models",
value=False,
scroll_exit=True,
)
self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="Merge Method:",
values=self.interpolations,
value=0,
max_height=len(self.interpolations) + 1,
scroll_exit=True,
)
self.alpha = self.add_widget_intelligent(
FloatTitleSlider,
name="Weight (alpha) to assign to second and third models:",
out_of=1,
step=0.05,
lowest=0,
value=0.5,
scroll_exit=True,
)
self.models.editing = True
def models_changed(self):
model_names = self.models.values
selected_models = self.models.value
if len(selected_models) > 3:
npyscreen.notify_confirm(
"Too many models selected for merging. Select two to three."
)
return
elif len(selected_models) > 2:
self.merge_method.values = ["add_difference"]
self.merge_method.value = 0
else:
self.merge_method.values = self.interpolations
self.merged_model_name.value = "+".join(
[model_names[x] for x in selected_models]
)
def on_ok(self):
if self.validate_field_values() and self.check_for_overwrite():
self.parentApp.setNextForm(None)
self.editing = False
self.parentApp.merge_arguments = self.marshall_arguments()
npyscreen.notify("Starting the merge...")
else:
self.editing = True
def on_cancel(self):
sys.exit(0)
def marshall_arguments(self) -> dict:
models = [self.models.values[x] for x in self.models.value]
args = dict(
models=models,
alpha=self.alpha.value,
interp=self.interpolations[self.merge_method.value[0]],
force=self.force.value,
merged_model_name=self.merged_model_name.value,
)
return args
def check_for_overwrite(self) -> bool:
model_out = self.merged_model_name.value
if model_out not in self.model_names:
return True
else:
return npyscreen.notify_yes_no(
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
)
def validate_field_values(self) -> bool:
bad_fields = []
selected_models = self.models.value
if len(selected_models) < 2 or len(selected_models) > 3:
bad_fields.append("Please select two or three models to merge.")
if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:"
for problem in bad_fields:
message += f"\n* {problem}"
npyscreen.notify_confirm(message)
return False
else:
return True
def get_model_names(self) -> List[str]:
model_names = [
name
for name in self.model_manager.model_names()
if self.model_manager.model_info(name).get("format") == "diffusers"
]
print(model_names)
return sorted(model_names)
class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self):
super().__init__()
conf = OmegaConf.load(global_config_file())
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
def run_gui(args: Namespace):
mergeapp = Mergeapp()
mergeapp.run()
args = mergeapp.merge_arguments
merge_diffusion_models_and_commit(**args)
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert (
len(args.models) >= 1 and len(args.models) <= 3
), "provide 2 or 3 models to merge"
if not args.merged_model_name:
args.merged_model_name = "+".join(args.models)
print(
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
)
model_manager = ModelManager(OmegaConf.load(global_config_file()))
assert (
args.clobber or args.merged_model_name not in model_manager.model_names()
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args))
def main():
args = _parse_args()
global_set_root(args.root_dir)
cache_dir = str(global_cache_dir("diffusers"))
os.environ[
"HF_HOME"
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir
try:
if args.front_end:
run_gui(args)
else:
run_cli(args)
print(f">> Conversion successful. New model is named {args.merged_model_name}")
except Exception as e:
print(f"** An error occurred while merging the pipelines: {str(e)}")
sys.exit(-1)
except KeyboardInterrupt:
sys.exit(-1)
if __name__ == "__main__":
main()