From 7c9128b2539d3359c3867963b3e9ffc2dc221248 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 5 Mar 2024 17:37:17 +1100 Subject: [PATCH] tidy(mm): use canonical capitalization for all model-related enums, classes For example, "Lora" -> "LoRA", "Vae" -> "VAE". --- invokeai/app/invocations/model.py | 2 +- invokeai/app/invocations/sdxl.py | 4 +- invokeai/backend/model_manager/config.py | 52 +++++++++---------- .../model_manager/load/model_loaders/lora.py | 4 +- .../model_manager/load/model_loaders/onnx.py | 2 +- .../model_manager/load/model_loaders/vae.py | 6 +-- invokeai/backend/model_manager/probe.py | 34 ++++++------ .../model_manager/util/select_hf_files.py | 12 ++--- invokeai/frontend/install/model_install.py | 2 +- .../model_records/test_model_records_sql.py | 18 +++---- .../model_manager/model_manager_fixtures.py | 10 ++-- .../util/test_hf_model_select.py | 6 +-- tests/test_model_probe.py | 2 +- 13 files changed, 77 insertions(+), 77 deletions(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 6087bc82db..cb69558be5 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=key, - submodel_type=SubModelType.Vae, + submodel_type=SubModelType.VAE, ), ), ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 0df27c0011..4e783defec 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel_type=SubModelType.Vae, + submodel_type=SubModelType.VAE, ), ), ) @@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): vae=VaeField( vae=ModelInfo( key=model_key, - submodel_type=SubModelType.Vae, + submodel_type=SubModelType.VAE, ), ), ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 973f749c52..8f0f437eb8 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -59,8 +59,8 @@ class ModelType(str, Enum): ONNX = "onnx" Main = "main" - Vae = "vae" - Lora = "lora" + VAE = "vae" + LoRA = "lora" ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" IPAdapter = "ip_adapter" @@ -76,9 +76,9 @@ class SubModelType(str, Enum): TextEncoder2 = "text_encoder_2" Tokenizer = "tokenizer" Tokenizer2 = "tokenizer_2" - Vae = "vae" - VaeDecoder = "vae_decoder" - VaeEncoder = "vae_encoder" + VAE = "vae" + VAEDecoder = "vae_decoder" + VAEEncoder = "vae_encoder" Scheduler = "scheduler" SafetyChecker = "safety_checker" @@ -96,8 +96,8 @@ class ModelFormat(str, Enum): Diffusers = "diffusers" Checkpoint = "checkpoint" - Lycoris = "lycoris" - Onnx = "onnx" + LyCORIS = "lycoris" + ONNX = "onnx" Olive = "olive" EmbeddingFile = "embedding_file" EmbeddingFolder = "embedding_folder" @@ -115,12 +115,12 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - DEFAULT = "" # model files without "fp16" or other qualifier - empty str + Default = "" # model files without "fp16" or other qualifier - empty str FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" - OPENVINO = "openvino" - FLAX = "flax" + OpenVINO = "openvino" + Flax = "flax" class ModelSourceType(str, Enum): @@ -183,51 +183,51 @@ class DiffusersConfigBase(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT + repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default -class LoRALycorisConfig(ModelConfigBase): +class LoRALyCORISConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" - type: Literal[ModelType.Lora] = ModelType.Lora - format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris + type: Literal[ModelType.LoRA] = ModelType.LoRA + format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Lora.value}.{ModelFormat.Lycoris.value}") + return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}") class LoRADiffusersConfig(ModelConfigBase): """Model config for LoRA/Diffusers models.""" - type: Literal[ModelType.Lora] = ModelType.Lora + type: Literal[ModelType.LoRA] = ModelType.LoRA format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Lora.value}.{ModelFormat.Diffusers.value}") + return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}") -class VaeCheckpointConfig(CheckpointConfigBase): +class VAECheckpointConfig(CheckpointConfigBase): """Model config for standalone VAE models.""" - type: Literal[ModelType.Vae] = ModelType.Vae + type: Literal[ModelType.VAE] = ModelType.VAE format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Vae.value}.{ModelFormat.Checkpoint.value}") + return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}") -class VaeDiffusersConfig(ModelConfigBase): +class VAEDiffusersConfig(ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" - type: Literal[ModelType.Vae] = ModelType.Vae + type: Literal[ModelType.VAE] = ModelType.VAE format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Vae.value}.{ModelFormat.Diffusers.value}") + return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}") class ControlNetDiffusersConfig(DiffusersConfigBase): @@ -356,11 +356,11 @@ AnyModelConfig = Annotated[ Union[ Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], - Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()], - Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()], + Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], + Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], - Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()], + Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index e308531a4f..436442a622 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod from .. import ModelLoader, ModelLoaderRegistry -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS) class LoraLoader(ModelLoader): """Class to load LoRA models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py index 38f0274acc..e771cac8eb 100644 --- a/invokeai/backend/model_manager/load/model_loaders/onnx.py +++ b/invokeai/backend/model_manager/load/model_loaders/onnx.py @@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry from .generic_diffusers import GenericDiffusersLoader -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) class OnnyxDiffusersModel(GenericDiffusersLoader): """Class to load onnx models.""" diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index e4fc811346..e18351138f 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -20,9 +20,9 @@ from .. import ModelLoaderRegistry from .generic_diffusers import GenericDiffusersLoader -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint) class VaeLoader(GenericDiffusersLoader): """Class to load VAE models.""" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 774959f7ef..759cb2fc46 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -97,8 +97,8 @@ class ModelProbe(object): "StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main, "LatentConsistencyModelPipeline": ModelType.Main, - "AutoencoderKL": ModelType.Vae, - "AutoencoderTiny": ModelType.Vae, + "AutoencoderKL": ModelType.VAE, + "AutoencoderTiny": ModelType.VAE, "ControlNetModel": ModelType.ControlNet, "CLIPVisionModelWithProjection": ModelType.CLIPVision, "T2IAdapter": ModelType.T2IAdapter, @@ -143,7 +143,7 @@ class ModelProbe(object): model_type = cls.get_model_type_from_folder(model_path) else: model_type = cls.get_model_type_from_checkpoint(model_path) - format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type + format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type probe_class = cls.PROBES[format_type].get(model_type) if not probe_class: @@ -172,7 +172,7 @@ class ModelProbe(object): # additional fields needed for main and controlnet models if ( - fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] + fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] is ModelFormat.Checkpoint ): fields["config_path"] = cls._get_checkpoint_config_path( @@ -185,7 +185,7 @@ class ModelProbe(object): # additional fields needed for main non-checkpoint models elif fields["type"] == ModelType.Main and fields["format"] in [ - ModelFormat.Onnx, + ModelFormat.ONNX, ModelFormat.Olive, ModelFormat.Diffusers, ]: @@ -219,11 +219,11 @@ class ModelProbe(object): if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): return ModelType.Main elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): - return ModelType.Vae + return ModelType.VAE elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): - return ModelType.Lora + return ModelType.LoRA elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): - return ModelType.Lora + return ModelType.LoRA elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): return ModelType.ControlNet elif key in {"emb_params", "string_to_param"}: @@ -245,7 +245,7 @@ class ModelProbe(object): if (folder_path / f"learned_embeds.{suffix}").exists(): return ModelType.TextualInversion if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): - return ModelType.Lora + return ModelType.LoRA if (folder_path / "unet/model.onnx").exists(): return ModelType.ONNX if (folder_path / "image_encoder.txt").exists(): @@ -301,7 +301,7 @@ class ModelProbe(object): if base_type is BaseModelType.StableDiffusion1 else "../controlnet/cldm_v21.yaml" ) - elif model_type is ModelType.Vae: + elif model_type is ModelType.VAE: config_file = ( "../stable-diffusion/v1-inference.yaml" if base_type is BaseModelType.StableDiffusion1 @@ -511,12 +511,12 @@ class FolderProbeBase(ProbeBase): if ".fp16" in x.suffixes: return ModelRepoVariant.FP16 if "openvino_model" in x.name: - return ModelRepoVariant.OPENVINO + return ModelRepoVariant.OpenVINO if "flax_model" in x.name: - return ModelRepoVariant.FLAX + return ModelRepoVariant.Flax if x.suffix == ".onnx": return ModelRepoVariant.ONNX - return ModelRepoVariant.DEFAULT + return ModelRepoVariant.Default class PipelineFolderProbe(FolderProbeBase): @@ -722,8 +722,8 @@ class T2IAdapterFolderProbe(FolderProbeBase): ############## register probe classes ###### ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) @@ -731,8 +731,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index 595d03f599..4a63ab27b7 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -35,7 +35,7 @@ def filter_files( The file list can be obtained from the `files` field of HuggingFaceMetadata, as defined in `invokeai.backend.model_manager.metadata.metadata_base`. """ - variant = variant or ModelRepoVariant.DEFAULT + variant = variant or ModelRepoVariant.Default paths: List[Path] = [] root = files[0].parts[0] @@ -90,11 +90,11 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path result.add(path) elif "openvino_model" in path.name: - if variant == ModelRepoVariant.OPENVINO: + if variant == ModelRepoVariant.OpenVINO: result.add(path) elif "flax_model" in path.name: - if variant == ModelRepoVariant.FLAX: + if variant == ModelRepoVariant.Flax: result.add(path) elif path.suffix in [".json", ".txt"]: @@ -103,7 +103,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path elif variant in [ ModelRepoVariant.FP16, ModelRepoVariant.FP32, - ModelRepoVariant.DEFAULT, + ModelRepoVariant.Default, ] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]: # For weights files, we want to select the best one for each subfolder. For example, we may have multiple # text encoders: @@ -127,7 +127,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path # Some special handling is needed here if there is not an exact match and if we cannot infer the variant # from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT. if candidate_variant_label == f".{variant}" or ( - not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT] + not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default] ): score += 1 @@ -148,7 +148,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path # config and text files then we return an empty list if ( variant - and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX] + and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax] and not any(variant.value in x.name for x in result) ): return set() diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 2f7fd0a1d0..4ef038277c 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.nextrely = top_of_table self.lora_models = self.add_model_widgets( - model_type=ModelType.Lora, + model_type=ModelType.LoRA, window_width=window_width, ) bottom_of_table = max(bottom_of_table, self.nextrely) diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index aeb7362e38..94b61bf1bf 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -23,7 +23,7 @@ from invokeai.backend.model_manager.config import ( ModelSourceType, ModelType, TextualInversionFileConfig, - VaeDiffusersConfig, + VAEDiffusersConfig, ) from invokeai.backend.util.logging import InvokeAILogger from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 @@ -141,12 +141,12 @@ def test_filter(store: ModelRecordServiceBase): source="test/source", source_type=ModelSourceType.Path, ) - config3 = VaeDiffusersConfig( + config3 = VAEDiffusersConfig( key="config3", path="/tmp/config3", name="config3", base=BaseModelType("sd-2"), - type=ModelType.Vae, + type=ModelType.VAE, hash="CONFIG3HASH", source="test/source", source_type=ModelSourceType.Path, @@ -157,7 +157,7 @@ def test_filter(store: ModelRecordServiceBase): assert len(matches) == 2 assert matches[0].name in {"config1", "config2"} - matches = store.search_by_attr(model_type=ModelType.Vae) + matches = store.search_by_attr(model_type=ModelType.VAE) assert len(matches) == 1 assert matches[0].name == "config3" assert matches[0].key == "config3" @@ -190,10 +190,10 @@ def test_unique(store: ModelRecordServiceBase): source="test/source/", source_type=ModelSourceType.Path, ) - config3 = VaeDiffusersConfig( + config3 = VAEDiffusersConfig( path="/tmp/config3", base=BaseModelType("sd-2"), - type=ModelType.Vae, + type=ModelType.VAE, name="nonuniquename", hash="CONFIG1HASH", source="test/source/", @@ -257,11 +257,11 @@ def test_filter_2(store: ModelRecordServiceBase): source="test/source/", source_type=ModelSourceType.Path, ) - config5 = VaeDiffusersConfig( + config5 = VAEDiffusersConfig( path="/tmp/config5", name="dup_name1", base=BaseModelType.StableDiffusion1, - type=ModelType.Vae, + type=ModelType.VAE, hash="CONFIG3HASH", source="test/source/", source_type=ModelSourceType.Path, @@ -283,7 +283,7 @@ def test_filter_2(store: ModelRecordServiceBase): matches = store.search_by_attr( base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.Vae, + model_type=ModelType.VAE, model_name="dup_name1", ) assert len(matches) == 1 diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 112b3765ff..e22ce4ac81 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -28,7 +28,7 @@ from invokeai.backend.model_manager.config import ( ModelSourceType, ModelType, ModelVariantType, - VaeDiffusersConfig, + VAEDiffusersConfig, ) from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.util.logging import InvokeAILogger @@ -162,13 +162,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas db = create_mock_sqlite_database(mm2_app_config, logger) store = ModelRecordServiceSQL(db) # add five simple config records to the database - config1 = VaeDiffusersConfig( + config1 = VAEDiffusersConfig( key="test_config_1", path="/tmp/foo1", format=ModelFormat.Diffusers, name="test2", base=BaseModelType.StableDiffusion2, - type=ModelType.Vae, + type=ModelType.VAE, hash="111222333444", source="stabilityai/sdxl-vae", source_type=ModelSourceType.HFRepoID, @@ -204,7 +204,7 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas format=ModelFormat.Diffusers, name="test4", base=BaseModelType.StableDiffusionXL, - type=ModelType.Lora, + type=ModelType.LoRA, hash="111222333444", source="author4/model4", source_type=ModelSourceType.HFRepoID, @@ -215,7 +215,7 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas format=ModelFormat.Diffusers, name="test5", base=BaseModelType.StableDiffusion1, - type=ModelType.Lora, + type=ModelType.LoRA, hash="111222333444", source="author4/model5", source_type=ModelSourceType.HFRepoID, diff --git a/tests/backend/model_manager/util/test_hf_model_select.py b/tests/backend/model_manager/util/test_hf_model_select.py index 9a410bc0e3..a29827e8c4 100644 --- a/tests/backend/model_manager/util/test_hf_model_select.py +++ b/tests/backend/model_manager/util/test_hf_model_select.py @@ -104,7 +104,7 @@ def sdxl_base_files() -> List[Path]: ], ), ( - ModelRepoVariant.DEFAULT, + ModelRepoVariant.Default, [ "model_index.json", "scheduler/scheduler_config.json", @@ -129,7 +129,7 @@ def sdxl_base_files() -> List[Path]: ], ), ( - ModelRepoVariant.OPENVINO, + ModelRepoVariant.OpenVINO, [ "model_index.json", "scheduler/scheduler_config.json", @@ -211,7 +211,7 @@ def sdxl_base_files() -> List[Path]: ], ), ( - ModelRepoVariant.FLAX, + ModelRepoVariant.Flax, [ "model_index.json", "scheduler/scheduler_config.json", diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index be823e2be9..78a9ec50b4 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -21,7 +21,7 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat base_type = probe.get_base_type() assert base_type == expected_type repo_variant = probe.get_repo_variant() - assert repo_variant == ModelRepoVariant.DEFAULT + assert repo_variant == ModelRepoVariant.Default def test_repo_variant(datadir: Path):