diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 53720585ef..8be7089cf5 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -1,9 +1,13 @@ from pathlib import Path import pytest +from torch import tensor from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant +from invokeai.backend.model_manager.config import InvalidModelConfigException from invokeai.backend.model_manager.probe import ( + CkptType, + ModelProbe, VaeFolderProbe, get_default_settings_controlnet_t2i_adapter, get_default_settings_main, @@ -52,3 +56,25 @@ def test_default_settings_main(): assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024 assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None assert get_default_settings_main(BaseModelType.Any) is None + + +def test_probe_handles_state_dict_with_integer_keys(): + # This structure isn't supported by invoke, but we still need to handle it gracefully. See #6044 + state_dict_with_integer_keys: CkptType = { + 320: ( + { + "linear1.weight": tensor([1.0]), + "linear1.bias": tensor([1.0]), + "linear2.weight": tensor([1.0]), + "linear2.bias": tensor([1.0]), + }, + { + "linear1.weight": tensor([1.0]), + "linear1.bias": tensor([1.0]), + "linear2.weight": tensor([1.0]), + "linear2.bias": tensor([1.0]), + }, + ), + } + with pytest.raises(InvalidModelConfigException): + ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)