mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model merging API ready for testing
This commit is contained in:
parent
ec7c2f07c6
commit
3e925fbf34
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
|
||||||
|
|
||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@ -11,9 +11,9 @@ from starlette.exceptions import HTTPException
|
|||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models import (
|
from invokeai.backend.model_management.models import (
|
||||||
OPENAPI_MODEL_CONFIGS,
|
OPENAPI_MODEL_CONFIGS,
|
||||||
SchedulerPredictionType
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
@ -21,6 +21,7 @@ models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
|||||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
@ -170,7 +171,7 @@ async def delete_model(
|
|||||||
404: { "description": "Model not found" },
|
404: { "description": "Model not found" },
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code = 200,
|
||||||
response_model = Union[tuple(OPENAPI_MODEL_CONFIGS)],
|
response_model = ConvertModelResponse,
|
||||||
)
|
)
|
||||||
async def convert_model(
|
async def convert_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
@ -195,55 +196,42 @@ async def convert_model(
|
|||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# @socketio.on("mergeDiffusersModels")
|
@models_router.put(
|
||||||
# def merge_diffusers_models(model_merge_info: dict):
|
"/merge/{base_model}",
|
||||||
# try:
|
operation_id="merge_models",
|
||||||
# models_to_merge = model_merge_info["models_to_merge"]
|
responses={
|
||||||
# model_ids_or_paths = [
|
200: { "description": "Model converted successfully" },
|
||||||
# self.generate.model_manager.model_name_or_path(x)
|
400: { "description": "Incompatible models" },
|
||||||
# for x in models_to_merge
|
404: { "description": "One or more models not found" },
|
||||||
# ]
|
},
|
||||||
# merged_pipe = merge_diffusion_models(
|
status_code = 200,
|
||||||
# model_ids_or_paths,
|
response_model = MergeModelResponse,
|
||||||
# model_merge_info["alpha"],
|
)
|
||||||
# model_merge_info["interp"],
|
async def merge_models(
|
||||||
# model_merge_info["force"],
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
# )
|
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||||
|
merged_model_name: Optional[str] = Body(description = "Name of destination model"),
|
||||||
# dump_path = global_models_dir() / "merged_models"
|
alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
# if model_merge_info["model_merge_save_path"] is not None:
|
interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"),
|
||||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False),
|
||||||
|
) -> MergeModelResponse:
|
||||||
# os.makedirs(dump_path, exist_ok=True)
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
logger = ApiDependencies.invoker.services.logger
|
||||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
try:
|
||||||
|
logger.info(f"Merging models: {model_names}")
|
||||||
# merged_model_config = dict(
|
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||||
# model_name=model_merge_info["merged_model_name"],
|
base_model,
|
||||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
merged_model_name or "+".join(model_names),
|
||||||
# commit_to_conf=opt.conf,
|
alpha,
|
||||||
# )
|
interp,
|
||||||
|
force)
|
||||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||||
# "vae", None
|
base_model = base_model,
|
||||||
# ):
|
model_type = ModelType.Main,
|
||||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
)
|
||||||
# merged_model_config.update(vae=vae)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
|
except KeyError:
|
||||||
# self.generate.model_manager.import_diffuser_model(
|
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||||
# dump_path, **merged_model_config
|
except ValueError as e:
|
||||||
# )
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# new_model_list = self.generate.model_manager.list_models()
|
return response
|
||||||
|
|
||||||
# socketio.emit(
|
|
||||||
# "modelsMerged",
|
|
||||||
# {
|
|
||||||
# "merged_models": models_to_merge,
|
|
||||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
|
||||||
# "model_list": new_model_list,
|
|
||||||
# "update": True,
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
# print(f">> Models Merged: {models_to_merge}")
|
|
||||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
|
||||||
# except Exception as e:
|
|
||||||
|
@ -4,10 +4,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from pydantic import Field
|
||||||
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management import (
|
||||||
ModelManager,
|
ModelManager,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
@ -15,8 +16,11 @@ from invokeai.backend.model_management.model_manager import (
|
|||||||
ModelInfo,
|
ModelInfo,
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
|
ModelMerger,
|
||||||
|
MergeInterpolationMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
@ -207,6 +211,26 @@ class ModelManagerServiceBase(ABC):
|
|||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def merge_models(
|
||||||
|
self,
|
||||||
|
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||||
|
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||||
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
|
:param model_names: List of 2-3 models to merge
|
||||||
|
:param base_model: Base model to use for all models
|
||||||
|
:param merged_model_name: Name of destination merged model
|
||||||
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
|
:param interp: Interpolation method. None (default)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -501,3 +525,31 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
that model.
|
that model.
|
||||||
'''
|
'''
|
||||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||||
|
|
||||||
|
def merge_models(
|
||||||
|
self,
|
||||||
|
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||||
|
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||||
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
|
:param model_names: List of 2-3 models to merge
|
||||||
|
:param base_model: Base model to use for all models
|
||||||
|
:param merged_model_name: Name of destination merged model
|
||||||
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
|
:param interp: Interpolation method. None (default)
|
||||||
|
"""
|
||||||
|
merger = ModelMerger(self.mgr)
|
||||||
|
return merger.merge_diffusion_models_and_save(
|
||||||
|
model_names = model_names,
|
||||||
|
base_model = base_model,
|
||||||
|
merged_model_name = merged_model_name,
|
||||||
|
alpha = alpha,
|
||||||
|
interp = interp,
|
||||||
|
force = force,
|
||||||
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
Loading…
Reference in New Issue
Block a user