Merge branch 'lstein/model-manager-router-api'

This commit is contained in:
Lincoln Stein 2023-07-06 15:13:41 -04:00
commit f78f10bef6
3 changed files with 24 additions and 18 deletions

View File

@ -204,12 +204,12 @@ async def convert_model(
response_model = MergeModelResponse,
)
async def merge_models(
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description = "Name of destination model"),
alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"),
force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False),
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger

View File

@ -546,11 +546,15 @@ class ModelManagerService(ModelManagerServiceBase):
:param interp: Interpolation method. None (default)
"""
merger = ModelMerger(self.mgr)
return merger.merge_diffusion_models_and_save(
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
)
try:
result = merger.merge_diffusion_models_and_save(
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
)
except AssertionError as e:
raise ValueError(e)
return result

View File

@ -15,14 +15,13 @@ from typing import List, Union
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
WeightedSum = "weighted_sum"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
@ -97,15 +96,18 @@ class ModelMerger(object):
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged")
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert len(model_names) <= 2 or \
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
logger.debug(f'interp = {interp}, merge_method={merge_method}')
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)