diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index dcbdbec82d..4e23a69d90 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -1,75 +1,30 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein -from typing import Literal, Optional, Union -from fastapi import Query, Body -from fastapi.routing import APIRouter, HTTPException -from pydantic import BaseModel, Field, parse_obj_as -from ..dependencies import ApiDependencies +from typing import Literal, List, Optional, Union + +from fastapi import Body, Path, Query, Response +from fastapi.routing import APIRouter +from pydantic import BaseModel, parse_obj_as +from starlette.exceptions import HTTPException + from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import AddModelResult -from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType -MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] +from invokeai.backend.model_management.models import ( + OPENAPI_MODEL_CONFIGS, + SchedulerPredictionType, +) +from invokeai.backend.model_management import MergeInterpolationMethod +from ..dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) -class VaeRepo(BaseModel): - repo_id: str = Field(description="The repo ID to use for this VAE") - path: Optional[str] = Field(description="The path to the VAE") - subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") - -class ModelInfo(BaseModel): - description: Optional[str] = Field(description="A description of the model") - model_name: str = Field(description="The name of the model") - model_type: str = Field(description="The type of the model") - -class DiffusersModelInfo(ModelInfo): - format: Literal['folder'] = 'folder' - - vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model") - repo_id: Optional[str] = Field(description="The repo ID to use for this model") - path: Optional[str] = Field(description="The path to the model") - -class CkptModelInfo(ModelInfo): - format: Literal['ckpt'] = 'ckpt' - - config: str = Field(description="The path to the model config") - weights: str = Field(description="The path to the model weights") - vae: str = Field(description="The path to the model VAE") - width: Optional[int] = Field(description="The width of the model") - height: Optional[int] = Field(description="The height of the model") - -class SafetensorsModelInfo(CkptModelInfo): - format: Literal['safetensors'] = 'safetensors' - -class CreateModelRequest(BaseModel): - name: str = Field(description="The name of the model") - info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") - -class CreateModelResponse(BaseModel): - name: str = Field(description="The name of the new model") - info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") - status: str = Field(description="The status of the API response") - -class ImportModelResponse(BaseModel): - name: str = Field(description="The name of the imported model") -# base_model: str = Field(description="The base model") -# model_type: str = Field(description="The model type") - info: AddModelResult = Field(description="The model info") - status: str = Field(description="The status of the API response") - -class ConversionRequest(BaseModel): - name: str = Field(description="The name of the new model") - info: CkptModelInfo = Field(description="The converted model info") - save_location: str = Field(description="The path to save the converted model weights") - -class ConvertedModelResponse(BaseModel): - name: str = Field(description="The name of the new model") - info: DiffusersModelInfo = Field(description="The converted model info") +UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): - models: list[MODEL_CONFIGS] - + models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @models_router.get( "/", @@ -77,75 +32,103 @@ class ModelsList(BaseModel): responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: Optional[BaseModelType] = Query( - default=None, description="Base model" - ), - model_type: Optional[ModelType] = Query( - default=None, description="The type of model to get" - ), + base_model: Optional[BaseModelType] = Query(default=None, description="Base model"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) models = parse_obj_as(ModelsList, { "models": models_raw }) return models -@models_router.post( - "/", +@models_router.patch( + "/{base_model}/{model_type}/{model_name}", operation_id="update_model", - responses={200: {"status": "success"}}, + responses={200: {"description" : "The model was updated successfully"}, + 404: {"description" : "The model could not be found"}, + 400: {"description" : "Bad request"} + }, + status_code = 200, + response_model = UpdateModelResponse, ) async def update_model( - model_request: CreateModelRequest -) -> CreateModelResponse: + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), +) -> UpdateModelResponse: """ Add Model """ - model_request_info = model_request.info - info_dict = model_request_info.dict() - model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success") - - ApiDependencies.invoker.services.model_manager.add_model( - model_name=model_request.name, - model_attributes=info_dict, - clobber=True, - ) + try: + ApiDependencies.invoker.services.model_manager.update_model( + model_name=model_name, + base_model=base_model, + model_type=model_type, + model_attributes=info.dict() + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ) + model_response = parse_obj_as(UpdateModelResponse, model_raw) + except KeyError as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) return model_response @models_router.post( - "/import", + "/", operation_id="import_model", responses= { 201: {"description" : "The model imported successfully"}, 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description" : "There is already a model corresponding to this path or repo_id"}, }, status_code=201, response_model=ImportModelResponse ) async def import_model( - name: str = Query(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), + location: str = Body(description="A model path, repo_id or URL to import"), + prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ + Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: """ Add a model using its local path, repo_id, or remote URL """ - items_to_import = {name} + + items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } logger = ApiDependencies.invoker.services.logger - - installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(prediction_type) - ) - if info := installed_models.get(name): - logger.info(f'Successfully imported {name}, got {info}') - return ImportModelResponse( - name = name, - info = info, - status = "success", + + try: + installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( + items_to_import = items_to_import, + prediction_type_helper = lambda x: prediction_types.get(prediction_type) ) - else: - logger.error(f'Model {name} not imported') - raise HTTPException(status_code=404, detail=f'Model {name} not found') + info = installed_models.get(location) + + if not info: + logger.error("Import failed") + raise HTTPException(status_code=424) + + logger.info(f'Successfully imported {location}, got {info}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=info.name, + base_model=info.base_model, + model_type=info.model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + @models_router.delete( - "/{model_name}", + "/{base_model}/{model_type}/{model_name}", operation_id="del_model", responses={ 204: { @@ -156,144 +139,95 @@ async def import_model( } }, ) -async def delete_model(model_name: str) -> None: +async def delete_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), +) -> Response: """Delete Model""" - model_names = ApiDependencies.invoker.services.model_manager.model_names() logger = ApiDependencies.invoker.services.logger - model_exists = model_name in model_names - - # check if model exists - logger.info(f"Checking for model {model_name}...") - - if model_exists: - logger.info(f"Deleting Model: {model_name}") - ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) - logger.info(f"Model Deleted: {model_name}") - raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") - else: - logger.error("Model not found") + try: + ApiDependencies.invoker.services.model_manager.del_model(model_name, + base_model = base_model, + model_type = model_type + ) + logger.info(f"Deleted model: {model_name}") + return Response(status_code=204) + except KeyError: + logger.error(f"Model not found: {model_name}") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - - # @socketio.on("convertToDiffusers") - # def convert_to_diffusers(model_to_convert: dict): - # try: - # if model_info := self.generate.model_manager.model_info( - # model_name=model_to_convert["model_name"] - # ): - # if "weights" in model_info: - # ckpt_path = Path(model_info["weights"]) - # original_config_file = Path(model_info["config"]) - # model_name = model_to_convert["model_name"] - # model_description = model_info["description"] - # else: - # self.socketio.emit( - # "error", {"message": "Model is not a valid checkpoint file"} - # ) - # else: - # self.socketio.emit( - # "error", {"message": "Could not retrieve model info."} - # ) - - # if not ckpt_path.is_absolute(): - # ckpt_path = Path(Globals.root, ckpt_path) - - # if original_config_file and not original_config_file.is_absolute(): - # original_config_file = Path(Globals.root, original_config_file) - - # diffusers_path = Path( - # ckpt_path.parent.absolute(), f"{model_name}_diffusers" - # ) - - # if model_to_convert["save_location"] == "root": - # diffusers_path = Path( - # global_converted_ckpts_dir(), f"{model_name}_diffusers" - # ) - - # if ( - # model_to_convert["save_location"] == "custom" - # and model_to_convert["custom_location"] is not None - # ): - # diffusers_path = Path( - # model_to_convert["custom_location"], f"{model_name}_diffusers" - # ) - - # if diffusers_path.exists(): - # shutil.rmtree(diffusers_path) - - # self.generate.model_manager.convert_and_import( - # ckpt_path, - # diffusers_path, - # model_name=model_name, - # model_description=model_description, - # vae=None, - # original_config_file=original_config_file, - # commit_to_conf=opt.conf, - # ) - - # new_model_list = self.generate.model_manager.list_models() - # socketio.emit( - # "modelConverted", - # { - # "new_model_name": model_name, - # "model_list": new_model_list, - # "update": True, - # }, - # ) - # print(f">> Model Converted: {model_name}") - # except Exception as e: - # self.handle_exceptions(e) - - # @socketio.on("mergeDiffusersModels") - # def merge_diffusers_models(model_merge_info: dict): - # try: - # models_to_merge = model_merge_info["models_to_merge"] - # model_ids_or_paths = [ - # self.generate.model_manager.model_name_or_path(x) - # for x in models_to_merge - # ] - # merged_pipe = merge_diffusion_models( - # model_ids_or_paths, - # model_merge_info["alpha"], - # model_merge_info["interp"], - # model_merge_info["force"], - # ) - - # dump_path = global_models_dir() / "merged_models" - # if model_merge_info["model_merge_save_path"] is not None: - # dump_path = Path(model_merge_info["model_merge_save_path"]) - - # os.makedirs(dump_path, exist_ok=True) - # dump_path = dump_path / model_merge_info["merged_model_name"] - # merged_pipe.save_pretrained(dump_path, safe_serialization=1) - - # merged_model_config = dict( - # model_name=model_merge_info["merged_model_name"], - # description=f'Merge of models {", ".join(models_to_merge)}', - # commit_to_conf=opt.conf, - # ) - - # if vae := self.generate.model_manager.config[models_to_merge[0]].get( - # "vae", None - # ): - # print(f">> Using configured VAE assigned to {models_to_merge[0]}") - # merged_model_config.update(vae=vae) - - # self.generate.model_manager.import_diffuser_model( - # dump_path, **merged_model_config - # ) - # new_model_list = self.generate.model_manager.list_models() - - # socketio.emit( - # "modelsMerged", - # { - # "merged_models": models_to_merge, - # "merged_model_name": model_merge_info["merged_model_name"], - # "model_list": new_model_list, - # "update": True, - # }, - # ) - # print(f">> Models Merged: {models_to_merge}") - # print(f">> New Model Added: {model_merge_info['merged_model_name']}") - # except Exception as e: +@models_router.put( + "/convert/{base_model}/{model_type}/{model_name}", + operation_id="convert_model", + responses={ + 200: { "description": "Model converted successfully" }, + 400: {"description" : "Bad request" }, + 404: { "description": "Model not found" }, + }, + status_code = 200, + response_model = ConvertModelResponse, +) +async def convert_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), +) -> ConvertModelResponse: + """Convert a checkpoint model into a diffusers model""" + logger = ApiDependencies.invoker.services.logger + try: + logger.info(f"Converting model: {model_name}") + ApiDependencies.invoker.services.model_manager.convert_model(model_name, + base_model = base_model, + model_type = model_type + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, + base_model = base_model, + model_type = model_type) + response = parse_obj_as(ConvertModelResponse, model_raw) + except KeyError: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return response + +@models_router.put( + "/merge/{base_model}", + operation_id="merge_models", + responses={ + 200: { "description": "Model converted successfully" }, + 400: { "description": "Incompatible models" }, + 404: { "description": "One or more models not found" }, + }, + status_code = 200, + response_model = MergeModelResponse, +) +async def merge_models( + base_model: BaseModelType = Path(description="Base model"), + model_names: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description = "Name of destination model"), + alpha: Optional[float] = Body(description = "Alpha weighting strength to apply to 2d and 3d models", default=0.5), + interp: Union[MergeInterpolationMethod, None] = Body(description = "Interpolation method"), + force: Optional[bool] = Body(description = "Force merging of models created with different versions of diffusers", default=False), +) -> MergeModelResponse: + """Convert a checkpoint model into a diffusers model""" + logger = ApiDependencies.invoker.services.logger + try: + logger.info(f"Merging models: {model_names}") + result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, + base_model, + merged_model_name or "+".join(model_names), + alpha, + interp, + force) + model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, + base_model = base_model, + model_type = ModelType.Main, + ) + response = parse_obj_as(ConvertModelResponse, model_raw) + except KeyError: + raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return response diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 455d9d021f..eb2c014b1a 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -2,22 +2,29 @@ from __future__ import annotations -import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING -from dataclasses import dataclass +from pydantic import Field +from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING +from types import ModuleType -from invokeai.backend.model_management.model_manager import ( +from invokeai.backend.model_management import ( ModelManager, BaseModelType, ModelType, SubModelType, ModelInfo, + AddModelResult, + SchedulerPredictionType, + ModelMerger, + MergeInterpolationMethod, ) + + +import torch from invokeai.app.models.exceptions import CanceledException -from .config import InvokeAIAppConfig from ...backend.util import choose_precision, choose_torch_device +from .config import InvokeAIAppConfig if TYPE_CHECKING: from ..invocations.baseinvocation import BaseInvocation, InvocationContext @@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC): def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: """ Given a model name returns a dict-like (OmegaConf) object describing it. - """ - pass - - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. + Uses the exact format as the omegaconf stanza. """ pass @@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC): } """ pass + + @abstractmethod + def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Return information about the model using the same format as list_models() + """ + pass + @abstractmethod + def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: + """ + Returns a list of all the model names known. + """ + pass @abstractmethod def add_model( @@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC): model_type: ModelType, model_attributes: dict, clobber: bool = False - ) -> None: + ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. @@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def update_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + model_attributes: dict, + ) -> AddModelResult: + """ + Update the named model with a dictionary of attributes. Will fail with a + KeyErrorException if the name does not already exist. + + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model name is missing. Call commit() to write changes to disk. + """ + pass + @abstractmethod def del_model( self, @@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def convert_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: Union[ModelType.Main,ModelType.Vae], + ) -> AddModelResult: + """ + Convert a checkpoint file into a diffusers folder, deleting the cached + version and deleting the original checkpoint file if it is in the models + directory. + :param model_name: Name of the model to convert + :param base_model: Base model type + :param model_type: Type of model ['vae' or 'main'] + + This will raise a ValueError unless the model is not a checkpoint. It will + also raise a ValueError in the event that there is a similarly-named diffusers + directory already in place. + """ + pass + @abstractmethod def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def commit(self, conf_file: Path = None) -> None: + def merge_models( + self, + model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), + base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + ) -> AddModelResult: + """ + Merge two to three diffusrs pipeline models and save as a new model. + :param model_names: List of 2-3 models to merge + :param base_model: Base model to use for all models + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + """ + pass + + @abstractmethod + def commit(self, conf_file: Optional[Path] = None) -> None: """ Write current configuration out to the indicated file. If no conf_file is provided, then replaces the @@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -299,12 +372,19 @@ class ModelManagerService(ModelManagerServiceBase): base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None ) -> list[dict]: - # ) -> dict: """ Return a list of models. """ return self.mgr.list_models(base_model, model_type) + def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Return information about the model using the same format as list_models() + """ + return self.mgr.list_model(model_name=model_name, + base_model=base_model, + model_type=model_type) + def add_model( self, model_name: str, @@ -320,9 +400,28 @@ class ModelManagerService(ModelManagerServiceBase): with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ + self.logger.debug(f'add/update model {model_name}') return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) - + def update_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + model_attributes: dict, + ) -> AddModelResult: + """ + Update the named model with a dictionary of attributes. Will fail with a + KeyError exception if the name does not already exist. + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model name is missing. Call commit() to write changes to disk. + """ + self.logger.debug(f'update model {model_name}') + if not self.model_exists(model_name, base_model, model_type): + raise KeyError(f"Unknown model {model_name}") + return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) + def del_model( self, model_name: str, @@ -334,8 +433,29 @@ class ModelManagerService(ModelManagerServiceBase): then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ + self.logger.debug(f'delete model {model_name}') self.mgr.del_model(model_name, base_model, model_type) + def convert_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: Union[ModelType.Main,ModelType.Vae], + ) -> AddModelResult: + """ + Convert a checkpoint file into a diffusers folder, deleting the cached + version and deleting the original checkpoint file if it is in the models + directory. + :param model_name: Name of the model to convert + :param base_model: Base model type + :param model_type: Type of model ['vae' or 'main'] + + This will raise a ValueError unless the model is not a checkpoint. It will + also raise a ValueError in the event that there is a similarly-named diffusers + directory already in place. + """ + self.logger.debug(f'convert model {model_name}') + return self.mgr.convert_model(model_name, base_model, model_type) def commit(self, conf_file: Optional[Path]=None): """ @@ -387,9 +507,9 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.logger def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -406,4 +526,31 @@ class ModelManagerService(ModelManagerServiceBase): of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. ''' - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + + def merge_models( + self, + model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), + base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + ) -> AddModelResult: + """ + Merge two to three diffusrs pipeline models and save as a new model. + :param model_names: List of 2-3 models to merge + :param base_model: Base model to use for all models + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + """ + merger = ModelMerger(self.mgr) + return merger.merge_diffusion_models_and_save( + model_names = model_names, + base_model = base_model, + merged_model_name = merged_model_name, + alpha = alpha, + interp = interp, + force = force, + ) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 86a922c05a..b9225d1654 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -166,14 +166,18 @@ class ModelInstall(object): # add requested models for path in selections.install_models: logger.info(f'Installing {path} [{job}/{jobs}]') - self.heuristic_import(path) + try: + self.heuristic_import(path) + except (ValueError, KeyError) as e: + logger.error(str(e)) job += 1 self.mgr.commit() def heuristic_import(self, - model_path_id_or_url: Union[str,Path], - models_installed: Set[Path]=None)->Dict[str, AddModelResult]: + model_path_id_or_url: Union[str,Path], + models_installed: Set[Path]=None, + )->Dict[str, AddModelResult]: ''' :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param models_installed: Set of installed models, used for recursive invocation @@ -187,61 +191,53 @@ class ModelInstall(object): self.current_id = model_path_id_or_url path = Path(model_path_id_or_url) - try: - # checkpoint file, or similar - if path.is_file(): - models_installed.update(self._install_path(path)) + # checkpoint file, or similar + if path.is_file(): + models_installed.update({str(path):self._install_path(path)}) - # folders style or similar - elif path.is_dir() and any([(path/x).exists() for x in \ - {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} - ] - ): - models_installed.update(self._install_path(path)) + # folders style or similar + elif path.is_dir() and any([(path/x).exists() for x in \ + {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} + ] + ): + models_installed.update(self._install_path(path)) - # recursive scan - elif path.is_dir(): - for child in path.iterdir(): - self.heuristic_import(child, models_installed=models_installed) + # recursive scan + elif path.is_dir(): + for child in path.iterdir(): + self.heuristic_import(child, models_installed=models_installed) - # huggingface repo - elif len(str(path).split('/')) == 2: - models_installed.update(self._install_repo(str(path))) + # huggingface repo + elif len(str(model_path_id_or_url).split('/')) == 2: + models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) - # a URL - elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): - models_installed.update(self._install_url(model_path_id_or_url)) + # a URL + elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): + models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) - else: - logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') - - except ValueError as e: - logger.error(str(e)) + else: + raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') return models_installed # install a model from a local path. The optional info parameter is there to prevent # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: - try: - model_result = None - info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) - model_name = path.stem if path.is_file() else path.name - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - raise ValueError(f'A model named "{model_name}" is already installed.') - attributes = self._make_attributes(path,info) - model_result = self.mgr.add_model(model_name = model_name, - base_model = info.base_type, - model_type = info.model_type, - model_attributes = attributes, - ) - except Exception as e: - logger.warning(f'{str(e)} Skipping registration.') - return {} - return {str(path): model_result} + def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult: + info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) + if not info: + logger.warning(f'Unable to parse format of {path}') + return None + model_name = path.stem if path.is_file() else path.name + if self.mgr.model_exists(model_name, info.base_type, info.model_type): + raise ValueError(f'A model named "{model_name}" is already installed.') + attributes = self._make_attributes(path,info) + return self.mgr.add_model(model_name = model_name, + base_model = info.base_type, + model_type = info.model_type, + model_attributes = attributes, + ) - def _install_url(self, url: str)->dict: - # copy to a staging area, probe, import and delete + def _install_url(self, url: str)->AddModelResult: with TemporaryDirectory(dir=self.config.models_path) as staging: location = download_with_resume(url,Path(staging)) if not location: @@ -253,7 +249,7 @@ class ModelInstall(object): # staged version will be garbage-collected at this time return self._install_path(Path(models_path), info) - def _install_repo(self, repo_id: str)->dict: + def _install_repo(self, repo_id: str)->AddModelResult: hinfo = HfApi().model_info(repo_id) # we try to figure out how to download this most economically diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 34e0b15728..e31085acef 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -1,7 +1,8 @@ """ Initialization file for invokeai.backend.model_management """ -from .model_manager import ModelManager, ModelInfo, AddModelResult +from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_cache import ModelCache from .models import BaseModelType, ModelType, SubModelType, ModelVariantType +from .model_merge import ModelMerger, MergeInterpolationMethod diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index d8ecdf81c2..e98d71e85c 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -2,8 +2,8 @@ from __future__ import annotations import copy from contextlib import contextmanager +from typing import Optional, Dict, Tuple, Any, Union, List from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union, List import torch from compel.embeddings_provider import BaseTextualInversionManager diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index db8a691d29..b4827bfd32 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -234,7 +234,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types -from shutil import rmtree +from shutil import rmtree, move import torch from omegaconf import OmegaConf @@ -279,7 +279,7 @@ class InvalidModelError(Exception): pass class AddModelResult(BaseModel): - name: str = Field(description="The name of the model after import") + name: str = Field(description="The name of the model after installation") model_type: ModelType = Field(description="The type of model") base_model: BaseModelType = Field(description="The base model") config: ModelConfigBase = Field(description="The configuration of the model") @@ -491,17 +491,32 @@ class ModelManager(object): """ return [(self.parse_key(x)) for x in self.models.keys()] + def list_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + ) -> dict: + """ + Returns a dict describing one installed model, using + the combined format of the list_models() method. + """ + models = self.list_models(base_model,model_type,model_name) + return models[0] if models else None + def list_models( self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, + model_name: Optional[str] = None, ) -> list[dict]: """ Return a list of models. """ + model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) models = [] - for model_key in sorted(self.models, key=str.casefold): + for model_key in model_keys: model_config = self.models[model_key] cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) @@ -546,10 +561,7 @@ class ModelManager(object): model_cfg = self.models.pop(model_key, None) if model_cfg is None: - self.logger.error( - f"Unknown model {model_key}" - ) - return + raise KeyError(f"Unknown model {model_key}") # note: it not garantie to release memory(model can has other references) cache_ids = self.cache_keys.pop(model_key, []) @@ -615,6 +627,7 @@ class ModelManager(object): self.cache.uncache_model(cache_id) self.models[model_key] = model_config + self.commit() return AddModelResult( name = model_name, model_type = model_type, @@ -622,6 +635,60 @@ class ModelManager(object): config = model_config, ) + def convert_model ( + self, + model_name: str, + base_model: BaseModelType, + model_type: Union[ModelType.Main,ModelType.Vae], + ) -> AddModelResult: + ''' + Convert a checkpoint file into a diffusers folder, deleting the cached + version and deleting the original checkpoint file if it is in the models + directory. + :param model_name: Name of the model to convert + :param base_model: Base model type + :param model_type: Type of model ['vae' or 'main'] + + This will raise a ValueError unless the model is a checkpoint. + ''' + info = self.model_info(model_name, base_model, model_type) + if info["model_format"] != "checkpoint": + raise ValueError(f"not a checkpoint format model: {model_name}") + + # We are taking advantage of a side effect of get_model() that converts check points + # into cached diffusers directories stored at `location`. It doesn't matter + # what submodeltype we request here, so we get the smallest. + submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {} + model = self.get_model(model_name, + base_model, + model_type, + **submodel, + ) + checkpoint_path = self.app_config.root_path / info["path"] + old_diffusers_path = self.app_config.models_path / model.location + new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name + if new_diffusers_path.exists(): + raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") + + try: + move(old_diffusers_path,new_diffusers_path) + info["model_format"] = "diffusers" + info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path)) + info.pop('config') + + result = self.add_model(model_name, base_model, model_type, + model_attributes = info, + clobber=True) + except: + # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! + rmtree(new_diffusers_path) + raise + + if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path): + checkpoint_path.unlink() + + return result + def search_models(self, search_folder): self.logger.info(f"Finding Models In: {search_folder}") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") @@ -821,6 +888,10 @@ class ModelManager(object): The result is a set of successfully installed models. Each element of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. + + May return the following exceptions: + - KeyError - one or more of the items to import is not a valid path, repo_id or URL + - ValueError - a corresponding model already exists ''' # avoid circular import here from invokeai.backend.install.model_install_backend import ModelInstall @@ -830,11 +901,7 @@ class ModelManager(object): prediction_type_helper = prediction_type_helper, model_manager = self) for thing in items_to_import: - try: - installed = installer.heuristic_import(thing) - successfully_installed.update(installed) - except Exception as e: - self.logger.warning(f'{thing} could not be imported: {str(e)}') - + installed = installer.heuristic_import(thing) + successfully_installed.update(installed) self.commit() return successfully_installed diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py new file mode 100644 index 0000000000..1a110a47b8 --- /dev/null +++ b/invokeai/backend/model_management/model_merge.py @@ -0,0 +1,129 @@ +""" +invokeai.backend.model_management.model_merge exports: +merge_diffusion_models() -- combine multiple models by location and return a pipeline object +merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml + +Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team +""" + +import warnings +from enum import Enum +from pathlib import Path +from diffusers import DiffusionPipeline +from diffusers import logging as dlogging +from typing import List, Union + +import invokeai.backend.util.logging as logger + +from invokeai.app.services.config import InvokeAIAppConfig +from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult + +class MergeInterpolationMethod(str, Enum): + Sigmoid = "sigmoid" + InvSigmoid = "inv_sigmoid" + AddDifference = "add_difference" + WeightedSum = "weighted_sum" + +class ModelMerger(object): + def __init__(self, manager: ModelManager): + self.manager = manager + + def merge_diffusion_models( + self, + model_paths: List[Path], + alpha: float = 0.5, + interp: MergeInterpolationMethod = None, + force: bool = False, + **kwargs, + ) -> DiffusionPipeline: + """ + :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids + :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha + would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 + :param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. + Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. + :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + + **kwargs - the default DiffusionPipeline.get_config_dict kwargs: + cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + verbosity = dlogging.get_verbosity() + dlogging.set_verbosity_error() + + pipe = DiffusionPipeline.from_pretrained( + model_paths[0], + custom_pipeline="checkpoint_merger", + ) + merged_pipe = pipe.merge( + pretrained_model_name_or_path_list=model_paths, + alpha=alpha, + interp=interp.value if interp else None, #diffusers API treats None as "weighted sum" + force=force, + **kwargs, + ) + dlogging.set_verbosity(verbosity) + return merged_pipe + + + def merge_diffusion_models_and_save ( + self, + model_names: List[str], + base_model: Union[BaseModelType,str], + merged_model_name: str, + alpha: float = 0.5, + interp: MergeInterpolationMethod = None, + force: bool = False, + **kwargs, + ) -> AddModelResult: + """ + :param models: up to three models, designated by their InvokeAI models.yaml model name + :param base_model: base model (must be the same for all merged models!) + :param merged_model_name: name for new model + :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha + would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 + :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. + Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). + :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + + **kwargs - the default DiffusionPipeline.get_config_dict kwargs: + cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map + """ + model_paths = list() + config = self.manager.app_config + base_model = BaseModelType(base_model) + vae = None + + for mod in model_names: + info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) + assert info, f"model {mod}, base_model {base_model}, is unknown" + assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging" + assert info["variant"] == "normal", (f"{mod} is a {info['variant']} model, which cannot currently be merged") + # pick up the first model's vae + if mod == model_names[0]: + vae = info.get("vae") + model_paths.extend([config.root_path / info["path"]]) + + merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp) + merged_pipe = self.merge_diffusion_models( + model_paths, alpha, merge_method, force, **kwargs + ) + dump_path = config.models_path / base_model.value / ModelType.Main.value + dump_path.mkdir(parents=True, exist_ok=True) + dump_path = dump_path / merged_model_name + + merged_pipe.save_pretrained(dump_path, safe_serialization=1) + attributes = dict( + path = str(dump_path), + description = f"Merge of models {', '.join(model_names)}", + model_format = "diffusers", + variant = ModelVariantType.Normal.value, + vae = vae, + ) + return self.manager.add_model(merged_model_name, + base_model = base_model, + model_type = ModelType.Main, + model_attributes = attributes, + clobber = True + ) diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index a5d43c98a2..c98d5a0ae8 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel): version=BaseModelType.StableDiffusion1, model_config=config, output_path=output_path, - ) + ) else: return model_path diff --git a/invokeai/frontend/merge/__init__.py b/invokeai/frontend/merge/__init__.py index fb892fd7db..f1fc66c39e 100644 --- a/invokeai/frontend/merge/__init__.py +++ b/invokeai/frontend/merge/__init__.py @@ -1,4 +1,5 @@ """ Initialization file for invokeai.frontend.merge """ -from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models +from .merge_diffusers import main as invokeai_merge_diffusers + diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 9da04b97f8..c20d913883 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team """ import argparse import curses -import os import sys -import warnings from argparse import Namespace from pathlib import Path from typing import List, Union @@ -20,99 +18,15 @@ from npyscreen import widget from omegaconf import OmegaConf import invokeai.backend.util.logging as logger -from invokeai.services.config import InvokeAIAppConfig -from ...backend.model_management import ModelManager -from ...frontend.install.widgets import FloatTitleSlider +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_management import ( + ModelMerger, MergeInterpolationMethod, + ModelManager, ModelType, BaseModelType, +) +from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns -DEST_MERGED_MODEL_DIR = "merged_models" config = InvokeAIAppConfig.get_config() -def merge_diffusion_models( - model_ids_or_paths: List[Union[str, Path]], - alpha: float = 0.5, - interp: str = None, - force: bool = False, - **kwargs, -) -> DiffusionPipeline: - """ - model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids - alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. - force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - - pipe = DiffusionPipeline.from_pretrained( - model_ids_or_paths[0], - cache_dir=kwargs.get("cache_dir", config.cache_dir), - custom_pipeline="checkpoint_merger", - ) - merged_pipe = pipe.merge( - pretrained_model_name_or_path_list=model_ids_or_paths, - alpha=alpha, - interp=interp, - force=force, - **kwargs, - ) - dlogging.set_verbosity(verbosity) - return merged_pipe - - -def merge_diffusion_models_and_commit( - models: List["str"], - merged_model_name: str, - alpha: float = 0.5, - interp: str = None, - force: bool = False, - **kwargs, -): - """ - models - up to three models, designated by their InvokeAI models.yaml model name - merged_model_name = name for new model - alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). - force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - config_file = config.model_conf_path - model_manager = ModelManager(OmegaConf.load(config_file)) - for mod in models: - assert mod in model_manager.model_names(), f'** Unknown model "{mod}"' - assert ( - model_manager.model_info(mod).get("format", None) == "diffusers" - ), f"** {mod} is not a diffusers model. It must be optimized before merging." - model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models] - - merged_pipe = merge_diffusion_models( - model_ids_or_paths, alpha, interp, force, **kwargs - ) - dump_path = config.models_dir / DEST_MERGED_MODEL_DIR - - os.makedirs(dump_path, exist_ok=True) - dump_path = dump_path / merged_model_name - merged_pipe.save_pretrained(dump_path, safe_serialization=1) - import_args = dict( - model_name=merged_model_name, description=f'Merge of models {", ".join(models)}' - ) - if vae := model_manager.config[models[0]].get("vae", None): - logger.info(f"Using configured VAE assigned to {models[0]}") - import_args.update(vae=vae) - model_manager.import_diffuser_model(dump_path, **import_args) - model_manager.commit(config_file) - - def _parse_args() -> Namespace: parser = argparse.ArgumentParser(description="InvokeAI model merging") parser.add_argument( @@ -131,10 +45,17 @@ def _parse_args() -> Namespace: ) parser.add_argument( "--models", + dest="model_names", type=str, nargs="+", help="Two to three model names to be merged", ) + parser.add_argument( + "--base_model", + type=str, + choices=[x.value for x in BaseModelType], + help="The base model shared by the models to be merged", + ) parser.add_argument( "--merged_model_name", "--destination", @@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): window_height, window_width = curses.initscr().getmaxyx() self.model_names = self.get_model_names() + self.current_base = 0 max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): value="Use up and down arrows to move, to select an item, and to move from one field to the next.", editable=False, ) + self.nextrely += 1 + self.base_select = self.add_widget_intelligent( + SingleSelectColumns, + values=[ + 'Models Built on SD-1.x', + 'Models Built on SD-2.x', + ], + value=[self.current_base], + columns = 4, + max_height = 2, + relx=8, + scroll_exit = True, + ) + self.base_select.on_changed = self._populate_models self.add_widget_intelligent( npyscreen.FixedText, value="MODEL 1", color="GOOD", editable=False, - rely=4 if horizontal_layout else None, + rely=6 if horizontal_layout else None, ) self.model1 = self.add_widget_intelligent( npyscreen.SelectOne, @@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): max_height=len(self.model_names), max_width=max_width, scroll_exit=True, - rely=5, + rely=7, ) self.add_widget_intelligent( npyscreen.FixedText, @@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): color="GOOD", editable=False, relx=max_width + 3 if horizontal_layout else None, - rely=4 if horizontal_layout else None, + rely=6 if horizontal_layout else None, ) self.model2 = self.add_widget_intelligent( npyscreen.SelectOne, @@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): max_height=len(self.model_names), max_width=max_width, relx=max_width + 3 if horizontal_layout else None, - rely=5 if horizontal_layout else None, + rely=7 if horizontal_layout else None, scroll_exit=True, ) self.add_widget_intelligent( @@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): color="GOOD", editable=False, relx=max_width * 2 + 3 if horizontal_layout else None, - rely=4 if horizontal_layout else None, + rely=6 if horizontal_layout else None, ) models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): max_width=max_width, scroll_exit=True, relx=max_width * 2 + 3 if horizontal_layout else None, - rely=5 if horizontal_layout else None, + rely=7 if horizontal_layout else None, ) for m in [self.model1, self.model2, self.model3]: m.when_value_edited = self.models_changed self.merged_model_name = self.add_widget_intelligent( - npyscreen.TitleText, + TextBox, name="Name for merged model:", labelColor="CONTROL", + max_height=3, value="", scroll_exit=True, ) self.force = self.add_widget_intelligent( npyscreen.Checkbox, - name="Force merge of incompatible models", + name="Force merge of models created by different diffusers library versions", labelColor="CONTROL", - value=False, + value=True, scroll_exit=True, ) + self.nextrely += 1 self.merge_method = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Merge Method:", @@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): interp = self.interpolations[self.merge_method.value[0]] args = dict( - models=models, + model_names=models, + base_model=tuple(BaseModelType)[self.base_select.value[0]], alpha=self.alpha.value, interp=interp, force=self.force.value, @@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): else: return True - def get_model_names(self) -> List[str]: + def get_model_names(self, base_model: BaseModelType=None) -> List[str]: model_names = [ - name - for name in self.model_manager.model_names() - if self.model_manager.model_info(name).get("format") == "diffusers" + info["name"] + for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) + if info["model_format"] == "diffusers" ] return sorted(model_names) + def _populate_models(self,value=None): + base_model = tuple(BaseModelType)[value[0]] + self.model_names = self.get_model_names(base_model) + + models_plus_none = self.model_names.copy() + models_plus_none.insert(0, "None") + self.model1.values = self.model_names + self.model2.values = self.model_names + self.model3.values = models_plus_none + + self.display() + class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self): + def __init__(self, model_manager:ModelManager): super().__init__() - conf = OmegaConf.load(config.model_conf_path) - self.model_manager = ModelManager( - conf, "cpu", "float16" - ) # precision doesn't really matter here + self.model_manager = model_manager def onStart(self): npyscreen.setTheme(npyscreen.Themes.ElegantTheme) @@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged): def run_gui(args: Namespace): - mergeapp = Mergeapp() + model_manager = ModelManager(config.model_conf_path) + mergeapp = Mergeapp(model_manager) mergeapp.run() args = mergeapp.merge_arguments - merge_diffusion_models_and_commit(**args) + merger = ModelMerger(model_manager) + merger.merge_diffusion_models_and_save(**args) logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') def run_cli(args: Namespace): assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" assert ( - args.models and len(args.models) >= 1 and len(args.models) <= 3 + args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3 ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." if not args.merged_model_name: - args.merged_model_name = "+".join(args.models) + args.merged_model_name = "+".join(args.model_names) logger.info( f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"' ) - model_manager = ModelManager(OmegaConf.load(config.model_conf_path)) - assert ( - args.clobber or args.merged_model_name not in model_manager.model_names() - ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' + model_manager = ModelManager(config.model_conf_path) + assert ( + not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - merge_diffusion_models_and_commit(**vars(args)) - logger.info(f'Models merged into new model: "{args.merged_model_name}".') + merger = ModelMerger(model_manager) + merger.merge_diffusion_models_and_save(**vars(args)) + logger.info(f'Models merged into new model: "{args.merged_model_name}".') def main(): args = _parse_args() - config.root = args.root_dir - - cache_dir = config.cache_dir - os.environ[ - "HF_HOME" - ] = cache_dir # because not clear the merge pipeline is honoring cache_dir - args.cache_dir = cache_dir + config.parse_args(['--root',str(args.root_dir)]) try: if args.front_end: