mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
partial implementation of merge
This commit is contained in:
parent
d4550b3059
commit
cfa3b2419c
@ -222,7 +222,10 @@ class ModelInstall(object):
|
|||||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||||
model_result = None
|
model_result = None
|
||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
if not info:
|
||||||
|
logger.warning(f'Unable to parse format of {path}')
|
||||||
|
return None
|
||||||
|
model_name = path.stem if info.format == "checkpoint" else path.name
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
attributes = self._make_attributes(path,info)
|
attributes = self._make_attributes(path,info)
|
||||||
|
@ -4,4 +4,5 @@ Initialization file for invokeai.backend.model_management
|
|||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
|
from .model_merge import merge_diffusion_models_and_save, MergeInterpolationMethod
|
||||||
|
|
||||||
|
119
invokeai/backend/model_management/model_merge.py
Normal file
119
invokeai/backend/model_management/model_merge.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
"""
|
||||||
|
invokeai.backend.model_management.model_merge exports:
|
||||||
|
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||||
|
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||||
|
|
||||||
|
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from diffusers import logging as dlogging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||||
|
|
||||||
|
class MergeInterpolationMethod(str, Enum):
|
||||||
|
Sigmoid = "sigmoid"
|
||||||
|
InvSigmoid = "inv_sigmoid"
|
||||||
|
AddDifference = "add_difference"
|
||||||
|
|
||||||
|
def merge_diffusion_models(
|
||||||
|
model_paths: List[Path],
|
||||||
|
alpha: float = 0.5,
|
||||||
|
interp: InterpolationMethod = None,
|
||||||
|
force: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> DiffusionPipeline:
|
||||||
|
"""
|
||||||
|
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
verbosity = dlogging.get_verbosity()
|
||||||
|
dlogging.set_verbosity_error()
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
model_paths[0],
|
||||||
|
custom_pipeline="checkpoint_merger",
|
||||||
|
)
|
||||||
|
merged_pipe = pipe.merge(
|
||||||
|
pretrained_model_name_or_path_list=model_paths,
|
||||||
|
alpha=alpha,
|
||||||
|
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||||
|
force=force,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
dlogging.set_verbosity(verbosity)
|
||||||
|
return merged_pipe
|
||||||
|
|
||||||
|
|
||||||
|
def merge_diffusion_models_and_save (
|
||||||
|
models: List["str"],
|
||||||
|
base_model: BaseModelType,
|
||||||
|
merged_model_name: str,
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
|
alpha: float = 0.5,
|
||||||
|
interp: InterpolationMethod = None,
|
||||||
|
force: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||||
|
:param base_model: base model (must be the same for all merged models!)
|
||||||
|
:param merged_model_name: name for new model
|
||||||
|
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||||
|
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||||
|
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||||
|
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||||
|
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||||
|
|
||||||
|
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||||
|
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||||
|
"""
|
||||||
|
model_manager = ModelManager(config.model_conf_path)
|
||||||
|
model_paths = list()
|
||||||
|
vae = None
|
||||||
|
for mod in models:
|
||||||
|
info = model_manager.model_info(mod, base_model=base_model, model_type=ModelType.main)
|
||||||
|
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||||
|
assert info["format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||||
|
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
|
||||||
|
if mod == models[0]:
|
||||||
|
vae = info["vae"]
|
||||||
|
model_paths.extend([info["path"]])
|
||||||
|
|
||||||
|
merged_pipe = merge_diffusion_models(
|
||||||
|
model_paths, alpha, interp, force, **kwargs
|
||||||
|
)
|
||||||
|
dump_path = config.models_path / base_model.value / ModelType.main.value
|
||||||
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
|
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||||
|
attributes = dict(
|
||||||
|
path = dump_path,
|
||||||
|
description = f"Merge of models {', '.join(models)}",
|
||||||
|
model_format = "diffusers",
|
||||||
|
variant = ModelVariantType.Normal.value,
|
||||||
|
vae = vae,
|
||||||
|
)
|
||||||
|
model_manager.add_model(merged_model_name,
|
||||||
|
base_model = base_model,
|
||||||
|
model_type = ModelType.Main,
|
||||||
|
model_attributes = attributes,
|
||||||
|
clobber = True
|
||||||
|
)
|
@ -6,6 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import curses
|
import curses
|
||||||
|
import enum
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
@ -21,98 +22,14 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.services.config import InvokeAIAppConfig
|
from invokeai.services.config import InvokeAIAppConfig
|
||||||
from ...backend.model_management import ModelManager
|
from ...backend.model_management import (
|
||||||
|
merge_diffusion_models_and_save,
|
||||||
|
ModelManager, MergeInterpolationMethod, BaseModelType
|
||||||
|
)
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from ...frontend.install.widgets import FloatTitleSlider
|
||||||
|
|
||||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
|
||||||
model_ids_or_paths: List[Union[str, Path]],
|
|
||||||
alpha: float = 0.5,
|
|
||||||
interp: str = None,
|
|
||||||
force: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> DiffusionPipeline:
|
|
||||||
"""
|
|
||||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
|
||||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
|
||||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
|
||||||
"""
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
verbosity = dlogging.get_verbosity()
|
|
||||||
dlogging.set_verbosity_error()
|
|
||||||
|
|
||||||
pipe = DiffusionPipeline.from_pretrained(
|
|
||||||
model_ids_or_paths[0],
|
|
||||||
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
|
||||||
custom_pipeline="checkpoint_merger",
|
|
||||||
)
|
|
||||||
merged_pipe = pipe.merge(
|
|
||||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
|
||||||
alpha=alpha,
|
|
||||||
interp=interp,
|
|
||||||
force=force,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
dlogging.set_verbosity(verbosity)
|
|
||||||
return merged_pipe
|
|
||||||
|
|
||||||
|
|
||||||
def merge_diffusion_models_and_commit(
|
|
||||||
models: List["str"],
|
|
||||||
merged_model_name: str,
|
|
||||||
alpha: float = 0.5,
|
|
||||||
interp: str = None,
|
|
||||||
force: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
|
||||||
merged_model_name = name for new model
|
|
||||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
|
||||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
|
||||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
|
||||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
|
||||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
|
||||||
|
|
||||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
|
||||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
|
||||||
"""
|
|
||||||
config_file = config.model_conf_path
|
|
||||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
|
||||||
for mod in models:
|
|
||||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
|
||||||
assert (
|
|
||||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
|
||||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
|
||||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
|
||||||
|
|
||||||
merged_pipe = merge_diffusion_models(
|
|
||||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
|
||||||
)
|
|
||||||
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
|
||||||
|
|
||||||
os.makedirs(dump_path, exist_ok=True)
|
|
||||||
dump_path = dump_path / merged_model_name
|
|
||||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
|
||||||
import_args = dict(
|
|
||||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
|
||||||
)
|
|
||||||
if vae := model_manager.config[models[0]].get("vae", None):
|
|
||||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
|
||||||
import_args.update(vae=vae)
|
|
||||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
|
||||||
model_manager.commit(config_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args() -> Namespace:
|
def _parse_args() -> Namespace:
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -135,6 +52,12 @@ def _parse_args() -> Namespace:
|
|||||||
nargs="+",
|
nargs="+",
|
||||||
help="Two to three model names to be merged",
|
help="Two to three model names to be merged",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_type",
|
||||||
|
type=str,
|
||||||
|
choices=[x.value for x in BaseModelType],
|
||||||
|
help="The base model shared by the models to be merged",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--merged_model_name",
|
"--merged_model_name",
|
||||||
"--destination",
|
"--destination",
|
||||||
@ -405,7 +328,7 @@ def run_gui(args: Namespace):
|
|||||||
mergeapp.run()
|
mergeapp.run()
|
||||||
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merge_diffusion_models_and_commit(**args)
|
merge_diffusion_models_and_save(**args)
|
||||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
|
|
||||||
|
|
||||||
@ -432,13 +355,7 @@ def run_cli(args: Namespace):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
config.root = args.root_dir
|
config.parse_args(['--root',args.root_dir])
|
||||||
|
|
||||||
cache_dir = config.cache_dir
|
|
||||||
os.environ[
|
|
||||||
"HF_HOME"
|
|
||||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
|
||||||
args.cache_dir = cache_dir
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.front_end:
|
if args.front_end:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user