From a969707e4569201cf6fab297e8734fb312337448 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 9 Aug 2023 10:52:29 -0400 Subject: [PATCH] prevent vae: '' from crashing model --- invokeai/app/api/routers/models.py | 6 +++++- invokeai/backend/model_management/model_manager.py | 2 +- tests/test_model_manager.py | 9 +++++++++ .../test_model_manager/configs/relative_sub.models.yaml | 7 +++++++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 2f59f1dd0f..b6c1edbbe1 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -104,8 +104,12 @@ async def update_model( ): # model manager moved model path during rename - don't overwrite it info.path = new_info.get("path") + # replace empty string values with None/null to avoid phenomenon of vae: '' + info_dict = info.dict() + info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()} + ApiDependencies.invoker.services.model_manager.update_model( - model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict() + model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict ) model_raw = ApiDependencies.invoker.services.model_manager.list_model( diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 6664d32540..adc3aaa661 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -526,7 +526,7 @@ class ModelManager(object): # Does the config explicitly override the submodel? if submodel_type is not None and hasattr(model_config, submodel_type): submodel_path = getattr(model_config, submodel_type) - if submodel_path is not None: + if submodel_path is not None and len(submodel_path) > 0: model_path = getattr(model_config, submodel_type) is_submodel_override = True diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index 4314bad595..4aa2c4d3b2 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -7,6 +7,7 @@ from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelTyp BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) +VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main) @pytest.fixture @@ -36,3 +37,11 @@ def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" assert vae_model_path == expected_vae_path assert is_override + + +def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config( + VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2] + ) + vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) + assert not is_override diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml index 3ec7a3adff..2e26710d13 100644 --- a/tests/test_model_manager/configs/relative_sub.models.yaml +++ b/tests/test_model_manager/configs/relative_sub.models.yaml @@ -13,3 +13,10 @@ sdxl/main/SDXL with VAE: vae: sdxl/vae/sdxl-vae-fp16-fix/ variant: normal format: diffusers + +sdxl/main/SDXL with empty VAE: + path: sdxl/main/SDXL base 1_0 + description: SDXL with customized VAE + vae: '' + variant: normal + format: diffusers