Merge branch 'main' into feat/clip_skip

This commit is contained in:
blessedcoolant 2023-07-07 16:21:53 +12:00
commit 7aa918677e
10 changed files with 683 additions and 457 deletions

View File

@ -1,75 +1,30 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein
from typing import Literal, Optional, Union
from fastapi import Query, Body from typing import Literal, List, Optional, Union
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from fastapi import Body, Path, Query, Response
from ..dependencies import ApiDependencies from fastapi.routing import APIRouter
from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management.models import (
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType OPENAPI_MODEL_CONFIGS,
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] SchedulerPredictionType,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel): UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
repo_id: str = Field(description="The repo ID to use for this VAE") ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
path: Optional[str] = Field(description="The path to the VAE") ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[MODEL_CONFIGS] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get( @models_router.get(
"/", "/",
@ -77,75 +32,103 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
base_model: Optional[BaseModelType] = Query( base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
default=None, description="Base model" model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
),
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList: ) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models
@models_router.post( @models_router.patch(
"/", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"status": "success"}}, responses={200: {"description" : "The model was updated successfully"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
},
status_code = 200,
response_model = UpdateModelResponse,
) )
async def update_model( async def update_model(
model_request: CreateModelRequest base_model: BaseModelType = Path(description="Base model"),
) -> CreateModelResponse: model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
""" Add Model """ """ Add Model """
model_request_info = model_request.info try:
info_dict = model_request_info.dict() ApiDependencies.invoker.services.model_manager.update_model(
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success") model_name=model_name,
base_model=base_model,
ApiDependencies.invoker.services.model_manager.add_model( model_type=model_type,
model_name=model_request.name, model_attributes=info.dict()
model_attributes=info_dict, )
clobber=True, model_raw = ApiDependencies.invoker.services.model_manager.list_model(
) model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return model_response return model_response
@models_router.post( @models_router.post(
"/import", "/",
operation_id="import_model", operation_id="import_model",
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse
) )
async def import_model( async def import_model(
name: str = Query(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """ """ Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( try:
items_to_import = items_to_import, installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
prediction_type_helper = lambda x: prediction_types.get(prediction_type) items_to_import = items_to_import,
) prediction_type_helper = lambda x: prediction_types.get(prediction_type)
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = info,
status = "success",
) )
else: info = installed_models.get(location)
logger.error(f'Model {name} not imported')
raise HTTPException(status_code=404, detail=f'Model {name} not found') if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="del_model", operation_id="del_model",
responses={ responses={
204: { 204: {
@ -156,144 +139,95 @@ async def import_model(
} }
}, },
) )
async def delete_model(model_name: str) -> None: async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model""" """Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
model_exists = model_name in model_names
# check if model exists
logger.info(f"Checking for model {model_name}...")
if model_exists:
logger.info(f"Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
logger.info(f"Model Deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else: try:
logger.error("Model not found") ApiDependencies.invoker.services.model_manager.del_model(model_name,
base_model = base_model,
model_type = model_type
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# @socketio.on("convertToDiffusers") @models_router.put(
# def convert_to_diffusers(model_to_convert: dict): "/convert/{base_model}/{model_type}/{model_name}",
# try: operation_id="convert_model",
# if model_info := self.generate.model_manager.model_info( responses={
# model_name=model_to_convert["model_name"] 200: { "description": "Model converted successfully" },
# ): 400: {"description" : "Bad request" },
# if "weights" in model_info: 404: { "description": "Model not found" },
# ckpt_path = Path(model_info["weights"]) },
# original_config_file = Path(model_info["config"]) status_code = 200,
# model_name = model_to_convert["model_name"] response_model = ConvertModelResponse,
# model_description = model_info["description"] )
# else: async def convert_model(
# self.socketio.emit( base_model: BaseModelType = Path(description="Base model"),
# "error", {"message": "Model is not a valid checkpoint file"} model_type: ModelType = Path(description="The type of model"),
# ) model_name: str = Path(description="model name"),
# else: ) -> ConvertModelResponse:
# self.socketio.emit( """Convert a checkpoint model into a diffusers model"""
# "error", {"message": "Could not retrieve model info."} logger = ApiDependencies.invoker.services.logger
# ) try:
logger.info(f"Converting model: {model_name}")
# if not ckpt_path.is_absolute(): ApiDependencies.invoker.services.model_manager.convert_model(model_name,
# ckpt_path = Path(Globals.root, ckpt_path) base_model = base_model,
model_type = model_type
# if original_config_file and not original_config_file.is_absolute(): )
# original_config_file = Path(Globals.root, original_config_file) model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model,
# diffusers_path = Path( model_type = model_type)
# ckpt_path.parent.absolute(), f"{model_name}_diffusers" response = parse_obj_as(ConvertModelResponse, model_raw)
# ) except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# if model_to_convert["save_location"] == "root": except ValueError as e:
# diffusers_path = Path( raise HTTPException(status_code=400, detail=str(e))
# global_converted_ckpts_dir(), f"{model_name}_diffusers" return response
# )
@models_router.put(
# if ( "/merge/{base_model}",
# model_to_convert["save_location"] == "custom" operation_id="merge_models",
# and model_to_convert["custom_location"] is not None responses={
# ): 200: { "description": "Model converted successfully" },
# diffusers_path = Path( 400: { "description": "Incompatible models" },
# model_to_convert["custom_location"], f"{model_name}_diffusers" 404: { "description": "One or more models not found" },
# ) },
status_code = 200,
# if diffusers_path.exists(): response_model = MergeModelResponse,
# shutil.rmtree(diffusers_path) )
async def merge_models(
# self.generate.model_manager.convert_and_import( base_model: BaseModelType = Path(description="Base model"),
# ckpt_path, model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
# diffusers_path, merged_model_name: Optional[str] = Body(description="Name of destination model"),
# model_name=model_name, alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
# model_description=model_description, interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
# vae=None, force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
# original_config_file=original_config_file, ) -> MergeModelResponse:
# commit_to_conf=opt.conf, """Convert a checkpoint model into a diffusers model"""
# ) logger = ApiDependencies.invoker.services.logger
try:
# new_model_list = self.generate.model_manager.list_models() logger.info(f"Merging models: {model_names}")
# socketio.emit( result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
# "modelConverted", base_model,
# { merged_model_name or "+".join(model_names),
# "new_model_name": model_name, alpha,
# "model_list": new_model_list, interp,
# "update": True, force)
# }, model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
# ) base_model = base_model,
# print(f">> Model Converted: {model_name}") model_type = ModelType.Main,
# except Exception as e: )
# self.handle_exceptions(e) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
# @socketio.on("mergeDiffusersModels") raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
# def merge_diffusers_models(model_merge_info: dict): except ValueError as e:
# try: raise HTTPException(status_code=400, detail=str(e))
# models_to_merge = model_merge_info["models_to_merge"] return response
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:

View File

@ -2,22 +2,29 @@
from __future__ import annotations from __future__ import annotations
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING from pydantic import Field
from dataclasses import dataclass from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management import (
ModelManager, ModelManager,
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
ModelInfo, ModelInfo,
AddModelResult,
SchedulerPredictionType,
ModelMerger,
MergeInterpolationMethod,
) )
import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: types.ModuleType, logger: ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC):
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" Uses the exact format as the omegaconf stanza.
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
""" """
pass pass
@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC):
} }
""" """
pass pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
pass
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
pass
@abstractmethod @abstractmethod
def add_model( def add_model(
@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC):
model_type: ModelType, model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False clobber: bool = False
) -> None: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod @abstractmethod
def del_model( def del_model(
self, self,
@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC):
""" """
pass pass
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
pass
@abstractmethod @abstractmethod
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->Dict[str, AddModelResult]: )->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def commit(self, conf_file: Path = None) -> None: def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the If no conf_file is provided, then replaces the
@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: types.ModuleType, logger: ModuleType,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -299,12 +372,19 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None model_type: Optional[ModelType] = None
) -> list[dict]: ) -> list[dict]:
# ) -> dict:
""" """
Return a list of models. Return a list of models.
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name,
base_model=base_model,
model_type=model_type)
def add_model( def add_model(
self, self,
model_name: str, model_name: str,
@ -320,9 +400,28 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model( def del_model(
self, self,
model_name: str, model_name: str,
@ -334,8 +433,29 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk. as well. Call commit() to write to disk.
""" """
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type) self.mgr.del_model(model_name, base_model, model_type)
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type)
def commit(self, conf_file: Optional[Path]=None): def commit(self, conf_file: Optional[Path]=None):
""" """
@ -387,9 +507,9 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.logger return self.mgr.logger
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->Dict[str, AddModelResult]: )->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported. :param items_to_import: Set of strings corresponding to models to be imported.
@ -406,4 +526,35 @@ class ModelManagerService(ModelManagerServiceBase):
of the set is a dict corresponding to the newly-created OmegaConf stanza for of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model. that model.
''' '''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper) return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
"""
merger = ModelMerger(self.mgr)
try:
result = merger.merge_diffusion_models_and_save(
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
)
except AssertionError as e:
raise ValueError(e)
return result

View File

@ -166,14 +166,18 @@ class ModelInstall(object):
# add requested models # add requested models
for path in selections.install_models: for path in selections.install_models:
logger.info(f'Installing {path} [{job}/{jobs}]') logger.info(f'Installing {path} [{job}/{jobs}]')
self.heuristic_import(path) try:
self.heuristic_import(path)
except (ValueError, KeyError) as e:
logger.error(str(e))
job += 1 job += 1
self.mgr.commit() self.mgr.commit()
def heuristic_import(self, def heuristic_import(self,
model_path_id_or_url: Union[str,Path], model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Dict[str, AddModelResult]: models_installed: Set[Path]=None,
)->Dict[str, AddModelResult]:
''' '''
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
:param models_installed: Set of installed models, used for recursive invocation :param models_installed: Set of installed models, used for recursive invocation
@ -187,61 +191,53 @@ class ModelInstall(object):
self.current_id = model_path_id_or_url self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url) path = Path(model_path_id_or_url)
try: # checkpoint file, or similar
# checkpoint file, or similar if path.is_file():
if path.is_file(): models_installed.update({str(path):self._install_path(path)})
models_installed.update(self._install_path(path))
# folders style or similar # folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \ elif path.is_dir() and any([(path/x).exists() for x in \
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
] ]
): ):
models_installed.update(self._install_path(path)) models_installed.update(self._install_path(path))
# recursive scan # recursive scan
elif path.is_dir(): elif path.is_dir():
for child in path.iterdir(): for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed) self.heuristic_import(child, models_installed=models_installed)
# huggingface repo # huggingface repo
elif len(str(path).split('/')) == 2: elif len(str(model_path_id_or_url).split('/')) == 2:
models_installed.update(self._install_repo(str(path))) models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL # a URL
elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update(self._install_url(model_path_id_or_url)) models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else: else:
logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
except ValueError as e:
logger.error(str(e))
return models_installed return models_installed
# install a model from a local path. The optional info parameter is there to prevent # install a model from a local path. The optional info parameter is there to prevent
# the model from being probed twice in the event that it has already been probed. # the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
try: info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_result = None if not info:
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) logger.warning(f'Unable to parse format of {path}')
model_name = path.stem if path.is_file() else path.name return None
if self.mgr.model_exists(model_name, info.base_type, info.model_type): model_name = path.stem if path.is_file() else path.name
raise ValueError(f'A model named "{model_name}" is already installed.') if self.mgr.model_exists(model_name, info.base_type, info.model_type):
attributes = self._make_attributes(path,info) raise ValueError(f'A model named "{model_name}" is already installed.')
model_result = self.mgr.add_model(model_name = model_name, attributes = self._make_attributes(path,info)
base_model = info.base_type, return self.mgr.add_model(model_name = model_name,
model_type = info.model_type, base_model = info.base_type,
model_attributes = attributes, model_type = info.model_type,
) model_attributes = attributes,
except Exception as e: )
logger.warning(f'{str(e)} Skipping registration.')
return {}
return {str(path): model_result}
def _install_url(self, url: str)->dict: def _install_url(self, url: str)->AddModelResult:
# copy to a staging area, probe, import and delete
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging)) location = download_with_resume(url,Path(staging))
if not location: if not location:
@ -253,7 +249,7 @@ class ModelInstall(object):
# staged version will be garbage-collected at this time # staged version will be garbage-collected at this time
return self._install_path(Path(models_path), info) return self._install_path(Path(models_path), info)
def _install_repo(self, repo_id: str)->dict: def _install_repo(self, repo_id: str)->AddModelResult:
hinfo = HfApi().model_info(repo_id) hinfo = HfApi().model_info(repo_id)
# we try to figure out how to download this most economically # we try to figure out how to download this most economically

View File

@ -1,7 +1,8 @@
""" """
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -2,8 +2,8 @@ from __future__ import annotations
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Any, Union, List
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, List
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager

View File

@ -234,7 +234,7 @@ import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree from shutil import rmtree, move
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -279,7 +279,7 @@ class InvalidModelError(Exception):
pass pass
class AddModelResult(BaseModel): class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after import") name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model") model_type: ModelType = Field(description="The type of model")
base_model: BaseModelType = Field(description="The base model") base_model: BaseModelType = Field(description="The base model")
config: ModelConfigBase = Field(description="The configuration of the model") config: ModelConfigBase = Field(description="The configuration of the model")
@ -491,17 +491,32 @@ class ModelManager(object):
""" """
return [(self.parse_key(x)) for x in self.models.keys()] return [(self.parse_key(x)) for x in self.models.keys()]
def list_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model,model_type,model_name)
return models[0] if models else None
def list_models( def list_models(
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
model_name: Optional[str] = None,
) -> list[dict]: ) -> list[dict]:
""" """
Return a list of models. Return a list of models.
""" """
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
models = [] models = []
for model_key in sorted(self.models, key=str.casefold): for model_key in model_keys:
model_config = self.models[model_key] model_config = self.models[model_key]
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
@ -546,10 +561,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None) model_cfg = self.models.pop(model_key, None)
if model_cfg is None: if model_cfg is None:
self.logger.error( raise KeyError(f"Unknown model {model_key}")
f"Unknown model {model_key}"
)
return
# note: it not garantie to release memory(model can has other references) # note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
@ -615,6 +627,7 @@ class ModelManager(object):
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
self.models[model_key] = model_config self.models[model_key] = model_config
self.commit()
return AddModelResult( return AddModelResult(
name = model_name, name = model_name,
model_type = model_type, model_type = model_type,
@ -622,6 +635,60 @@ class ModelManager(object):
config = model_config, config = model_config,
) )
def convert_model (
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
'''
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
This will raise a ValueError unless the model is a checkpoint.
'''
info = self.model_info(model_name, base_model, model_type)
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `location`. It doesn't matter
# what submodeltype we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
model = self.get_model(model_name,
base_model,
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
try:
move(old_diffusers_path,new_diffusers_path)
info["model_format"] = "diffusers"
info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path))
info.pop('config')
result = self.add_model(model_name, base_model, model_type,
model_attributes = info,
clobber=True)
except:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
rmtree(new_diffusers_path)
raise
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
checkpoint_path.unlink()
return result
def search_models(self, search_folder): def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}") self.logger.info(f"Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
@ -821,6 +888,10 @@ class ModelManager(object):
The result is a set of successfully installed models. Each element The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model. that model.
May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
''' '''
# avoid circular import here # avoid circular import here
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
@ -830,11 +901,7 @@ class ModelManager(object):
prediction_type_helper = prediction_type_helper, prediction_type_helper = prediction_type_helper,
model_manager = self) model_manager = self)
for thing in items_to_import: for thing in items_to_import:
try: installed = installer.heuristic_import(thing)
installed = installer.heuristic_import(thing) successfully_installed.update(installed)
successfully_installed.update(installed)
except Exception as e:
self.logger.warning(f'{thing} could not be imported: {str(e)}')
self.commit() self.commit()
return successfully_installed return successfully_installed

View File

@ -0,0 +1,131 @@
"""
invokeai.backend.model_management.model_merge exports:
merge_diffusion_models() -- combine multiple models by location and return a pipeline object
merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
"""
import warnings
from enum import Enum
from pathlib import Path
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
from typing import List, Union
import invokeai.backend.util.logging as logger
from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
class MergeInterpolationMethod(str, Enum):
WeightedSum = "weighted_sum"
Sigmoid = "sigmoid"
InvSigmoid = "inv_sigmoid"
AddDifference = "add_difference"
class ModelMerger(object):
def __init__(self, manager: ModelManager):
self.manager = manager
def merge_diffusion_models(
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
:param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_paths[0],
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_paths,
alpha=alpha,
interp=interp.value if interp else None, #diffusers API treats None as "weighted sum"
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_save (
self,
model_names: List[str],
base_model: Union[BaseModelType,str],
merged_model_name: str,
alpha: float = 0.5,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> AddModelResult:
"""
:param models: up to three models, designated by their InvokeAI models.yaml model name
:param base_model: base model (must be the same for all merged models!)
:param merged_model_name: name for new model
:param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
:param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
:param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
config = self.manager.app_config
base_model = BaseModelType(base_model)
vae = None
for mod in model_names:
info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main)
assert info, f"model {mod}, base_model {base_model}, is unknown"
assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging"
assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged"
assert len(model_names) <= 2 or \
interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported"
# pick up the first model's vae
if mod == model_names[0]:
vae = info.get("vae")
model_paths.extend([config.root_path / info["path"]])
merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp)
logger.debug(f'interp = {interp}, merge_method={merge_method}')
merged_pipe = self.merge_diffusion_models(
model_paths, alpha, merge_method, force, **kwargs
)
dump_path = config.models_path / base_model.value / ModelType.Main.value
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path = str(dump_path),
description = f"Merge of models {', '.join(model_names)}",
model_format = "diffusers",
variant = ModelVariantType.Normal.value,
vae = vae,
)
return self.manager.add_model(merged_model_name,
base_model = base_model,
model_type = ModelType.Main,
model_attributes = attributes,
clobber = True
)

View File

@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel):
version=BaseModelType.StableDiffusion1, version=BaseModelType.StableDiffusion1,
model_config=config, model_config=config,
output_path=output_path, output_path=output_path,
) )
else: else:
return model_path return model_path

View File

@ -1,4 +1,5 @@
""" """
Initialization file for invokeai.frontend.merge Initialization file for invokeai.frontend.merge
""" """
from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models from .merge_diffusers import main as invokeai_merge_diffusers

View File

@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
""" """
import argparse import argparse
import curses import curses
import os
import sys import sys
import warnings
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
@ -20,99 +18,15 @@ from npyscreen import widget
from omegaconf import OmegaConf from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager from invokeai.backend.model_management import (
from ...frontend.install.widgets import FloatTitleSlider ModelMerger, MergeInterpolationMethod,
ModelManager, ModelType, BaseModelType,
)
from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns
DEST_MERGED_MODEL_DIR = "merged_models"
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
"""
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
pipe = DiffusionPipeline.from_pretrained(
model_ids_or_paths[0],
cache_dir=kwargs.get("cache_dir", config.cache_dir),
custom_pipeline="checkpoint_merger",
)
merged_pipe = pipe.merge(
pretrained_model_name_or_path_list=model_ids_or_paths,
alpha=alpha,
interp=interp,
force=force,
**kwargs,
)
dlogging.set_verbosity(verbosity)
return merged_pipe
def merge_diffusion_models_and_commit(
models: List["str"],
merged_model_name: str,
alpha: float = 0.5,
interp: str = None,
force: bool = False,
**kwargs,
):
"""
models - up to three models, designated by their InvokeAI models.yaml model name
merged_model_name = name for new model
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C).
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
config_file = config.model_conf_path
model_manager = ModelManager(OmegaConf.load(config_file))
for mod in models:
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
assert (
model_manager.model_info(mod).get("format", None) == "diffusers"
), f"** {mod} is not a diffusers model. It must be optimized before merging."
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
merged_pipe = merge_diffusion_models(
model_ids_or_paths, alpha, interp, force, **kwargs
)
dump_path = config.models_dir / DEST_MERGED_MODEL_DIR
os.makedirs(dump_path, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
import_args = dict(
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
)
if vae := model_manager.config[models[0]].get("vae", None):
logger.info(f"Using configured VAE assigned to {models[0]}")
import_args.update(vae=vae)
model_manager.import_diffuser_model(dump_path, **import_args)
model_manager.commit(config_file)
def _parse_args() -> Namespace: def _parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="InvokeAI model merging") parser = argparse.ArgumentParser(description="InvokeAI model merging")
parser.add_argument( parser.add_argument(
@ -131,10 +45,17 @@ def _parse_args() -> Namespace:
) )
parser.add_argument( parser.add_argument(
"--models", "--models",
dest="model_names",
type=str, type=str,
nargs="+", nargs="+",
help="Two to three model names to be merged", help="Two to three model names to be merged",
) )
parser.add_argument(
"--base_model",
type=str,
choices=[x.value for x in BaseModelType],
help="The base model shared by the models to be merged",
)
parser.add_argument( parser.add_argument(
"--merged_model_name", "--merged_model_name",
"--destination", "--destination",
@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
window_height, window_width = curses.initscr().getmaxyx() window_height, window_width = curses.initscr().getmaxyx()
self.model_names = self.get_model_names() self.model_names = self.get_model_names()
self.current_base = 0
max_width = max([len(x) for x in self.model_names]) max_width = max([len(x) for x in self.model_names])
max_width += 6 max_width += 6
horizontal_layout = max_width * 3 < window_width horizontal_layout = max_width * 3 < window_width
@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.", value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
editable=False, editable=False,
) )
self.nextrely += 1
self.base_select = self.add_widget_intelligent(
SingleSelectColumns,
values=[
'Models Built on SD-1.x',
'Models Built on SD-2.x',
],
value=[self.current_base],
columns = 4,
max_height = 2,
relx=8,
scroll_exit = True,
)
self.base_select.on_changed = self._populate_models
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.FixedText, npyscreen.FixedText,
value="MODEL 1", value="MODEL 1",
color="GOOD", color="GOOD",
editable=False, editable=False,
rely=4 if horizontal_layout else None, rely=6 if horizontal_layout else None,
) )
self.model1 = self.add_widget_intelligent( self.model1 = self.add_widget_intelligent(
npyscreen.SelectOne, npyscreen.SelectOne,
@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names), max_height=len(self.model_names),
max_width=max_width, max_width=max_width,
scroll_exit=True, scroll_exit=True,
rely=5, rely=7,
) )
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.FixedText, npyscreen.FixedText,
@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD", color="GOOD",
editable=False, editable=False,
relx=max_width + 3 if horizontal_layout else None, relx=max_width + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None, rely=6 if horizontal_layout else None,
) )
self.model2 = self.add_widget_intelligent( self.model2 = self.add_widget_intelligent(
npyscreen.SelectOne, npyscreen.SelectOne,
@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_height=len(self.model_names), max_height=len(self.model_names),
max_width=max_width, max_width=max_width,
relx=max_width + 3 if horizontal_layout else None, relx=max_width + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None, rely=7 if horizontal_layout else None,
scroll_exit=True, scroll_exit=True,
) )
self.add_widget_intelligent( self.add_widget_intelligent(
@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
color="GOOD", color="GOOD",
editable=False, editable=False,
relx=max_width * 2 + 3 if horizontal_layout else None, relx=max_width * 2 + 3 if horizontal_layout else None,
rely=4 if horizontal_layout else None, rely=6 if horizontal_layout else None,
) )
models_plus_none = self.model_names.copy() models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None") models_plus_none.insert(0, "None")
@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
max_width=max_width, max_width=max_width,
scroll_exit=True, scroll_exit=True,
relx=max_width * 2 + 3 if horizontal_layout else None, relx=max_width * 2 + 3 if horizontal_layout else None,
rely=5 if horizontal_layout else None, rely=7 if horizontal_layout else None,
) )
for m in [self.model1, self.model2, self.model3]: for m in [self.model1, self.model2, self.model3]:
m.when_value_edited = self.models_changed m.when_value_edited = self.models_changed
self.merged_model_name = self.add_widget_intelligent( self.merged_model_name = self.add_widget_intelligent(
npyscreen.TitleText, TextBox,
name="Name for merged model:", name="Name for merged model:",
labelColor="CONTROL", labelColor="CONTROL",
max_height=3,
value="", value="",
scroll_exit=True, scroll_exit=True,
) )
self.force = self.add_widget_intelligent( self.force = self.add_widget_intelligent(
npyscreen.Checkbox, npyscreen.Checkbox,
name="Force merge of incompatible models", name="Force merge of models created by different diffusers library versions",
labelColor="CONTROL", labelColor="CONTROL",
value=False, value=True,
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1
self.merge_method = self.add_widget_intelligent( self.merge_method = self.add_widget_intelligent(
npyscreen.TitleSelectOne, npyscreen.TitleSelectOne,
name="Merge Method:", name="Merge Method:",
@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
interp = self.interpolations[self.merge_method.value[0]] interp = self.interpolations[self.merge_method.value[0]]
args = dict( args = dict(
models=models, model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
alpha=self.alpha.value, alpha=self.alpha.value,
interp=interp, interp=interp,
force=self.force.value, force=self.force.value,
@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else: else:
return True return True
def get_model_names(self) -> List[str]: def get_model_names(self, base_model: BaseModelType=None) -> List[str]:
model_names = [ model_names = [
name info["name"]
for name in self.model_manager.model_names() for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model)
if self.model_manager.model_info(name).get("format") == "diffusers" if info["model_format"] == "diffusers"
] ]
return sorted(model_names) return sorted(model_names)
def _populate_models(self,value=None):
base_model = tuple(BaseModelType)[value[0]]
self.model_names = self.get_model_names(base_model)
models_plus_none = self.model_names.copy()
models_plus_none.insert(0, "None")
self.model1.values = self.model_names
self.model2.values = self.model_names
self.model3.values = models_plus_none
self.display()
class Mergeapp(npyscreen.NPSAppManaged): class Mergeapp(npyscreen.NPSAppManaged):
def __init__(self): def __init__(self, model_manager:ModelManager):
super().__init__() super().__init__()
conf = OmegaConf.load(config.model_conf_path) self.model_manager = model_manager
self.model_manager = ModelManager(
conf, "cpu", "float16"
) # precision doesn't really matter here
def onStart(self): def onStart(self):
npyscreen.setTheme(npyscreen.Themes.ElegantTheme) npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged):
def run_gui(args: Namespace): def run_gui(args: Namespace):
mergeapp = Mergeapp() model_manager = ModelManager(config.model_conf_path)
mergeapp = Mergeapp(model_manager)
mergeapp.run() mergeapp.run()
args = mergeapp.merge_arguments args = mergeapp.merge_arguments
merge_diffusion_models_and_commit(**args) merger = ModelMerger(model_manager)
merger.merge_diffusion_models_and_save(**args)
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
def run_cli(args: Namespace): def run_cli(args: Namespace):
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
assert ( assert (
args.models and len(args.models) >= 1 and len(args.models) <= 3 args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
if not args.merged_model_name: if not args.merged_model_name:
args.merged_model_name = "+".join(args.models) args.merged_model_name = "+".join(args.model_names)
logger.info( logger.info(
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"' f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
) )
model_manager = ModelManager(OmegaConf.load(config.model_conf_path)) model_manager = ModelManager(config.model_conf_path)
assert ( assert (
args.clobber or args.merged_model_name not in model_manager.model_names() not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
merge_diffusion_models_and_commit(**vars(args)) merger = ModelMerger(model_manager)
logger.info(f'Models merged into new model: "{args.merged_model_name}".') merger.merge_diffusion_models_and_save(**vars(args))
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
def main(): def main():
args = _parse_args() args = _parse_args()
config.root = args.root_dir config.parse_args(['--root',str(args.root_dir)])
cache_dir = config.cache_dir
os.environ[
"HF_HOME"
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
args.cache_dir = cache_dir
try: try:
if args.front_end: if args.front_end: