tests(mm): fix some objects in tests

This commit is contained in:
psychedelicious 2024-03-01 15:48:15 +11:00
parent bd4fd9693d
commit 4347d1c7f7

View File

@ -20,8 +20,9 @@ from invokeai.backend.model_manager.config import (
BaseModelType, BaseModelType,
MainCheckpointConfig, MainCheckpointConfig,
MainDiffusersConfig, MainDiffusersConfig,
ModelFormat,
ModelType, ModelType,
TextualInversionConfig, TextualInversionFileConfig,
VaeDiffusersConfig, VaeDiffusersConfig,
) )
from invokeai.backend.model_manager.metadata import BaseMetadata from invokeai.backend.model_manager.metadata import BaseMetadata
@ -40,13 +41,13 @@ def store(
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
def example_config() -> TextualInversionConfig: def example_config() -> TextualInversionFileConfig:
return TextualInversionConfig( return TextualInversionFileConfig(
path="/tmp/pokemon.bin", path="/tmp/pokemon.bin",
name="old name", name="old name",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("embedding"), type=ModelType.TextualInversion,
format="embedding_file", format=ModelFormat.EmbeddingFile,
original_hash="ABC123", original_hash="ABC123",
) )
@ -55,14 +56,14 @@ def test_type(store: ModelRecordServiceBase):
config = example_config() config = example_config()
store.add_model("key1", config) store.add_model("key1", config)
config1 = store.get_model("key1") config1 = store.get_model("key1")
assert type(config1) == TextualInversionConfig assert type(config1) == TextualInversionFileConfig
def test_add(store: ModelRecordServiceBase): def test_add(store: ModelRecordServiceBase):
raw = { raw = {
"path": "/tmp/foo.ckpt", "path": "/tmp/foo.ckpt",
"name": "model1", "name": "model1",
"base": BaseModelType("sd-1"), "base": BaseModelType.StableDiffusion1,
"type": "main", "type": "main",
"config_path": "/tmp/foo.yaml", "config_path": "/tmp/foo.yaml",
"variant": "normal", "variant": "normal",
@ -73,7 +74,7 @@ def test_add(store: ModelRecordServiceBase):
config1 = store.get_model("key1") config1 = store.get_model("key1")
assert config1 is not None assert config1 is not None
assert type(config1) == MainCheckpointConfig assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType("sd-1") assert config1.base == BaseModelType.StableDiffusion1
assert config1.name == "model1" assert config1.name == "model1"
assert config1.original_hash == "111222333444" assert config1.original_hash == "111222333444"
assert config1.current_hash is None assert config1.current_hash is None
@ -138,31 +139,31 @@ def test_filter(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig( config1 = MainDiffusersConfig(
path="/tmp/config1", path="/tmp/config1",
name="config1", name="config1",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("main"), type=ModelType.Main,
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config2 = MainDiffusersConfig( config2 = MainDiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
name="config2", name="config2",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("main"), type=ModelType.Main,
original_hash="CONFIG2HASH", original_hash="CONFIG2HASH",
) )
config3 = VaeDiffusersConfig( config3 = VaeDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
name="config3", name="config3",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType("vae"), type=ModelType.Vae,
original_hash="CONFIG3HASH", original_hash="CONFIG3HASH",
) )
for c in config1, config2, config3: for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
matches = store.search_by_attr(model_type=ModelType("main")) matches = store.search_by_attr(model_type=ModelType.Main)
assert len(matches) == 2 assert len(matches) == 2
assert matches[0].name in {"config1", "config2"} assert matches[0].name in {"config1", "config2"}
matches = store.search_by_attr(model_type=ModelType("vae")) matches = store.search_by_attr(model_type=ModelType.Vae)
assert len(matches) == 1 assert len(matches) == 1
assert matches[0].name == "config3" assert matches[0].name == "config3"
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
@ -179,29 +180,29 @@ def test_filter(store: ModelRecordServiceBase):
def test_unique(store: ModelRecordServiceBase): def test_unique(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig( config1 = MainDiffusersConfig(
path="/tmp/config1", path="/tmp/config1",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("main"), type=ModelType.Main,
name="nonuniquename", name="nonuniquename",
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config2 = MainDiffusersConfig( config2 = MainDiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType("main"), type=ModelType.Main,
name="nonuniquename", name="nonuniquename",
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config3 = VaeDiffusersConfig( config3 = VaeDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType("vae"), type=ModelType.Vae,
name="nonuniquename", name="nonuniquename",
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config4 = MainDiffusersConfig( config4 = MainDiffusersConfig(
path="/tmp/config4", path="/tmp/config4",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("main"), type=ModelType.Main,
name="nonuniquename", name="nonuniquename",
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
@ -219,56 +220,56 @@ def test_filter_2(store: ModelRecordServiceBase):
config1 = MainDiffusersConfig( config1 = MainDiffusersConfig(
path="/tmp/config1", path="/tmp/config1",
name="config1", name="config1",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("main"), type=ModelType.Main,
original_hash="CONFIG1HASH", original_hash="CONFIG1HASH",
) )
config2 = MainDiffusersConfig( config2 = MainDiffusersConfig(
path="/tmp/config2", path="/tmp/config2",
name="config2", name="config2",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("main"), type=ModelType.Main,
original_hash="CONFIG2HASH", original_hash="CONFIG2HASH",
) )
config3 = MainDiffusersConfig( config3 = MainDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
name="dup_name1", name="dup_name1",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType("main"), type=ModelType.Main,
original_hash="CONFIG3HASH", original_hash="CONFIG3HASH",
) )
config4 = MainDiffusersConfig( config4 = MainDiffusersConfig(
path="/tmp/config4", path="/tmp/config4",
name="dup_name1", name="dup_name1",
base=BaseModelType("sdxl"), base=BaseModelType("sdxl"),
type=ModelType("main"), type=ModelType.Main,
original_hash="CONFIG3HASH", original_hash="CONFIG3HASH",
) )
config5 = VaeDiffusersConfig( config5 = VaeDiffusersConfig(
path="/tmp/config5", path="/tmp/config5",
name="dup_name1", name="dup_name1",
base=BaseModelType("sd-1"), base=BaseModelType.StableDiffusion1,
type=ModelType("vae"), type=ModelType.Vae,
original_hash="CONFIG3HASH", original_hash="CONFIG3HASH",
) )
for c in config1, config2, config3, config4, config5: for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
matches = store.search_by_attr( matches = store.search_by_attr(
model_type=ModelType("main"), model_type=ModelType.Main,
model_name="dup_name1", model_name="dup_name1",
) )
assert len(matches) == 2 assert len(matches) == 2
matches = store.search_by_attr( matches = store.search_by_attr(
base_model=BaseModelType("sd-1"), base_model=BaseModelType.StableDiffusion1,
model_type=ModelType("main"), model_type=ModelType.Main,
) )
assert len(matches) == 2 assert len(matches) == 2
matches = store.search_by_attr( matches = store.search_by_attr(
base_model=BaseModelType("sd-1"), base_model=BaseModelType.StableDiffusion1,
model_type=ModelType("vae"), model_type=ModelType.Vae,
model_name="dup_name1", model_name="dup_name1",
) )
assert len(matches) == 1 assert len(matches) == 1