mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(mm): fix some objects in tests
This commit is contained in:
parent
bd4fd9693d
commit
4347d1c7f7
@ -20,8 +20,9 @@ from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
TextualInversionConfig,
|
||||
TextualInversionFileConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
||||
@ -40,13 +41,13 @@ def store(
|
||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
|
||||
def example_config() -> TextualInversionConfig:
|
||||
return TextualInversionConfig(
|
||||
def example_config() -> TextualInversionFileConfig:
|
||||
return TextualInversionFileConfig(
|
||||
path="/tmp/pokemon.bin",
|
||||
name="old name",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("embedding"),
|
||||
format="embedding_file",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.TextualInversion,
|
||||
format=ModelFormat.EmbeddingFile,
|
||||
original_hash="ABC123",
|
||||
)
|
||||
|
||||
@ -55,14 +56,14 @@ def test_type(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config1 = store.get_model("key1")
|
||||
assert type(config1) == TextualInversionConfig
|
||||
assert type(config1) == TextualInversionFileConfig
|
||||
|
||||
|
||||
def test_add(store: ModelRecordServiceBase):
|
||||
raw = {
|
||||
"path": "/tmp/foo.ckpt",
|
||||
"name": "model1",
|
||||
"base": BaseModelType("sd-1"),
|
||||
"base": BaseModelType.StableDiffusion1,
|
||||
"type": "main",
|
||||
"config_path": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
@ -73,7 +74,7 @@ def test_add(store: ModelRecordServiceBase):
|
||||
config1 = store.get_model("key1")
|
||||
assert config1 is not None
|
||||
assert type(config1) == MainCheckpointConfig
|
||||
assert config1.base == BaseModelType("sd-1")
|
||||
assert config1.base == BaseModelType.StableDiffusion1
|
||||
assert config1.name == "model1"
|
||||
assert config1.original_hash == "111222333444"
|
||||
assert config1.current_hash is None
|
||||
@ -138,31 +139,31 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="config3",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("vae"),
|
||||
type=ModelType.Vae,
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||
matches = store.search_by_attr(model_type=ModelType("main"))
|
||||
matches = store.search_by_attr(model_type=ModelType.Main)
|
||||
assert len(matches) == 2
|
||||
assert matches[0].name in {"config1", "config2"}
|
||||
|
||||
matches = store.search_by_attr(model_type=ModelType("vae"))
|
||||
matches = store.search_by_attr(model_type=ModelType.Vae)
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "config3"
|
||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||
@ -179,29 +180,29 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
def test_unique(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("vae"),
|
||||
type=ModelType.Vae,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
@ -219,56 +220,56 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
config1 = MainDiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = MainDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sdxl"),
|
||||
type=ModelType("main"),
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
config5 = VaeDiffusersConfig(
|
||||
path="/tmp/config5",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("vae"),
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Vae,
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
for c in config1, config2, config3, config4, config5:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
|
||||
matches = store.search_by_attr(
|
||||
model_type=ModelType("main"),
|
||||
model_type=ModelType.Main,
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("main"),
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_attr(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("vae"),
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.Vae,
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 1
|
||||
|
Loading…
Reference in New Issue
Block a user