prevent vae: '' from crashing model

This commit is contained in:
Lincoln Stein 2023-08-09 10:52:29 -04:00 committed by Kent Keirsey
parent 7bad9bcf53
commit a969707e45
4 changed files with 22 additions and 2 deletions

View File

@ -104,8 +104,12 @@ async def update_model(
): # model manager moved model path during rename - don't overwrite it ): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path") info.path = new_info.get("path")
# replace empty string values with None/null to avoid phenomenon of vae: ''
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict() model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info_dict
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(

View File

@ -526,7 +526,7 @@ class ModelManager(object):
# Does the config explicitly override the submodel? # Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type): if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type) submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None: if submodel_path is not None and len(submodel_path) > 0:
model_path = getattr(model_config, submodel_type) model_path = getattr(model_config, submodel_type)
is_submodel_override = True is_submodel_override = True

View File

@ -7,6 +7,7 @@ from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelTyp
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_NULL_OVERRIDE_MODEL_NAME = ("SDXL with empty VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
@pytest.fixture @pytest.fixture
@ -36,3 +37,11 @@ def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir:
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path assert vae_model_path == expected_vae_path
assert is_override assert is_override
def test_get_model_path_for_null_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_NULL_OVERRIDE_MODEL_NAME[1], VAE_NULL_OVERRIDE_MODEL_NAME[0], VAE_NULL_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
assert not is_override

View File

@ -13,3 +13,10 @@ sdxl/main/SDXL with VAE:
vae: sdxl/vae/sdxl-vae-fp16-fix/ vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal variant: normal
format: diffusers format: diffusers
sdxl/main/SDXL with empty VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: ''
variant: normal
format: diffusers