tests: add test_probe_handles_state_dict_with_integer_keys

This commit is contained in:
psychedelicious 2024-03-26 08:56:12 +11:00
parent 243de683d7
commit 778922e603

View File

@ -1,9 +1,13 @@
from pathlib import Path from pathlib import Path
import pytest import pytest
from torch import tensor
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
from invokeai.backend.model_manager.config import InvalidModelConfigException
from invokeai.backend.model_manager.probe import ( from invokeai.backend.model_manager.probe import (
CkptType,
ModelProbe,
VaeFolderProbe, VaeFolderProbe,
get_default_settings_controlnet_t2i_adapter, get_default_settings_controlnet_t2i_adapter,
get_default_settings_main, get_default_settings_main,
@ -52,3 +56,25 @@ def test_default_settings_main():
assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024 assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024
assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None
assert get_default_settings_main(BaseModelType.Any) is None assert get_default_settings_main(BaseModelType.Any) is None
def test_probe_handles_state_dict_with_integer_keys():
# This structure isn't supported by invoke, but we still need to handle it gracefully. See #6044
state_dict_with_integer_keys: CkptType = {
320: (
{
"linear1.weight": tensor([1.0]),
"linear1.bias": tensor([1.0]),
"linear2.weight": tensor([1.0]),
"linear2.bias": tensor([1.0]),
},
{
"linear1.weight": tensor([1.0]),
"linear1.bias": tensor([1.0]),
"linear2.weight": tensor([1.0]),
"linear2.bias": tensor([1.0]),
},
),
}
with pytest.raises(InvalidModelConfigException):
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)