diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 3e8888be24..ebe7ffbbd0 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -533,7 +533,7 @@ class ModelManager(object): model_path = self.resolve_model_path(model_path) return model_path, is_submodel_override - def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase: + def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase: """Get a model's config object.""" model_key = self.create_key(model_name, base_model, model_type) try: diff --git a/pyproject.toml b/pyproject.toml index b3f12481a8..2ae297a6da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ dependencies = [ "dev" = [ "pudb", ] -"test" = ["pytest>6.0.0", "pytest-cov", "black"] +"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"] "xformers" = [ "xformers~=0.0.19; sys_platform!='darwin'", "triton; sys_platform=='linux'", diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py new file mode 100644 index 0000000000..af0394eac2 --- /dev/null +++ b/tests/test_model_manager.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType + + +@pytest.fixture +def model_manager(datadir) -> ModelManager: + InvokeAIAppConfig.get_config(root=datadir) + return ModelManager(datadir / "configs" / "relative_sub.models.yaml") + + +def test_get_model_names(model_manager: ModelManager): + names = model_manager.model_names() + assert names[:2] == [ + ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main), + ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main), + ] + + +def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL base", ModelType.Main) + top_model_path, is_override = model_manager._get_model_path(model_config) + expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" + assert top_model_path == expected_model_path + assert not is_override + + +def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL with VAE", ModelType.Main) + vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) + expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" + assert vae_model_path == expected_vae_path + assert is_override diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml new file mode 100644 index 0000000000..757c50e3b5 --- /dev/null +++ b/tests/test_model_manager/configs/relative_sub.models.yaml @@ -0,0 +1,15 @@ +__metadata__: + version: 3.0.0 + +sdxl/main/SDXL base: + path: sdxl/main/SDXL base 1_0 + description: SDXL base v1.0 + variant: normal + format: diffusers + +sdxl/main/SDXL with VAE: + path: sdxl/main/SDXL base 1_0 + description: SDXL base v1.0 + vae: sdxl/vae/sdxl-vae-fp16-fix/ + variant: normal + format: diffusers diff --git a/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json b/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json b/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json new file mode 100644 index 0000000000..e69de29bb2