diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 964cc19f19..b4685caf10 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -150,7 +150,7 @@ class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - + repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" @@ -179,7 +179,6 @@ class ControlNetDiffusersConfig(_DiffusersConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" @@ -215,7 +214,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False - class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index cd048d2fe7..ba3ac3dd0c 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -20,6 +20,7 @@ from .config import ( ModelFormat, ModelType, ModelVariantType, + ModelRepoVariant, SchedulerPredictionType, ) from .hash import FastModelHash @@ -155,6 +156,9 @@ class ModelProbe(object): fields["original_hash"] = fields.get("original_hash") or hash fields["current_hash"] = fields.get("current_hash") or hash + if format_type == ModelFormat.Diffusers: + fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() + # additional fields needed for main and controlnet models if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint: fields["config"] = cls._get_checkpoint_config_path( @@ -477,6 +481,20 @@ class FolderProbeBase(ProbeBase): def get_format(self) -> ModelFormat: return ModelFormat("diffusers") + def get_repo_variant(self) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(self.model_path.glob('**/*.safetensors')) + weight_files.extend(list(self.model_path.glob('**/*.bin'))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OPENVINO + if "flax_model" in x.name: + return ModelRepoVariant.FLAX + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.DEFAULT class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: @@ -522,6 +540,7 @@ class PipelineFolderProbe(FolderProbeBase): except Exception: pass return ModelVariantType.Normal + class VaeFolderProbe(FolderProbeBase): diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 248b7d602f..415559a64c 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -3,7 +3,7 @@ from pathlib import Path import pytest from invokeai.backend import BaseModelType -from invokeai.backend.model_management.model_probe import VaeFolderProbe +from invokeai.backend.model_manager.probe import VaeFolderProbe @pytest.mark.parametrize( @@ -20,3 +20,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat probe = VaeFolderProbe(sd1_vae_path) base_type = probe.get_base_type() assert base_type == expected_type + repo_variant = probe.get_repo_variant() + assert repo_variant == 'default' + +def test_repo_variant(datadir: Path): + probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") + repo_variant = probe.get_repo_variant() + assert repo_variant == 'fp16' diff --git a/tests/test_model_probe/vae/taesdxl-fp16/config.json b/tests/test_model_probe/vae/taesdxl-fp16/config.json new file mode 100644 index 0000000000..62f01c3eb4 --- /dev/null +++ b/tests/test_model_probe/vae/taesdxl-fp16/config.json @@ -0,0 +1,37 @@ +{ + "_class_name": "AutoencoderTiny", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "relu", + "decoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "encoder_block_out_channels": [ + 64, + 64, + 64, + 64 + ], + "force_upcast": false, + "in_channels": 3, + "latent_channels": 4, + "latent_magnitude": 3, + "latent_shift": 0.5, + "num_decoder_blocks": [ + 3, + 3, + 3, + 1 + ], + "num_encoder_blocks": [ + 1, + 3, + 3, + 3 + ], + "out_channels": 3, + "scaling_factor": 1.0, + "upsampling_scaling_factor": 2 +} diff --git a/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors b/tests/test_model_probe/vae/taesdxl-fp16/diffusion_pytorch_model.fp16.safetensors new file mode 100644 index 0000000000..e69de29bb2