mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
167 lines
4.4 KiB
Python
167 lines
4.4 KiB
Python
"""
|
|
Test the refactored model config classes.
|
|
"""
|
|
|
|
from invokeai.backend.model_manager.config import (
|
|
InvalidModelConfigException,
|
|
LoRAConfig,
|
|
MainCheckpointConfig,
|
|
MainDiffusersConfig,
|
|
ModelConfigFactory,
|
|
ONNXSD1Config,
|
|
ONNXSD2Config,
|
|
TextualInversionConfig,
|
|
ValidationError,
|
|
)
|
|
|
|
|
|
def test_checkpoints():
|
|
raw = dict(
|
|
path="/tmp/foo.ckpt",
|
|
name="foo",
|
|
base_model="sd-1",
|
|
model_type="main",
|
|
config="/tmp/foo.yaml",
|
|
variant="normal",
|
|
model_format="checkpoint",
|
|
)
|
|
config = ModelConfigFactory.make_config(raw)
|
|
assert isinstance(config, MainCheckpointConfig)
|
|
assert config.model_format == "checkpoint"
|
|
assert config.base_model == "sd-1"
|
|
assert config.vae is None
|
|
|
|
|
|
def test_diffusers():
|
|
raw = dict(
|
|
path="/tmp/foo",
|
|
name="foo",
|
|
base_model="sd-2",
|
|
model_type="main",
|
|
variant="inpaint",
|
|
model_format="diffusers",
|
|
vae="/tmp/foobar/vae.pt",
|
|
)
|
|
config = ModelConfigFactory.make_config(raw)
|
|
assert isinstance(config, MainDiffusersConfig)
|
|
assert config.model_format == "diffusers"
|
|
assert config.base_model == "sd-2"
|
|
assert config.variant == "inpaint"
|
|
assert config.vae == "/tmp/foobar/vae.pt"
|
|
|
|
|
|
def test_invalid_diffusers():
|
|
raw = dict(
|
|
path="/tmp/foo",
|
|
name="foo",
|
|
base_model="sd-2",
|
|
model_type="main",
|
|
variant="inpaint",
|
|
config="/tmp/foo.ckpt",
|
|
model_format="diffusers",
|
|
)
|
|
# This is expected to fail with a validation error, because
|
|
# diffusers format does not have a `config` field
|
|
try:
|
|
ModelConfigFactory.make_config(raw)
|
|
assert False, "Validation should have failed"
|
|
except InvalidModelConfigException:
|
|
assert True
|
|
|
|
|
|
def test_lora():
|
|
raw = dict(
|
|
path="/tmp/foo",
|
|
name="foo",
|
|
base_model="sdxl",
|
|
model_type="lora",
|
|
model_format="lycoris",
|
|
)
|
|
config = ModelConfigFactory.make_config(raw)
|
|
assert isinstance(config, LoRAConfig)
|
|
assert config.model_format == "lycoris"
|
|
raw["model_format"] = "diffusers"
|
|
config = ModelConfigFactory.make_config(raw)
|
|
assert isinstance(config, LoRAConfig)
|
|
assert config.model_format == "diffusers"
|
|
|
|
|
|
def test_embedding():
|
|
raw = dict(
|
|
path="/tmp/foo",
|
|
name="foo",
|
|
base_model="sdxl-refiner",
|
|
model_type="embedding",
|
|
model_format="embedding_file",
|
|
)
|
|
config = ModelConfigFactory.make_config(raw)
|
|
assert isinstance(config, TextualInversionConfig)
|
|
assert config.model_format == "embedding_file"
|
|
|
|
|
|
def test_onnx():
|
|
raw = dict(
|
|
path="/tmp/foo.ckpt",
|
|
name="foo",
|
|
base_model="sd-1",
|
|
model_type="onnx",
|
|
variant="normal",
|
|
model_format="onnx",
|
|
)
|
|
config = ModelConfigFactory.make_config(raw)
|
|
assert isinstance(config, ONNXSD1Config)
|
|
assert config.model_format == "onnx"
|
|
|
|
raw["base_model"] = "sd-2"
|
|
# this should not validate without the upcast_attention field
|
|
try:
|
|
ModelConfigFactory.make_config(raw)
|
|
assert False, "Config should not have validated without upcast_attention"
|
|
except InvalidModelConfigException:
|
|
assert True
|
|
|
|
raw["upcast_attention"] = True
|
|
raw["prediction_type"] = "epsilon"
|
|
config = ModelConfigFactory.make_config(raw)
|
|
assert isinstance(config, ONNXSD2Config)
|
|
assert config.upcast_attention
|
|
|
|
|
|
def test_assignment():
|
|
raw = dict(
|
|
path="/tmp/foo.ckpt",
|
|
name="foo",
|
|
base_model="sd-2",
|
|
model_type="onnx",
|
|
variant="normal",
|
|
model_format="onnx",
|
|
upcast_attention=True,
|
|
prediction_type="epsilon",
|
|
)
|
|
config = ModelConfigFactory.make_config(raw)
|
|
config.upcast_attention = False
|
|
assert not config.upcast_attention
|
|
try:
|
|
config.prediction_type = "not valid"
|
|
assert False, "Config should not have accepted invalid assignment"
|
|
except ValidationError:
|
|
assert True
|
|
|
|
|
|
def test_invalid_combination():
|
|
raw = dict(
|
|
path="/tmp/foo.ckpt",
|
|
name="foo",
|
|
base_model="sd-2",
|
|
model_type="main",
|
|
variant="normal",
|
|
model_format="onnx",
|
|
upcast_attention=True,
|
|
prediction_type="epsilon",
|
|
)
|
|
try:
|
|
ModelConfigFactory.make_config(raw)
|
|
assert False, "This should have raised an InvalidModelConfigException"
|
|
except InvalidModelConfigException:
|
|
assert True
|