diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8dbeaa3d05..c298114cbc 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -1,6 +1,7 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein +import pathlib from typing import Literal, List, Optional, Union from fastapi import Body, Path, Query, Response @@ -22,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @@ -78,7 +80,7 @@ async def update_model( return model_response @models_router.post( - "/", + "/import", operation_id="import_model", responses= { 201: {"description" : "The model imported successfully"}, @@ -94,7 +96,7 @@ async def import_model( prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: - """ Add a model using its local path, repo_id, or remote URL """ + """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } @@ -126,18 +128,100 @@ async def import_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) +@models_router.post( + "/add", + operation_id="add_model", + responses= { + 201: {"description" : "The model added successfully"}, + 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, + 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def add_model( + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), +) -> ImportModelResponse: + """ Add a model using the configuration information appropriate for its type. Only local models can be added by path""" + + logger = ApiDependencies.invoker.services.logger + try: + ApiDependencies.invoker.services.model_manager.add_model( + info.model_name, + info.base_model, + info.model_type, + model_attributes = info.dict() + ) + logger.info(f'Successfully added {info.model_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=info.model_name, + base_model=info.base_model, + model_type=info.model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + +@models_router.post( + "/rename/{base_model}/{model_type}/{model_name}", + operation_id="rename_model", + responses= { + 201: {"description" : "The model was renamed successfully"}, + 404: {"description" : "The model could not be found"}, + 409: {"description" : "There is already a model corresponding to the new name"}, + }, + status_code=201, + response_model=ImportModelResponse +) +async def rename_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="current model name"), + new_name: Optional[str] = Query(description="new model name", default=None), + new_base: Optional[BaseModelType] = Query(description="new model base", default=None), +) -> ImportModelResponse: + """ Rename a model""" + + logger = ApiDependencies.invoker.services.logger + + try: + result = ApiDependencies.invoker.services.model_manager.rename_model( + base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + logger.debug(result) + logger.info(f'Successfully renamed {model_name}=>{new_name}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=new_name or model_name, + base_model=new_base or base_model, + model_type=model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + @models_router.delete( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", responses={ - 204: { - "description": "Model deleted successfully" - }, - 404: { - "description": "Model not found" - } + 204: { "description": "Model deleted successfully" }, + 404: { "description": "Model not found" } }, + status_code = 204, + response_model = None, ) async def delete_model( base_model: BaseModelType = Path(description="Base model"), @@ -173,14 +257,17 @@ async def convert_model( base_model: BaseModelType = Path(description="Base model"), model_type: ModelType = Path(description="The type of model"), model_name: str = Path(description="model name"), + convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), ) -> ConvertModelResponse: - """Convert a checkpoint model into a diffusers model""" + """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Converting model: {model_name}") + dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None ApiDependencies.invoker.services.model_manager.convert_model(model_name, base_model = base_model, - model_type = model_type + model_type = model_type, + convert_dest_directory = dest, ) model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, base_model = base_model, @@ -191,6 +278,53 @@ async def convert_model( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response + +@models_router.get( + "/search", + operation_id="search_for_models", + responses={ + 200: { "description": "Directory searched successfully" }, + 404: { "description": "Invalid directory path" }, + }, + status_code = 200, + response_model = List[pathlib.Path] +) +async def search_for_models( + search_path: pathlib.Path = Query(description="Directory path to search for models") +)->List[pathlib.Path]: + if not search_path.is_dir(): + raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") + return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) + +@models_router.get( + "/ckpt_confs", + operation_id="list_ckpt_configs", + responses={ + 200: { "description" : "paths retrieved successfully" }, + }, + status_code = 200, + response_model = List[pathlib.Path] +) +async def list_ckpt_configs( +)->List[pathlib.Path]: + """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" + return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() + + +@models_router.get( + "/sync", + operation_id="sync_to_config", + responses={ + 201: { "description": "synchronization successful" }, + }, + status_code = 201, + response_model = None +) +async def sync_to_config( +)->None: + """Call after making changes to models.yaml, autoimport directories or models directory to synchronize + in-memory data structures with disk data structures.""" + return ApiDependencies.invoker.services.model_manager.sync_to_config() @models_router.put( "/merge/{base_model}", @@ -210,17 +344,21 @@ async def merge_models( alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), + merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) ) -> MergeModelResponse: """Convert a checkpoint model into a diffusers model""" logger = ApiDependencies.invoker.services.logger try: - logger.info(f"Merging models: {model_names}") + logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") + dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, base_model, - merged_model_name or "+".join(model_names), - alpha, - interp, - force) + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory = dest + ) model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, base_model = base_model, model_type = ModelType.Main, diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 1b1c43dc11..67db5c9478 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -19,7 +19,7 @@ from invokeai.backend.model_management import ( ModelMerger, MergeInterpolationMethod, ) - +from invokeai.backend.model_management.model_search import FindModels import torch from invokeai.app.models.exceptions import CanceledException @@ -167,6 +167,27 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str, + ): + """ + Rename the indicated model. + """ + pass + + @abstractmethod + def list_checkpoint_configs( + self + )->List[Path]: + """ + List the checkpoint config paths from ROOT/configs/stable-diffusion. + """ + pass + @abstractmethod def convert_model( self, @@ -220,6 +241,7 @@ class ModelManagerServiceBase(ABC): alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -228,9 +250,26 @@ class ModelManagerServiceBase(ABC): :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ pass - + + @abstractmethod + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + pass + + @abstractmethod + def sync_to_config(self): + """ + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + pass + @abstractmethod def commit(self, conf_file: Optional[Path] = None) -> None: """ @@ -431,16 +470,18 @@ class ModelManagerService(ModelManagerServiceBase): """ Delete the named model from configuration. If delete_files is true, then the underlying weight file or diffusers directory will be deleted - as well. Call commit() to write to disk. + as well. """ self.logger.debug(f'delete model {model_name}') self.mgr.del_model(model_name, base_model, model_type) + self.mgr.commit() def convert_model( self, model_name: str, base_model: BaseModelType, model_type: Union[ModelType.Main,ModelType.Vae], + convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), ) -> AddModelResult: """ Convert a checkpoint file into a diffusers folder, deleting the cached @@ -449,13 +490,14 @@ class ModelManagerService(ModelManagerServiceBase): :param model_name: Name of the model to convert :param base_model: Base model type :param model_type: Type of model ['vae' or 'main'] + :param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default) This will raise a ValueError unless the model is not a checkpoint. It will also raise a ValueError in the event that there is a similarly-named diffusers directory already in place. """ self.logger.debug(f'convert model {model_name}') - return self.mgr.convert_model(model_name, base_model, model_type) + return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) def commit(self, conf_file: Optional[Path]=None): """ @@ -536,6 +578,7 @@ class ModelManagerService(ModelManagerServiceBase): alpha: Optional[float] = 0.5, interp: Optional[MergeInterpolationMethod] = None, force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -544,6 +587,7 @@ class ModelManagerService(ModelManagerServiceBase): :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model :param interp: Interpolation method. None (default) + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ merger = ModelMerger(self.mgr) try: @@ -554,7 +598,55 @@ class ModelManagerService(ModelManagerServiceBase): alpha = alpha, interp = interp, force = force, + merge_dest_directory=merge_dest_directory, ) except AssertionError as e: raise ValueError(e) return result + + def search_for_models(self, directory: Path)->List[Path]: + """ + Return list of all models found in the designated directory. + """ + search = FindModels(directory,self.logger) + return search.list_models() + + def sync_to_config(self): + """ + Re-read models.yaml, rescan the models directory, and reimport models + in the autoimport directories. Call after making changes outside the + model manager API. + """ + return self.mgr.sync_to_config() + + def list_checkpoint_configs(self)->List[Path]: + """ + List the checkpoint config paths from ROOT/configs/stable-diffusion. + """ + config = self.mgr.app_config + conf_path = config.legacy_conf_path + root_path = config.root_path + return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')] + + def rename_model(self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + """ + Rename the indicated model. Can provide a new name and/or a new base. + :param model_name: Current name of the model + :param base_model: Current base of the model + :param model_type: Model type (can't be changed) + :param new_name: New name for the model + :param new_base: New base for the model + """ + self.mgr.rename_model(base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = new_name, + new_base = new_base, + ) + diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index b6f6d62d97..2e537313ac 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -71,8 +71,6 @@ class ModelInstallList: class InstallSelections(): install_models: List[str]= field(default_factory=list) remove_models: List[str]=field(default_factory=list) -# scan_directory: Path = None -# autoscan_on_startup: bool=False @dataclass class ModelLoadInfo(): diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 6b9d085885..55f6de9b5b 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -247,6 +247,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import CUDA_DEVICE, Chdir from .model_cache import ModelCache, ModelLocker +from .model_search import ModelSearch from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, @@ -322,16 +323,7 @@ class ModelManager(object): self.config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: metadata not found # TODO: version check - - self.models = dict() - for model_key, model_config in config.items(): - model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] - # alias for config file - model_config["model_format"] = model_config.pop("format") - self.models[model_key] = model_class.create_config(**model_config) - - # check config version number and update on disk/RAM if necessary + self.app_config = InvokeAIAppConfig.get_config() self.logger = logger self.cache = ModelCache( @@ -342,11 +334,41 @@ class ModelManager(object): sequential_offload = sequential_offload, logger = logger, ) + + self._read_models(config) + + def _read_models(self, config: Optional[DictConfig] = None): + if not config: + if self.config_path: + config = OmegaConf.load(self.config_path) + else: + return + + self.models = dict() + for model_key, model_config in config.items(): + if model_key.startswith('_'): + continue + model_name, base_model, model_type = self.parse_key(model_key) + model_class = MODEL_CLASSES[base_model][model_type] + # alias for config file + model_config["model_format"] = model_config.pop("format") + self.models[model_key] = model_class.create_config(**model_config) + + # check config version number and update on disk/RAM if necessary self.cache_keys = dict() # add controlnet, lora and textual_inversion models from disk self.scan_models_directory() + def sync_to_config(self): + """ + Call this when `models.yaml` has been changed externally. + This will reinitialize internal data structures + """ + # Reread models directory; note that this will reinitialize the cache, + # causing otherwise unreferenced models to be removed from memory + self._read_models() + def model_exists( self, model_name: str, @@ -527,7 +549,10 @@ class ModelManager(object): model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) models = [] for model_key in model_keys: - model_config = self.models[model_key] + model_config = self.models.get(model_key) + if not model_config: + self.logger.error(f'Unknown model {model_name}') + raise KeyError(f'Unknown model {model_name}') cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) if base_model is not None and cur_base_model != base_model: @@ -646,11 +671,61 @@ class ModelManager(object): config = model_config, ) + def rename_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): + ''' + Rename or rebase a model. + ''' + if new_name is None and new_base is None: + self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") + return + + model_key = self.create_key(model_name, base_model, model_type) + model_cfg = self.models.get(model_key, None) + if not model_cfg: + raise KeyError(f"Unknown model: {model_key}") + + old_path = self.app_config.root_path / model_cfg.path + new_name = new_name or model_name + new_base = new_base or base_model + new_key = self.create_key(new_name, new_base, model_type) + if new_key in self.models: + raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"') + + # if this is a model file/directory that we manage ourselves, we need to move it + if old_path.is_relative_to(self.app_config.models_path): + new_path = self.app_config.root_path / 'models' / new_base.value / model_type.value / new_name + move(old_path, new_path) + model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) + + # clean up caches + old_model_cache = self._get_model_cache_path(old_path) + if old_model_cache.exists(): + if old_model_cache.is_dir(): + rmtree(str(old_model_cache)) + else: + old_model_cache.unlink() + + cache_ids = self.cache_keys.pop(model_key, []) + for cache_id in cache_ids: + self.cache.uncache_model(cache_id) + + self.models.pop(model_key, None) # delete + self.models[new_key] = model_cfg + self.commit() + def convert_model ( self, model_name: str, base_model: BaseModelType, model_type: Union[ModelType.Main,ModelType.Vae], + dest_directory: Optional[Path]=None, ) -> AddModelResult: ''' Convert a checkpoint file into a diffusers folder, deleting the cached @@ -677,14 +752,14 @@ class ModelManager(object): ) checkpoint_path = self.app_config.root_path / info["path"] old_diffusers_path = self.app_config.models_path / model.location - new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name + new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name if new_diffusers_path.exists(): raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") try: move(old_diffusers_path,new_diffusers_path) info["model_format"] = "diffusers" - info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path)) + info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path)) info.pop('config') result = self.add_model(model_name, base_model, model_type, @@ -824,6 +899,7 @@ class ModelManager(object): if (new_models_found or imported_models) and self.config_path: self.commit() + def autoimport(self)->Dict[str, AddModelResult]: ''' Scan the autoimport directory (if defined) and import new models, delete defunct models. @@ -831,63 +907,42 @@ class ModelManager(object): # avoid circular import from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.frontend.install.model_install import ask_user_for_prediction_type - + + + class ScanAndImport(ModelSearch): + def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall): + super().__init__(directories, logger) + self.installer = installer + self.ignore = ignore + + def on_search_started(self): + self.new_models_found = dict() + + def on_model_found(self, model: Path): + if model not in self.ignore: + self.new_models_found.update(self.installer.heuristic_import(model)) + + def on_search_completed(self): + self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models') + + def models_found(self): + return self.new_models_found + + installer = ModelInstall(config = self.app_config, model_manager = self, prediction_type_helper = ask_user_for_prediction_type, ) - - scanned_dirs = set() - config = self.app_config - known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()} - - for autodir in [config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir]: - if autodir is None: - continue - - self.logger.info(f'Scanning {autodir} for models to import') - installed = dict() - - autodir = self.app_config.root_path / autodir - if not autodir.exists(): - continue - - items_scanned = 0 - new_models_found = dict() - - for root, dirs, files in os.walk(autodir): - items_scanned += len(dirs) + len(files) - for d in dirs: - path = Path(root) / d - if path in known_paths or path.parent in scanned_dirs: - scanned_dirs.add(path) - continue - if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): - try: - new_models_found.update(installer.heuristic_import(path)) - scanned_dirs.add(path) - except ValueError as e: - self.logger.warning(str(e)) - - for f in files: - path = Path(root) / f - if path in known_paths or path.parent in scanned_dirs: - continue - if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - try: - import_result = installer.heuristic_import(path) - new_models_found.update(import_result) - except ValueError as e: - self.logger.warning(str(e)) - - self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') - installed.update(new_models_found) - - return installed + known_paths = {config.root_path / x['path'] for x in self.list_models()} + directories = {config.root_path / x for x in [config.autoimport_dir, + config.lora_dir, + config.embedding_dir, + config.controlnet_dir] + } + scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) + scanner.search() + return scanner.models_found() def heuristic_import(self, items_to_import: Set[str], @@ -925,3 +980,4 @@ class ModelManager(object): successfully_installed.update(installed) self.commit() return successfully_installed + diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py index 39f951d2b4..6427b9e430 100644 --- a/invokeai/backend/model_management/model_merge.py +++ b/invokeai/backend/model_management/model_merge.py @@ -11,7 +11,7 @@ from enum import Enum from pathlib import Path from diffusers import DiffusionPipeline from diffusers import logging as dlogging -from typing import List, Union +from typing import List, Union, Optional import invokeai.backend.util.logging as logger @@ -74,6 +74,7 @@ class ModelMerger(object): alpha: float = 0.5, interp: MergeInterpolationMethod = None, force: bool = False, + merge_dest_directory: Optional[Path] = None, **kwargs, ) -> AddModelResult: """ @@ -85,7 +86,7 @@ class ModelMerger(object): :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - + :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) **kwargs - the default DiffusionPipeline.get_config_dict kwargs: cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map """ @@ -111,7 +112,7 @@ class ModelMerger(object): merged_pipe = self.merge_diffusion_models( model_paths, alpha, merge_method, force, **kwargs ) - dump_path = config.models_path / base_model.value / ModelType.Main.value + dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value dump_path.mkdir(parents=True, exist_ok=True) dump_path = dump_path / merged_model_name diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py new file mode 100644 index 0000000000..1e282b4bb8 --- /dev/null +++ b/invokeai/backend/model_management/model_search.py @@ -0,0 +1,103 @@ +# Copyright 2023, Lincoln D. Stein and the InvokeAI Team +""" +Abstract base class for recursive directory search for models. +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Set, types +from pathlib import Path + +import invokeai.backend.util.logging as logger + +class ModelSearch(ABC): + def __init__(self, directories: List[Path], logger: types.ModuleType=logger): + """ + Initialize a recursive model directory search. + :param directories: List of directory Paths to recurse through + :param logger: Logger to use + """ + self.directories = directories + self.logger = logger + self._items_scanned = 0 + self._models_found = 0 + self._scanned_dirs = set() + self._scanned_paths = set() + self._pruned_paths = set() + + @abstractmethod + def on_search_started(self): + """ + Called before the scan starts. + """ + pass + + @abstractmethod + def on_model_found(self, model: Path): + """ + Process a found model. Raise an exception if something goes wrong. + :param model: Model to process - could be a directory or checkpoint. + """ + pass + + @abstractmethod + def on_search_completed(self): + """ + Perform some activity when the scan is completed. May use instance + variables, items_scanned and models_found + """ + pass + + def search(self): + self.on_search_started() + for dir in self.directories: + self.walk_directory(dir) + self.on_search_completed() + + def walk_directory(self, path: Path): + for root, dirs, files in os.walk(path): + if str(Path(root).name).startswith('.'): + self._pruned_paths.add(root) + if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): + continue + + self._items_scanned += len(dirs) + len(files) + for d in dirs: + path = Path(root) / d + if path in self._scanned_paths or path.parent in self._scanned_dirs: + self._scanned_dirs.add(path) + continue + if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): + try: + self.on_model_found(path) + self._models_found += 1 + self._scanned_dirs.add(path) + except Exception as e: + self.logger.warning(str(e)) + + for f in files: + path = Path(root) / f + if path.parent in self._scanned_dirs: + continue + if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: + try: + self.on_model_found(path) + self._models_found += 1 + except Exception as e: + self.logger.warning(str(e)) + +class FindModels(ModelSearch): + def on_search_started(self): + self.models_found: Set[Path] = set() + + def on_model_found(self,model: Path): + self.models_found.add(model) + + def on_search_completed(self): + pass + + def list_models(self) -> List[Path]: + self.search() + return self.models_found + + diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 1c573b26b6..e404c56bdf 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -48,7 +48,9 @@ for base_model, models in MODEL_CLASSES.items(): model_configs.discard(None) MODEL_CONFIGS.extend(model_configs) - for cfg in model_configs: + # LS: sort to get the checkpoint configs first, which makes + # for a better template in the Swagger docs + for cfg in sorted(model_configs, key=lambda x: str(x)): model_name, cfg_name = cfg.__qualname__.split('.')[-2:] openapi_cfg_name = model_name + cfg_name if openapi_cfg_name in vars(): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index ddbc401e5b..c569872a81 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -59,7 +59,6 @@ class ModelConfigBase(BaseModel): path: str # or Path description: Optional[str] = Field(None) model_format: Optional[str] = Field(None) - # do not save to config error: Optional[ModelError] = Field(None) class Config: diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 74751a40dd..3d2e50d8fb 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -37,8 +37,7 @@ class StableDiffusion1Model(DiffusersModel): vae: Optional[str] = Field(None) config: str variant: ModelVariantType - - + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.Main