From e9352227f3b587d523ce67d3113ec466d30e58d3 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 6 Jul 2023 15:12:34 -0400 Subject: [PATCH] add merge api --- invokeai/app/api/routers/models.py | 12 +++++------ .../app/services/model_manager_service.py | 20 +++++++++++-------- .../backend/model_management/model_merge.py | 10 ++++++---- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 4e23a69d90..8dbeaa3d05 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -204,12 +204,12 @@ async def convert_model( response_model = MergeModelResponse, ) async def merge_models( - base_model: BaseModelType = Path(description="Base model"), - model_names: List[str] = Body(description="model name", min_items=2, max_items=3), - merged_model_name: Optional[str] = Body(description = "Name of destination model"), - alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5), - interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"), - force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False), + base_model: BaseModelType = Path(description="Base model"), + model_names: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description="Name of destination model"), + alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), + interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), + force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), ) -> MergeModelResponse: """Convert a checkpoint model into a diffusers model""" logger = ApiDependencies.invoker.services.logger diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index eb2c014b1a..6359247cde 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -546,11 +546,15 @@ class ModelManagerService(ModelManagerServiceBase): :param interp: Interpolation method. None (default) """ merger = ModelMerger(self.mgr) - return merger.merge_diffusion_models_and_save( - model_names = model_names, - base_model = base_model, - merged_model_name = merged_model_name, - alpha = alpha, - interp = interp, - force = force, - ) + try: + result = merger.merge_diffusion_models_and_save( + model_names = model_names, + base_model = base_model, + merged_model_name = merged_model_name, + alpha = alpha, + interp = interp, + force = force, + ) + except AssertionError as e: + raise ValueError(e) + return result diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py index 1a110a47b8..39f951d2b4 100644 --- a/invokeai/backend/model_management/model_merge.py +++ b/invokeai/backend/model_management/model_merge.py @@ -15,14 +15,13 @@ from typing import List, Union import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult class MergeInterpolationMethod(str, Enum): + WeightedSum = "weighted_sum" Sigmoid = "sigmoid" InvSigmoid = "inv_sigmoid" AddDifference = "add_difference" - WeightedSum = "weighted_sum" class ModelMerger(object): def __init__(self, manager: ModelManager): @@ -97,15 +96,18 @@ class ModelMerger(object): for mod in model_names: info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) - assert info, f"model {mod}, base_model {base_model}, is unknown" + assert info, f"model {mod}, base_model {base_model}, is unknown" assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging" - assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged") + assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" + assert len(model_names) <= 2 or \ + interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported" # pick up the first model's vae if mod == model_names[0]: vae = info.get("vae") model_paths.extend([config.root_path / info["path"]]) merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp) + logger.debug(f'interp = {interp}, merge_method={merge_method}') merged_pipe = self.merge_diffusion_models( model_paths, alpha, merge_method, force, **kwargs )