2023-09-20 19:48:59 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import pytest
|
2024-03-25 21:56:12 +00:00
|
|
|
from torch import tensor
|
2023-09-20 19:48:59 +00:00
|
|
|
|
2024-02-06 03:56:32 +00:00
|
|
|
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
2024-03-25 21:56:12 +00:00
|
|
|
from invokeai.backend.model_manager.config import InvalidModelConfigException
|
2024-03-12 23:02:55 +00:00
|
|
|
from invokeai.backend.model_manager.probe import (
|
2024-03-25 21:56:12 +00:00
|
|
|
CkptType,
|
|
|
|
ModelProbe,
|
2024-03-12 23:02:55 +00:00
|
|
|
VaeFolderProbe,
|
|
|
|
get_default_settings_controlnet_t2i_adapter,
|
|
|
|
get_default_settings_main,
|
|
|
|
)
|
2023-09-20 19:48:59 +00:00
|
|
|
|
|
|
|
|
2023-09-20 19:53:25 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"vae_path,expected_type",
|
|
|
|
[
|
|
|
|
("sd-vae-ft-mse", BaseModelType.StableDiffusion1),
|
|
|
|
("sdxl-vae", BaseModelType.StableDiffusionXL),
|
|
|
|
("taesd", BaseModelType.StableDiffusion1),
|
|
|
|
("taesdxl", BaseModelType.StableDiffusionXL),
|
|
|
|
],
|
|
|
|
)
|
2023-09-20 19:48:59 +00:00
|
|
|
def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Path):
|
|
|
|
sd1_vae_path = datadir / "vae" / vae_path
|
|
|
|
probe = VaeFolderProbe(sd1_vae_path)
|
|
|
|
base_type = probe.get_base_type()
|
|
|
|
assert base_type == expected_type
|
2024-01-22 19:37:23 +00:00
|
|
|
repo_variant = probe.get_repo_variant()
|
2024-03-05 06:37:17 +00:00
|
|
|
assert repo_variant == ModelRepoVariant.Default
|
2024-02-01 04:37:59 +00:00
|
|
|
|
2024-01-22 19:37:23 +00:00
|
|
|
|
|
|
|
def test_repo_variant(datadir: Path):
|
|
|
|
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
|
|
|
|
repo_variant = probe.get_repo_variant()
|
2024-02-06 03:56:32 +00:00
|
|
|
assert repo_variant == ModelRepoVariant.FP16
|
2024-03-08 09:39:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_controlnet_t2i_default_settings():
|
2024-03-12 23:02:55 +00:00
|
|
|
assert get_default_settings_controlnet_t2i_adapter("some_canny_model").preprocessor == "canny_image_processor"
|
|
|
|
assert (
|
|
|
|
get_default_settings_controlnet_t2i_adapter("some_depth_model").preprocessor == "depth_anything_image_processor"
|
|
|
|
)
|
|
|
|
assert get_default_settings_controlnet_t2i_adapter("some_pose_model").preprocessor == "dw_openpose_image_processor"
|
|
|
|
assert get_default_settings_controlnet_t2i_adapter("i like turtles") is None
|
|
|
|
|
|
|
|
|
|
|
|
def test_default_settings_main():
|
|
|
|
assert get_default_settings_main(BaseModelType.StableDiffusion1).width == 512
|
|
|
|
assert get_default_settings_main(BaseModelType.StableDiffusion1).height == 512
|
|
|
|
assert get_default_settings_main(BaseModelType.StableDiffusion2).width == 512
|
|
|
|
assert get_default_settings_main(BaseModelType.StableDiffusion2).height == 512
|
|
|
|
assert get_default_settings_main(BaseModelType.StableDiffusionXL).width == 1024
|
|
|
|
assert get_default_settings_main(BaseModelType.StableDiffusionXL).height == 1024
|
|
|
|
assert get_default_settings_main(BaseModelType.StableDiffusionXLRefiner) is None
|
|
|
|
assert get_default_settings_main(BaseModelType.Any) is None
|
2024-03-25 21:56:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_probe_handles_state_dict_with_integer_keys():
|
|
|
|
# This structure isn't supported by invoke, but we still need to handle it gracefully. See #6044
|
|
|
|
state_dict_with_integer_keys: CkptType = {
|
|
|
|
320: (
|
|
|
|
{
|
|
|
|
"linear1.weight": tensor([1.0]),
|
|
|
|
"linear1.bias": tensor([1.0]),
|
|
|
|
"linear2.weight": tensor([1.0]),
|
|
|
|
"linear2.bias": tensor([1.0]),
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"linear1.weight": tensor([1.0]),
|
|
|
|
"linear1.bias": tensor([1.0]),
|
|
|
|
"linear2.weight": tensor([1.0]),
|
|
|
|
"linear2.bias": tensor([1.0]),
|
|
|
|
},
|
|
|
|
),
|
|
|
|
}
|
|
|
|
with pytest.raises(InvalidModelConfigException):
|
|
|
|
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)
|