mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests(mm): fix some objects in tests
This commit is contained in:
parent
bd4fd9693d
commit
4347d1c7f7
@ -20,8 +20,9 @@ from invokeai.backend.model_manager.config import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
MainCheckpointConfig,
|
MainCheckpointConfig,
|
||||||
MainDiffusersConfig,
|
MainDiffusersConfig,
|
||||||
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
TextualInversionConfig,
|
TextualInversionFileConfig,
|
||||||
VaeDiffusersConfig,
|
VaeDiffusersConfig,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
from invokeai.backend.model_manager.metadata import BaseMetadata
|
||||||
@ -40,13 +41,13 @@ def store(
|
|||||||
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||||
|
|
||||||
|
|
||||||
def example_config() -> TextualInversionConfig:
|
def example_config() -> TextualInversionFileConfig:
|
||||||
return TextualInversionConfig(
|
return TextualInversionFileConfig(
|
||||||
path="/tmp/pokemon.bin",
|
path="/tmp/pokemon.bin",
|
||||||
name="old name",
|
name="old name",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("embedding"),
|
type=ModelType.TextualInversion,
|
||||||
format="embedding_file",
|
format=ModelFormat.EmbeddingFile,
|
||||||
original_hash="ABC123",
|
original_hash="ABC123",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -55,14 +56,14 @@ def test_type(store: ModelRecordServiceBase):
|
|||||||
config = example_config()
|
config = example_config()
|
||||||
store.add_model("key1", config)
|
store.add_model("key1", config)
|
||||||
config1 = store.get_model("key1")
|
config1 = store.get_model("key1")
|
||||||
assert type(config1) == TextualInversionConfig
|
assert type(config1) == TextualInversionFileConfig
|
||||||
|
|
||||||
|
|
||||||
def test_add(store: ModelRecordServiceBase):
|
def test_add(store: ModelRecordServiceBase):
|
||||||
raw = {
|
raw = {
|
||||||
"path": "/tmp/foo.ckpt",
|
"path": "/tmp/foo.ckpt",
|
||||||
"name": "model1",
|
"name": "model1",
|
||||||
"base": BaseModelType("sd-1"),
|
"base": BaseModelType.StableDiffusion1,
|
||||||
"type": "main",
|
"type": "main",
|
||||||
"config_path": "/tmp/foo.yaml",
|
"config_path": "/tmp/foo.yaml",
|
||||||
"variant": "normal",
|
"variant": "normal",
|
||||||
@ -73,7 +74,7 @@ def test_add(store: ModelRecordServiceBase):
|
|||||||
config1 = store.get_model("key1")
|
config1 = store.get_model("key1")
|
||||||
assert config1 is not None
|
assert config1 is not None
|
||||||
assert type(config1) == MainCheckpointConfig
|
assert type(config1) == MainCheckpointConfig
|
||||||
assert config1.base == BaseModelType("sd-1")
|
assert config1.base == BaseModelType.StableDiffusion1
|
||||||
assert config1.name == "model1"
|
assert config1.name == "model1"
|
||||||
assert config1.original_hash == "111222333444"
|
assert config1.original_hash == "111222333444"
|
||||||
assert config1.current_hash is None
|
assert config1.current_hash is None
|
||||||
@ -138,31 +139,31 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
config1 = MainDiffusersConfig(
|
config1 = MainDiffusersConfig(
|
||||||
path="/tmp/config1",
|
path="/tmp/config1",
|
||||||
name="config1",
|
name="config1",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG2HASH",
|
original_hash="CONFIG2HASH",
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VaeDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="config3",
|
name="config3",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType("vae"),
|
type=ModelType.Vae,
|
||||||
original_hash="CONFIG3HASH",
|
original_hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
for c in config1, config2, config3:
|
for c in config1, config2, config3:
|
||||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||||
matches = store.search_by_attr(model_type=ModelType("main"))
|
matches = store.search_by_attr(model_type=ModelType.Main)
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
assert matches[0].name in {"config1", "config2"}
|
assert matches[0].name in {"config1", "config2"}
|
||||||
|
|
||||||
matches = store.search_by_attr(model_type=ModelType("vae"))
|
matches = store.search_by_attr(model_type=ModelType.Vae)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0].name == "config3"
|
assert matches[0].name == "config3"
|
||||||
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
|
||||||
@ -179,29 +180,29 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
def test_unique(store: ModelRecordServiceBase):
|
def test_unique(store: ModelRecordServiceBase):
|
||||||
config1 = MainDiffusersConfig(
|
config1 = MainDiffusersConfig(
|
||||||
path="/tmp/config1",
|
path="/tmp/config1",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VaeDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType("vae"),
|
type=ModelType.Vae,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config4 = MainDiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
@ -219,56 +220,56 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
config1 = MainDiffusersConfig(
|
config1 = MainDiffusersConfig(
|
||||||
path="/tmp/config1",
|
path="/tmp/config1",
|
||||||
name="config1",
|
name="config1",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG1HASH",
|
original_hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG2HASH",
|
original_hash="CONFIG2HASH",
|
||||||
)
|
)
|
||||||
config3 = MainDiffusersConfig(
|
config3 = MainDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG3HASH",
|
original_hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
config4 = MainDiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sdxl"),
|
base=BaseModelType("sdxl"),
|
||||||
type=ModelType("main"),
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG3HASH",
|
original_hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
config5 = VaeDiffusersConfig(
|
config5 = VaeDiffusersConfig(
|
||||||
path="/tmp/config5",
|
path="/tmp/config5",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sd-1"),
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType("vae"),
|
type=ModelType.Vae,
|
||||||
original_hash="CONFIG3HASH",
|
original_hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
for c in config1, config2, config3, config4, config5:
|
for c in config1, config2, config3, config4, config5:
|
||||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||||
|
|
||||||
matches = store.search_by_attr(
|
matches = store.search_by_attr(
|
||||||
model_type=ModelType("main"),
|
model_type=ModelType.Main,
|
||||||
model_name="dup_name1",
|
model_name="dup_name1",
|
||||||
)
|
)
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
|
|
||||||
matches = store.search_by_attr(
|
matches = store.search_by_attr(
|
||||||
base_model=BaseModelType("sd-1"),
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
model_type=ModelType("main"),
|
model_type=ModelType.Main,
|
||||||
)
|
)
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
|
|
||||||
matches = store.search_by_attr(
|
matches = store.search_by_attr(
|
||||||
base_model=BaseModelType("sd-1"),
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
model_type=ModelType("vae"),
|
model_type=ModelType.Vae,
|
||||||
model_name="dup_name1",
|
model_name="dup_name1",
|
||||||
)
|
)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user