Merge branch 'lstein/model-manager-router-api'

This commit is contained in:
Lincoln Stein 2023-07-06 15:13:41 -04:00
commit f78f10bef6
3 changed files with 24 additions and 18 deletions

View File

@ -204,12 +204,12 @@ async def convert_model(
response_model = MergeModelResponse, response_model = MergeModelResponse,
) )
async def merge_models( async def merge_models(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3), model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description = "Name of destination model"), merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False), force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
) -> MergeModelResponse: ) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger

View File

@ -546,11 +546,15 @@ class ModelManagerService(ModelManagerServiceBase):
:param interp: Interpolation method. None (default) :param interp: Interpolation method. None (default)
""" """
merger = ModelMerger(self.mgr) merger = ModelMerger(self.mgr)
return merger.merge_diffusion_models_and_save( try:
model_names = model_names, result = merger.merge_diffusion_models_and_save(
base_model = base_model, model_names = model_names,
merged_model_name = merged_model_name, base_model = base_model,
alpha = alpha, merged_model_name = merged_model_name,
interp = interp, alpha = alpha,
force = force, interp = interp,
) force = force,
)
except AssertionError as e:
raise ValueError(e)
return result

View File

@ -15,14 +15,13 @@ from typing import List, Union
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum): class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid" Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid" InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference" AddDifference = "add_difference"
WeightedSum = "weighted_sum"
class ModelMerger(object): class ModelMerger(object):
def __init__(self, manager: ModelManager): def __init__(self, manager: ModelManager):
@ -97,15 +96,18 @@ class ModelMerger(object):
for mod in model_names: for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown" assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging" assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged") assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert len(model_names) <= 2 or \
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
# pick up the first model's vae # pick up the first model's vae
if mod == model_names[0]: if mod == model_names[0]:
vae = info.get("vae") vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]]) model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp) merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
logger.debug(f'interp = {interp}, merge_method={merge_method}')
merged_pipe = self.merge_diffusion_models( merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs model_paths, alpha, merge_method, force, **kwargs
) )