2023-09-20 19:48:59 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from invokeai.backend import BaseModelType
|
|
|
|
from invokeai.backend.model_management.model_probe import VaeFolderProbe
|
|
|
|
|
|
|
|
|
2023-09-20 19:53:25 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"vae_path,expected_type",
|
|
|
|
[
|
|
|
|
("sd-vae-ft-mse", BaseModelType.StableDiffusion1),
|
|
|
|
("sdxl-vae", BaseModelType.StableDiffusionXL),
|
|
|
|
("taesd", BaseModelType.StableDiffusion1),
|
|
|
|
("taesdxl", BaseModelType.StableDiffusionXL),
|
|
|
|
],
|
|
|
|
)
|
2023-09-20 19:48:59 +00:00
|
|
|
def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Path):
|
|
|
|
sd1_vae_path = datadir / "vae" / vae_path
|
|
|
|
probe = VaeFolderProbe(sd1_vae_path)
|
|
|
|
base_type = probe.get_base_type()
|
|
|
|
assert base_type == expected_type
|