From 2faa7cee37f715df6de448bf72176309ead44070 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 14 Jul 2023 23:03:18 -0400 Subject: [PATCH] add rename_model route --- invokeai/app/api/routers/models.py | 89 ++++++++++++++++++- .../app/services/model_manager_service.py | 35 ++++++++ .../backend/install/model_install_backend.py | 2 - .../backend/model_management/model_manager.py | 49 ++++++++++ .../model_management/models/__init__.py | 4 +- .../backend/model_management/models/base.py | 1 - .../models/stable_diffusion.py | 3 +- 7 files changed, 175 insertions(+), 8 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index d0e0361ad9..c298114cbc 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -23,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @@ -79,7 +80,7 @@ async def update_model( return model_response @models_router.post( - "/", + "/import", operation_id="import_model", responses= { 201: {"description" : "The model imported successfully"}, @@ -95,7 +96,7 @@ async def import_model( prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: - """ Add a model using its local path, repo_id, or remote URL """ + """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } @@ -127,7 +128,91 @@ async def import_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) +@models_router.post( + "/add", + operation_id="add_model", + responses= { + 201: {"description" : "The model added successfully"}, + 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, + 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def add_model( + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), +) -> ImportModelResponse: + """ Add a model using the configuration information appropriate for its type. Only local models can be added by path""" + + logger = ApiDependencies.invoker.services.logger + try: + ApiDependencies.invoker.services.model_manager.add_model( + info.model_name, + info.base_model, + info.model_type, + model_attributes = info.dict() + ) + logger.info(f'Successfully added {info.model_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=info.model_name, + base_model=info.base_model, + model_type=info.model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + +@models_router.post( + "/rename/{base_model}/{model_type}/{model_name}", + operation_id="rename_model", + responses= { + 201: {"description" : "The model was renamed successfully"}, + 404: {"description" : "The model could not be found"}, + 409: {"description" : "There is already a model corresponding to the new name"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def rename_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="current model name"), + new_name: Optional[str] = Query(description="new model name", default=None), + new_base: Optional[BaseModelType] = Query(description="new model base", default=None), +) -> ImportModelResponse: + """ Rename a model""" + + logger = ApiDependencies.invoker.services.logger + + try: + result = ApiDependencies.invoker.services.model_manager.rename_model( + base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + logger.debug(result) + logger.info(f'Successfully renamed {model_name}=>{new_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=new_name or model_name, + base_model=new_base or base_model, + model_type=model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + @models_router.delete( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 3c5dad7b3e..67db5c9478 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -167,6 +167,18 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str, + ): + """ + Rename the indicated model. + """ + pass + @abstractmethod def list_checkpoint_configs( self @@ -615,3 +627,26 @@ class ModelManagerService(ModelManagerServiceBase): conf_path = config.legacy_conf_path root_path = config.root_path return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')] + + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + """ + Rename the indicated model. Can provide a new name and/or a new base. + :param model_name: Current name of the model + :param base_model: Current base of the model + :param model_type: Model type (can't be changed) + :param new_name: New name for the model + :param new_base: New base for the model + """ + self.mgr.rename_model(base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index b6f6d62d97..2e537313ac 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -71,8 +71,6 @@ class ModelInstallList: class InstallSelections(): install_models: List[str]= field(default_factory=list) remove_models: List[str]=field(default_factory=list) -# scan_directory: Path = None -# autoscan_on_startup: bool=False @dataclass class ModelLoadInfo(): diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index f4485bf67a..55f6de9b5b 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -671,6 +671,55 @@ class ModelManager(object): config = model_config, ) + def rename_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + ''' + Rename or rebase a model. + ''' + if new_name is None and new_base is None: + self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") + return + + model_key = self.create_key(model_name, base_model, model_type) + model_cfg = self.models.get(model_key, None) + if not model_cfg: + raise KeyError(f"Unknown model: {model_key}") + + old_path = self.app_config.root_path / model_cfg.path + new_name = new_name or model_name + new_base = new_base or base_model + new_key = self.create_key(new_name, new_base, model_type) + if new_key in self.models: + raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"') + + # if this is a model file/directory that we manage ourselves, we need to move it + if old_path.is_relative_to(self.app_config.models_path): + new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name + move(old_path, new_path) + model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) + + # clean up caches + old_model_cache = self._get_model_cache_path(old_path) + if old_model_cache.exists(): + if old_model_cache.is_dir(): + rmtree(str(old_model_cache)) + else: + old_model_cache.unlink() + + cache_ids = self.cache_keys.pop(model_key, []) + for cache_id in cache_ids: + self.cache.uncache_model(cache_id) + + self.models.pop(model_key, None) # delete + self.models[new_key] = model_cfg + self.commit() + def convert_model ( self, model_name: str, diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 1c573b26b6..e404c56bdf 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items(): model_configs.discard(None) MODEL_CONFIGS.extend(model_configs) - for cfg in model_configs: + # LS: sort to get the checkpoint configs first, which makes + # for a better template in the Swagger docs + for cfg in sorted(model_configs, key=lambda x: str(x)): model_name, cfg_name = cfg.__qualname__.split('.')[-2:] openapi_cfg_name = model_name + cfg_name if openapi_cfg_name in vars(): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index ddbc401e5b..c569872a81 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel): path: str # or Path description: Optional[str] = Field(None) model_format: Optional[str] = Field(None) - # do not save to config error: Optional[ModelError] = Field(None) class Config: diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 74751a40dd..3d2e50d8fb 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -37,8 +37,7 @@ class StableDiffusion1Model(DiffusersModel): vae: Optional[str] = Field(None) config: str variant: ModelVariantType - - + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.Main