model merging API ready for testing

This commit is contained in:
Lincoln Stein 2023-07-06 13:15:15 -04:00
parent ec7c2f07c6
commit 3e925fbf34
3 changed files with 98 additions and 58 deletions

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
from typing import Literal, Optional, Union
from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -11,9 +11,9 @@ from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType
SchedulerPredictionType,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -21,6 +21,7 @@ models_router = APIRouter(prefix="/v1/models", tags=["models"])
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@ -170,7 +171,7 @@ async def delete_model(
404: { "description": "Model not found" },
},
status_code = 200,
response_model = Union[tuple(OPENAPI_MODEL_CONFIGS)],
response_model = ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
@ -195,55 +196,42 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e))
return response
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
responses={
200: { "description": "Model converted successfully" },
400: { "description": "Incompatible models" },
404: { "description": "One or more models not found" },
},
status_code = 200,
response_model = MergeModelResponse,
)
async def merge_models(
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description = "Name of destination model"),
alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"),
force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False),
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names}")
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model,
merged_model_name or "+".join(model_names),
alpha,
interp,
force)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model,
model_type = ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -4,10 +4,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from pydantic import Field
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management.model_manager import (
from invokeai.backend.model_management import (
ModelManager,
BaseModelType,
ModelType,
@ -15,8 +16,11 @@ from invokeai.backend.model_management.model_manager import (
ModelInfo,
AddModelResult,
SchedulerPredictionType,
ModelMerger,
MergeInterpolationMethod,
)
import torch
from invokeai.app.models.exceptions import CanceledException
from ...backend.util import choose_precision, choose_torch_device
@ -207,6 +211,26 @@ class ModelManagerServiceBase(ABC):
'''
pass
@abstractmethod
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
@ -501,3 +525,31 @@ class ModelManagerService(ModelManagerServiceBase):
that model.
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
merger = ModelMerger(self.mgr)
return merger.merge_diffusion_models_and_save(
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
)

View File

@ -1,7 +1,7 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo, AddModelResult
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .model_merge import ModelMerger, MergeInterpolationMethod