From 4347d1c7f7e76561e5aba68b12bc15ced26204d0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 1 Mar 2024 15:48:15 +1100 Subject: [PATCH] tests(mm): fix some objects in tests --- .../model_records/test_model_records_sql.py | 71 ++++++++++--------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index d418faee36..b8a1f3502e 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -20,8 +20,9 @@ from invokeai.backend.model_manager.config import ( BaseModelType, MainCheckpointConfig, MainDiffusersConfig, + ModelFormat, ModelType, - TextualInversionConfig, + TextualInversionFileConfig, VaeDiffusersConfig, ) from invokeai.backend.model_manager.metadata import BaseMetadata @@ -40,13 +41,13 @@ def store( return ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db)) -def example_config() -> TextualInversionConfig: - return TextualInversionConfig( +def example_config() -> TextualInversionFileConfig: + return TextualInversionFileConfig( path="/tmp/pokemon.bin", name="old name", - base=BaseModelType("sd-1"), - type=ModelType("embedding"), - format="embedding_file", + base=BaseModelType.StableDiffusion1, + type=ModelType.TextualInversion, + format=ModelFormat.EmbeddingFile, original_hash="ABC123", ) @@ -55,14 +56,14 @@ def test_type(store: ModelRecordServiceBase): config = example_config() store.add_model("key1", config) config1 = store.get_model("key1") - assert type(config1) == TextualInversionConfig + assert type(config1) == TextualInversionFileConfig def test_add(store: ModelRecordServiceBase): raw = { "path": "/tmp/foo.ckpt", "name": "model1", - "base": BaseModelType("sd-1"), + "base": BaseModelType.StableDiffusion1, "type": "main", "config_path": "/tmp/foo.yaml", "variant": "normal", @@ -73,7 +74,7 @@ def test_add(store: ModelRecordServiceBase): config1 = store.get_model("key1") assert config1 is not None assert type(config1) == MainCheckpointConfig - assert config1.base == BaseModelType("sd-1") + assert config1.base == BaseModelType.StableDiffusion1 assert config1.name == "model1" assert config1.original_hash == "111222333444" assert config1.current_hash is None @@ -138,31 +139,31 @@ def test_filter(store: ModelRecordServiceBase): config1 = MainDiffusersConfig( path="/tmp/config1", name="config1", - base=BaseModelType("sd-1"), - type=ModelType("main"), + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, original_hash="CONFIG1HASH", ) config2 = MainDiffusersConfig( path="/tmp/config2", name="config2", - base=BaseModelType("sd-1"), - type=ModelType("main"), + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, original_hash="CONFIG2HASH", ) config3 = VaeDiffusersConfig( path="/tmp/config3", name="config3", base=BaseModelType("sd-2"), - type=ModelType("vae"), + type=ModelType.Vae, original_hash="CONFIG3HASH", ) for c in config1, config2, config3: store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) - matches = store.search_by_attr(model_type=ModelType("main")) + matches = store.search_by_attr(model_type=ModelType.Main) assert len(matches) == 2 assert matches[0].name in {"config1", "config2"} - matches = store.search_by_attr(model_type=ModelType("vae")) + matches = store.search_by_attr(model_type=ModelType.Vae) assert len(matches) == 1 assert matches[0].name == "config3" assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest() @@ -179,29 +180,29 @@ def test_filter(store: ModelRecordServiceBase): def test_unique(store: ModelRecordServiceBase): config1 = MainDiffusersConfig( path="/tmp/config1", - base=BaseModelType("sd-1"), - type=ModelType("main"), + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, name="nonuniquename", original_hash="CONFIG1HASH", ) config2 = MainDiffusersConfig( path="/tmp/config2", base=BaseModelType("sd-2"), - type=ModelType("main"), + type=ModelType.Main, name="nonuniquename", original_hash="CONFIG1HASH", ) config3 = VaeDiffusersConfig( path="/tmp/config3", base=BaseModelType("sd-2"), - type=ModelType("vae"), + type=ModelType.Vae, name="nonuniquename", original_hash="CONFIG1HASH", ) config4 = MainDiffusersConfig( path="/tmp/config4", - base=BaseModelType("sd-1"), - type=ModelType("main"), + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, name="nonuniquename", original_hash="CONFIG1HASH", ) @@ -219,56 +220,56 @@ def test_filter_2(store: ModelRecordServiceBase): config1 = MainDiffusersConfig( path="/tmp/config1", name="config1", - base=BaseModelType("sd-1"), - type=ModelType("main"), + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, original_hash="CONFIG1HASH", ) config2 = MainDiffusersConfig( path="/tmp/config2", name="config2", - base=BaseModelType("sd-1"), - type=ModelType("main"), + base=BaseModelType.StableDiffusion1, + type=ModelType.Main, original_hash="CONFIG2HASH", ) config3 = MainDiffusersConfig( path="/tmp/config3", name="dup_name1", base=BaseModelType("sd-2"), - type=ModelType("main"), + type=ModelType.Main, original_hash="CONFIG3HASH", ) config4 = MainDiffusersConfig( path="/tmp/config4", name="dup_name1", base=BaseModelType("sdxl"), - type=ModelType("main"), + type=ModelType.Main, original_hash="CONFIG3HASH", ) config5 = VaeDiffusersConfig( path="/tmp/config5", name="dup_name1", - base=BaseModelType("sd-1"), - type=ModelType("vae"), + base=BaseModelType.StableDiffusion1, + type=ModelType.Vae, original_hash="CONFIG3HASH", ) for c in config1, config2, config3, config4, config5: store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c) matches = store.search_by_attr( - model_type=ModelType("main"), + model_type=ModelType.Main, model_name="dup_name1", ) assert len(matches) == 2 matches = store.search_by_attr( - base_model=BaseModelType("sd-1"), - model_type=ModelType("main"), + base_model=BaseModelType.StableDiffusion1, + model_type=ModelType.Main, ) assert len(matches) == 2 matches = store.search_by_attr( - base_model=BaseModelType("sd-1"), - model_type=ModelType("vae"), + base_model=BaseModelType.StableDiffusion1, + model_type=ModelType.Vae, model_name="dup_name1", ) assert len(matches) == 1