diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 0bad714a17..eac2e1dbbf 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR. """ from __future__ import annotations -import os import hashlib +import os import textwrap -import yaml +import types from dataclasses import dataclass from pathlib import Path -from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types from shutil import rmtree, move +from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable import torch +import yaml from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig - from pydantic import BaseModel, Field import invokeai.backend.util.logging as logger @@ -259,6 +259,7 @@ from .models import ( ModelNotFoundException, InvalidModelException, DuplicateModelException, + ModelBase, ) # We are only starting to number the config file with release 3. @@ -361,7 +362,7 @@ class ModelManager(object): if model_key.startswith("_"): continue model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) # alias for config file model_config["model_format"] = model_config.pop("format") self.models[model_key] = model_class.create_config(**model_config) @@ -381,18 +382,24 @@ class ModelManager(object): # causing otherwise unreferenced models to be removed from memory self._read_models() - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> bool: + def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: """ - Given a model name, returns True if it is a valid - identifier. + Given a model name, returns True if it is a valid identifier. + + :param model_name: symbolic name of the model in models.yaml + :param model_type: ModelType enum indicating the type of model to return + :param base_model: BaseModelType enum indicating the base model used by this model + :param rescan: if True, scan_models_directory """ model_key = self.create_key(model_name, base_model, model_type) - return model_key in self.models + exists = model_key in self.models + + # if model not found try to find it (maybe file just pasted) + if rescan and not exists: + self.scan_models_directory(base_model=base_model, model_type=model_type) + exists = self.model_exists(model_name, base_model, model_type, rescan=False) + + return exists @classmethod def create_key( @@ -443,39 +450,32 @@ class ModelManager(object): :param model_name: symbolic name of the model in models.yaml :param model_type: ModelType enum indicating the type of model to return :param base_model: BaseModelType enum indicating the base model used by this model - :param submode_typel: an ModelType enum indicating the portion of + :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_class = MODEL_CLASSES[base_model][model_type] model_key = self.create_key(model_name, base_model, model_type) - # if model not found try to find it (maybe file just pasted) - if model_key not in self.models: - self.scan_models_directory(base_model=base_model, model_type=model_type) - if model_key not in self.models: - raise ModelNotFoundException(f"Model not found - {model_key}") + if not self.model_exists(model_name, base_model, model_type, rescan=True): + raise ModelNotFoundException(f"Model not found - {model_key}") - model_config = self.models[model_key] - model_path = self.resolve_model_path(model_config.path) + model_config = self._get_model_config(base_model, model_name, model_type) + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + + if is_submodel_override: + model_type = submodel_type + submodel_type = None + + model_class = self._get_implementation(base_model, model_type) if not model_path.exists(): if model_class.save_to_config: self.models[model_key].error = ModelError.NotFound - raise Exception(f'Files for model "{model_key}" not found') + raise Exception(f'Files for model "{model_key}" not found at {model_path}') else: self.models.pop(model_key, None) - raise ModelNotFoundException(f"Model not found - {model_key}") - - # vae/movq override - # TODO: - if submodel_type is not None and hasattr(model_config, submodel_type): - override_path = getattr(model_config, submodel_type) - if override_path: - model_path = self.resolve_path(override_path) - model_type = submodel_type - submodel_type = None - model_class = MODEL_CLASSES[base_model][model_type] + raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}') # TODO: path # TODO: is it accurate to use path as id @@ -513,6 +513,55 @@ class ModelManager(object): _cache=self.cache, ) + def _get_model_path( + self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None + ) -> (Path, bool): + """Extract a model's filesystem path from its config. + + :return: The fully qualified Path of the module (or submodule). + """ + model_path = model_config.path + is_submodel_override = False + + # Does the config explicitly override the submodel? + if submodel_type is not None and hasattr(model_config, submodel_type): + submodel_path = getattr(model_config, submodel_type) + if submodel_path is not None: + model_path = getattr(model_config, submodel_type) + is_submodel_override = True + + model_path = self.resolve_model_path(model_path) + return model_path, is_submodel_override + + def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase: + """Get a model's config object.""" + model_key = self.create_key(model_name, base_model, model_type) + try: + model_config = self.models[model_key] + except KeyError: + raise ModelNotFoundException(f"Model not found - {model_key}") + return model_config + + def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: + """Get the concrete implementation class for a specific model type.""" + model_class = MODEL_CLASSES[base_model][model_type] + return model_class + + def _instantiate( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel_type: Optional[SubModelType] = None, + ) -> ModelBase: + """Make a new instance of this model, without loading it.""" + model_config = self._get_model_config(base_model, model_name, model_type) + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + # FIXME: do non-overriden submodels get the right class? + constructor = self._get_implementation(base_model, model_type) + instance = constructor(model_path, base_model, model_type) + return instance + def model_info( self, model_name: str, @@ -660,7 +709,7 @@ class ModelManager(object): if path := model_attributes.get("path"): model_attributes["path"] = str(self.relative_model_path(Path(path))) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) model_config = model_class.create_config(**model_attributes) model_key = self.create_key(model_name, base_model, model_type) @@ -851,7 +900,7 @@ class ModelManager(object): for model_key, model_config in self.models.items(): model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) if model_class.save_to_config: # TODO: or exclude_unset better fits here? data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) @@ -909,7 +958,7 @@ class ModelManager(object): model_path = self.resolve_model_path(model_config.path).absolute() if not model_path.exists(): - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) if model_class.save_to_config: model_config.error = ModelError.NotFound self.models.pop(model_key, None) @@ -925,7 +974,7 @@ class ModelManager(object): for cur_model_type in ModelType: if model_type is not None and cur_model_type != model_type: continue - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value)) if not models_dir.exists(): diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index b15844bcf8..957a102ffb 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,9 +1,14 @@ import os -import torch -import safetensors from enum import Enum from pathlib import Path -from typing import Optional, Union, Literal +from typing import Optional + +import safetensors +import torch +from diffusers.utils import is_safetensors_available +from omegaconf import OmegaConf + +from invokeai.app.services.config import InvokeAIAppConfig from .base import ( ModelBase, ModelConfigBase, @@ -18,9 +23,6 @@ from .base import ( InvalidModelException, ModelNotFoundException, ) -from invokeai.app.services.config import InvokeAIAppConfig -from diffusers.utils import is_safetensors_available -from omegaconf import OmegaConf class VaeModelFormat(str, Enum): @@ -80,7 +82,7 @@ class VaeModel(ModelBase): @classmethod def detect_format(cls, path: str): if not os.path.exists(path): - raise ModelNotFoundException() + raise ModelNotFoundException(f"Does not exist as local file: {path}") if os.path.isdir(path): if os.path.exists(os.path.join(path, "config.json")): diff --git a/pyproject.toml b/pyproject.toml index b3f12481a8..2ae297a6da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ dependencies = [ "dev" = [ "pudb", ] -"test" = ["pytest>6.0.0", "pytest-cov", "black"] +"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"] "xformers" = [ "xformers~=0.0.19; sys_platform!='darwin'", "triton; sys_platform=='linux'", diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py new file mode 100644 index 0000000000..4314bad595 --- /dev/null +++ b/tests/test_model_manager.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType + +BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) +VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) + + +@pytest.fixture +def model_manager(datadir) -> ModelManager: + InvokeAIAppConfig.get_config(root=datadir) + return ModelManager(datadir / "configs" / "relative_sub.models.yaml") + + +def test_get_model_names(model_manager: ModelManager): + names = model_manager.model_names() + assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] + + +def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) + top_model_path, is_override = model_manager._get_model_path(model_config) + expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" + assert top_model_path == expected_model_path + assert not is_override + + +def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config( + VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] + ) + vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) + expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" + assert vae_model_path == expected_vae_path + assert is_override diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml new file mode 100644 index 0000000000..3ec7a3adff --- /dev/null +++ b/tests/test_model_manager/configs/relative_sub.models.yaml @@ -0,0 +1,15 @@ +__metadata__: + version: 3.0.0 + +sdxl/main/SDXL base: + path: sdxl/main/SDXL base 1_0 + description: SDXL base v1.0 + variant: normal + format: diffusers + +sdxl/main/SDXL with VAE: + path: sdxl/main/SDXL base 1_0 + description: SDXL with customized VAE + vae: sdxl/vae/sdxl-vae-fp16-fix/ + variant: normal + format: diffusers diff --git a/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json b/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json b/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json new file mode 100644 index 0000000000..e69de29bb2