model merging API ready for testing

This commit is contained in:
Lincoln Stein 2023-07-06 13:15:15 -04:00
parent ec7c2f07c6
commit 3e925fbf34
3 changed files with 98 additions and 58 deletions

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
from typing import Literal, Optional, Union from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -11,9 +11,9 @@ from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import ( from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType SchedulerPredictionType,
) )
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -21,6 +21,7 @@ models_router = APIRouter(prefix="/v1/models", tags=["models"])
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@ -170,7 +171,7 @@ async def delete_model(
404: { "description": "Model not found" }, 404: { "description": "Model not found" },
}, },
status_code = 200, status_code = 200,
response_model = Union[tuple(OPENAPI_MODEL_CONFIGS)], response_model = ConvertModelResponse,
) )
async def convert_model( async def convert_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
@ -195,55 +196,42 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
# @socketio.on("mergeDiffusersModels") @models_router.put(
# def merge_diffusers_models(model_merge_info: dict): "/merge/{base_model}",
# try: operation_id="merge_models",
# models_to_merge = model_merge_info["models_to_merge"] responses={
# model_ids_or_paths = [ 200: { "description": "Model converted successfully" },
# self.generate.model_manager.model_name_or_path(x) 400: { "description": "Incompatible models" },
# for x in models_to_merge 404: { "description": "One or more models not found" },
# ] },
# merged_pipe = merge_diffusion_models( status_code = 200,
# model_ids_or_paths, response_model = MergeModelResponse,
# model_merge_info["alpha"], )
# model_merge_info["interp"], async def merge_models(
# model_merge_info["force"], base_model: BaseModelType = Path(description="Base model"),
# ) model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description = "Name of destination model"),
# dump_path = global_models_dir() / "merged_models" alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5),
# if model_merge_info["model_merge_save_path"] is not None: interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"),
# dump_path = Path(model_merge_info["model_merge_save_path"]) force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False),
) -> MergeModelResponse:
# os.makedirs(dump_path, exist_ok=True) """Convert a checkpoint model into a diffusers model"""
# dump_path = dump_path / model_merge_info["merged_model_name"] logger = ApiDependencies.invoker.services.logger
# merged_pipe.save_pretrained(dump_path, safe_serialization=1) try:
logger.info(f"Merging models: {model_names}")
# merged_model_config = dict( result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
# model_name=model_merge_info["merged_model_name"], base_model,
# description=f'Merge of models {", ".join(models_to_merge)}', merged_model_name or "+".join(model_names),
# commit_to_conf=opt.conf, alpha,
# ) interp,
force)
# if vae := self.generate.model_manager.config[models_to_merge[0]].get( model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
# "vae", None base_model = base_model,
# ): model_type = ModelType.Main,
# print(f">> Using configured VAE assigned to {models_to_merge[0]}") )
# merged_model_config.update(vae=vae) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
# self.generate.model_manager.import_diffuser_model( raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
# dump_path, **merged_model_config except ValueError as e:
# ) raise HTTPException(status_code=400, detail=str(e))
# new_model_list = self.generate.model_manager.list_models() return response
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:

View File

@ -4,10 +4,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from pydantic import Field
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType from types import ModuleType
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management import (
ModelManager, ModelManager,
BaseModelType, BaseModelType,
ModelType, ModelType,
@ -15,8 +16,11 @@ from invokeai.backend.model_management.model_manager import (
ModelInfo, ModelInfo,
AddModelResult, AddModelResult,
SchedulerPredictionType, SchedulerPredictionType,
ModelMerger,
MergeInterpolationMethod,
) )
import torch import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
@ -207,6 +211,26 @@ class ModelManagerServiceBase(ABC):
''' '''
pass pass
@abstractmethod
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None: def commit(self, conf_file: Optional[Path] = None) -> None:
""" """
@ -501,3 +525,31 @@ class ModelManagerService(ModelManagerServiceBase):
that model. that model.
''' '''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper) return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
merger = ModelMerger(self.mgr)
return merger.merge_diffusion_models_and_save(
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
)

View File

@ -1,7 +1,7 @@
""" """
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .model_merge import ModelMerger, MergeInterpolationMethod from .model_merge import ModelMerger, MergeInterpolationMethod