mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
partial implementation of merge
This commit is contained in:
parent
d4550b3059
commit
cfa3b2419c
@ -222,7 +222,10 @@ class ModelInstall(object):
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||
model_result = None
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||
if not info:
|
||||
logger.warning(f'Unable to parse format of {path}')
|
||||
return None
|
||||
model_name = path.stem if info.format == "checkpoint" else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path,info)
|
||||
|
@ -4,4 +4,5 @@ Initialization file for invokeai.backend.model_management
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
from .model_merge import merge_diffusion_models_and_save, MergeInterpolationMethod
|
||||
|
||||
|
119
invokeai/backend/model_management/model_merge.py
Normal file
119
invokeai/backend/model_management/model_merge.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""
|
||||
invokeai.backend.model_management.model_merge exports:
|
||||
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
|
||||
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from typing import List
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
Sigmoid = "sigmoid"
|
||||
InvSigmoid = "inv_sigmoid"
|
||||
AddDifference = "add_difference"
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: InterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_paths[0],
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_save (
|
||||
models: List["str"],
|
||||
base_model: BaseModelType,
|
||||
merged_model_name: str,
|
||||
config: InvokeAIAppConfig,
|
||||
alpha: float = 0.5,
|
||||
interp: InterpolationMethod = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
:param models: up to three models, designated by their InvokeAI models.yaml model name
|
||||
:param base_model: base model (must be the same for all merged models!)
|
||||
:param merged_model_name: name for new model
|
||||
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_manager = ModelManager(config.model_conf_path)
|
||||
model_paths = list()
|
||||
vae = None
|
||||
for mod in models:
|
||||
info = model_manager.model_info(mod, base_model=base_model, model_type=ModelType.main)
|
||||
assert info, f"model {mod}, base_model {base_model}, is unknown"
|
||||
assert info["format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
|
||||
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
|
||||
if mod == models[0]:
|
||||
vae = info["vae"]
|
||||
model_paths.extend([info["path"]])
|
||||
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_path / base_model.value / ModelType.main.value
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
attributes = dict(
|
||||
path = dump_path,
|
||||
description = f"Merge of models {', '.join(models)}",
|
||||
model_format = "diffusers",
|
||||
variant = ModelVariantType.Normal.value,
|
||||
vae = vae,
|
||||
)
|
||||
model_manager.add_model(merged_model_name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
model_attributes = attributes,
|
||||
clobber = True
|
||||
)
|
@ -6,6 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import curses
|
||||
import enum
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
@ -21,98 +22,14 @@ from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.services.config import InvokeAIAppConfig
|
||||
from ...backend.model_management import ModelManager
|
||||
from ...backend.model_management import (
|
||||
merge_diffusion_models_and_save,
|
||||
ModelManager, MergeInterpolationMethod, BaseModelType
|
||||
)
|
||||
from ...frontend.install.widgets import FloatTitleSlider
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", config.cache_dir),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
config_file = config.model_conf_path
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
@ -135,6 +52,12 @@ def _parse_args() -> Namespace:
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_type",
|
||||
type=str,
|
||||
choices=[x.value for x in BaseModelType],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
@ -405,7 +328,7 @@ def run_gui(args: Namespace):
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
merge_diffusion_models_and_save(**args)
|
||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
@ -432,13 +355,7 @@ def run_cli(args: Namespace):
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
config.root = args.root_dir
|
||||
|
||||
cache_dir = config.cache_dir
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
config.parse_args(['--root',args.root_dir])
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
|
Loading…
Reference in New Issue
Block a user