tests(mm): fix some objects in tests

This commit is contained in:
psychedelicious 2024-03-01 15:48:15 +11:00
parent bd4fd9693d
commit 4347d1c7f7

View File

@ -20,8 +20,9 @@ from invokeai.backend.model_manager.config import (
BaseModelType,
MainCheckpointConfig,
MainDiffusersConfig,
ModelFormat,
ModelType,
TextualInversionConfig,
TextualInversionFileConfig,
VaeDiffusersConfig,
)
from invokeai.backend.model_manager.metadata import BaseMetadata
@ -40,13 +41,13 @@ def store(
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def example_config() -> TextualInversionConfig:
return TextualInversionConfig(
def example_config() -> TextualInversionFileConfig:
return TextualInversionFileConfig(
path="/tmp/pokemon.bin",
name="old name",
base=BaseModelType("sd-1"),
type=ModelType("embedding"),
format="embedding_file",
base=BaseModelType.StableDiffusion1,
type=ModelType.TextualInversion,
format=ModelFormat.EmbeddingFile,
original_hash="ABC123",
)
@ -55,14 +56,14 @@ def test_type(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config1 = store.get_model("key1")
assert type(config1) == TextualInversionConfig
assert type(config1) == TextualInversionFileConfig
def test_add(store: ModelRecordServiceBase):
raw = {
"path": "/tmp/foo.ckpt",
"name": "model1",
"base": BaseModelType("sd-1"),
"base": BaseModelType.StableDiffusion1,
"type": "main",
"config_path": "/tmp/foo.yaml",
"variant": "normal",
@ -73,7 +74,7 @@ def test_add(store: ModelRecordServiceBase):
config1 = store.get_model("key1")
assert config1 is not None
assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType("sd-1")
assert config1.base == BaseModelType.StableDiffusion1
assert config1.name == "model1"
assert config1.original_hash == "111222333444"
assert config1.current_hash is None
@ -138,31 +139,31 @@ def test_filter(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG2HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
name="config3",
base=BaseModelType("sd-2"),
type=ModelType("vae"),
type=ModelType.Vae,
original_hash="CONFIG3HASH",
)
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_attr(model_type=ModelType("main"))
matches = store.search_by_attr(model_type=ModelType.Main)
assert len(matches) == 2
assert matches[0].name in {"config1", "config2"}
matches = store.search_by_attr(model_type=ModelType("vae"))
matches = store.search_by_attr(model_type=ModelType.Vae)
assert len(matches) == 1
assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
@ -179,29 +180,29 @@ def test_filter(store: ModelRecordServiceBase):
def test_unique(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
path="/tmp/config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
base=BaseModelType("sd-2"),
type=ModelType("main"),
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
base=BaseModelType("sd-2"),
type=ModelType("vae"),
type=ModelType.Vae,
name="nonuniquename",
original_hash="CONFIG1HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
)
@ -219,56 +220,56 @@ def test_filter_2(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG2HASH",
)
config3 = MainDiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
type=ModelType.Main,
original_hash="CONFIG3HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sdxl"),
type=ModelType("main"),
type=ModelType.Main,
original_hash="CONFIG3HASH",
)
config5 = VaeDiffusersConfig(
path="/tmp/config5",
name="dup_name1",
base=BaseModelType("sd-1"),
type=ModelType("vae"),
base=BaseModelType.StableDiffusion1,
type=ModelType.Vae,
original_hash="CONFIG3HASH",
)
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
matches = store.search_by_attr(
model_type=ModelType("main"),
model_type=ModelType.Main,
model_name="dup_name1",
)
assert len(matches) == 2
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("main"),
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Main,
)
assert len(matches) == 2
matches = store.search_by_attr(
base_model=BaseModelType("sd-1"),
model_type=ModelType("vae"),
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Vae,
model_name="dup_name1",
)
assert len(matches) == 1