mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests: add test_probe_handles_state_dict_with_integer_keys
This commit is contained in:
parent
243de683d7
commit
778922e603
@ -1,9 +1,13 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from torch import tensor
|
||||||
|
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
||||||
|
from invokeai.backend.model_manager.config import InvalidModelConfigException
|
||||||
from invokeai.backend.model_manager.probe import (
|
from invokeai.backend.model_manager.probe import (
|
||||||
|
CkptType,
|
||||||
|
ModelProbe,
|
||||||
VaeFolderProbe,
|
VaeFolderProbe,
|
||||||
get_default_settings_controlnet_t2i_adapter,
|
get_default_settings_controlnet_t2i_adapter,
|
||||||
get_default_settings_main,
|
get_default_settings_main,
|
||||||
@ -52,3 +56,25 @@ def test_default_settings_main():
|
|||||||
assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024
|
assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024
|
||||||
assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None
|
assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None
|
||||||
assert get_default_settings_main(BaseModelType.Any) is None
|
assert get_default_settings_main(BaseModelType.Any) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_probe_handles_state_dict_with_integer_keys():
|
||||||
|
# This structure isn't supported by invoke, but we still need to handle it gracefully. See #6044
|
||||||
|
state_dict_with_integer_keys: CkptType = {
|
||||||
|
320: (
|
||||||
|
{
|
||||||
|
"linear1.weight": tensor([1.0]),
|
||||||
|
"linear1.bias": tensor([1.0]),
|
||||||
|
"linear2.weight": tensor([1.0]),
|
||||||
|
"linear2.bias": tensor([1.0]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"linear1.weight": tensor([1.0]),
|
||||||
|
"linear1.bias": tensor([1.0]),
|
||||||
|
"linear2.weight": tensor([1.0]),
|
||||||
|
"linear2.bias": tensor([1.0]),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
with pytest.raises(InvalidModelConfigException):
|
||||||
|
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)
|
||||||
|
Loading…
Reference in New Issue
Block a user