tidy(mm): use canonical capitalization for all model-related enums, classes

For example, "Lora" -> "LoRA", "Vae" -> "VAE".
This commit is contained in:
psychedelicious
2024-03-05 17:37:17 +11:00
parent 4f9bb00275
commit 7c9128b253
13 changed files with 77 additions and 77 deletions

View File

@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=key, key=key,
submodel_type=SubModelType.Vae, submodel_type=SubModelType.VAE,
), ),
), ),
) )

View File

@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel_type=SubModelType.Vae, submodel_type=SubModelType.VAE,
), ),
), ),
) )
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
key=model_key, key=model_key,
submodel_type=SubModelType.Vae, submodel_type=SubModelType.VAE,
), ),
), ),
) )

View File

@ -59,8 +59,8 @@ class ModelType(str, Enum):
ONNX = "onnx" ONNX = "onnx"
Main = "main" Main = "main"
Vae = "vae" VAE = "vae"
Lora = "lora" LoRA = "lora"
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ip_adapter" IPAdapter = "ip_adapter"
@ -76,9 +76,9 @@ class SubModelType(str, Enum):
TextEncoder2 = "text_encoder_2" TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer" Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2" Tokenizer2 = "tokenizer_2"
Vae = "vae" VAE = "vae"
VaeDecoder = "vae_decoder" VAEDecoder = "vae_decoder"
VaeEncoder = "vae_encoder" VAEEncoder = "vae_encoder"
Scheduler = "scheduler" Scheduler = "scheduler"
SafetyChecker = "safety_checker" SafetyChecker = "safety_checker"
@ -96,8 +96,8 @@ class ModelFormat(str, Enum):
Diffusers = "diffusers" Diffusers = "diffusers"
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
Lycoris = "lycoris" LyCORIS = "lycoris"
Onnx = "onnx" ONNX = "onnx"
Olive = "olive" Olive = "olive"
EmbeddingFile = "embedding_file" EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder" EmbeddingFolder = "embedding_folder"
@ -115,12 +115,12 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum): class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format.""" """Various hugging face variants on the diffusers format."""
DEFAULT = "" # model files without "fp16" or other qualifier - empty str Default = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16" FP16 = "fp16"
FP32 = "fp32" FP32 = "fp32"
ONNX = "onnx" ONNX = "onnx"
OPENVINO = "openvino" OpenVINO = "openvino"
FLAX = "flax" Flax = "flax"
class ModelSourceType(str, Enum): class ModelSourceType(str, Enum):
@ -183,51 +183,51 @@ class DiffusersConfigBase(ModelConfigBase):
"""Model config for diffusers-style models.""" """Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRALycorisConfig(ModelConfigBase): class LoRALyCORISConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models.""" """Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Lycoris.value}") return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(ModelConfigBase): class LoRADiffusersConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models.""" """Model config for LoRA/Diffusers models."""
type: Literal[ModelType.Lora] = ModelType.Lora type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Diffusers.value}") return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
class VaeCheckpointConfig(CheckpointConfigBase): class VAECheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models.""" """Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Checkpoint.value}") return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
class VaeDiffusersConfig(ModelConfigBase): class VAEDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version).""" """Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod @staticmethod
def get_tag() -> Tag: def get_tag() -> Tag:
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Diffusers.value}") return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(DiffusersConfigBase): class ControlNetDiffusersConfig(DiffusersConfigBase):
@ -356,11 +356,11 @@ AnyModelConfig = Annotated[
Union[ Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()], Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()], Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()], Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],

View File

@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
from .. import ModelLoader, ModelLoaderRegistry from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
class LoraLoader(ModelLoader): class LoraLoader(ModelLoader):
"""Class to load LoRA models.""" """Class to load LoRA models."""

View File

@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(GenericDiffusersLoader): class OnnyxDiffusersModel(GenericDiffusersLoader):
"""Class to load onnx models.""" """Class to load onnx models."""

View File

@ -20,9 +20,9 @@ from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader): class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""

View File

@ -97,8 +97,8 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main, "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main, "LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.Vae, "AutoencoderTiny": ModelType.VAE,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision, "CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter, "T2IAdapter": ModelType.T2IAdapter,
@ -143,7 +143,7 @@ class ModelProbe(object):
model_type = cls.get_model_type_from_folder(model_path) model_type = cls.get_model_type_from_folder(model_path)
else: else:
model_type = cls.get_model_type_from_checkpoint(model_path) model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class: if not probe_class:
@ -172,7 +172,7 @@ class ModelProbe(object):
# additional fields needed for main and controlnet models # additional fields needed for main and controlnet models
if ( if (
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
and fields["format"] is ModelFormat.Checkpoint and fields["format"] is ModelFormat.Checkpoint
): ):
fields["config_path"] = cls._get_checkpoint_config_path( fields["config_path"] = cls._get_checkpoint_config_path(
@ -185,7 +185,7 @@ class ModelProbe(object):
# additional fields needed for main non-checkpoint models # additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [ elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.Onnx, ModelFormat.ONNX,
ModelFormat.Olive, ModelFormat.Olive,
ModelFormat.Diffusers, ModelFormat.Diffusers,
]: ]:
@ -219,11 +219,11 @@ class ModelProbe(object):
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae return ModelType.VAE
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora return ModelType.LoRA
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora return ModelType.LoRA
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
@ -245,7 +245,7 @@ class ModelProbe(object):
if (folder_path / f"learned_embeds.{suffix}").exists(): if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora return ModelType.LoRA
if (folder_path / "unet/model.onnx").exists(): if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists(): if (folder_path / "image_encoder.txt").exists():
@ -301,7 +301,7 @@ class ModelProbe(object):
if base_type is BaseModelType.StableDiffusion1 if base_type is BaseModelType.StableDiffusion1
else "../controlnet/cldm_v21.yaml" else "../controlnet/cldm_v21.yaml"
) )
elif model_type is ModelType.Vae: elif model_type is ModelType.VAE:
config_file = ( config_file = (
"../stable-diffusion/v1-inference.yaml" "../stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1 if base_type is BaseModelType.StableDiffusion1
@ -511,12 +511,12 @@ class FolderProbeBase(ProbeBase):
if ".fp16" in x.suffixes: if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16 return ModelRepoVariant.FP16
if "openvino_model" in x.name: if "openvino_model" in x.name:
return ModelRepoVariant.OPENVINO return ModelRepoVariant.OpenVINO
if "flax_model" in x.name: if "flax_model" in x.name:
return ModelRepoVariant.FLAX return ModelRepoVariant.Flax
if x.suffix == ".onnx": if x.suffix == ".onnx":
return ModelRepoVariant.ONNX return ModelRepoVariant.ONNX
return ModelRepoVariant.DEFAULT return ModelRepoVariant.Default
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
@ -722,8 +722,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
@ -731,8 +731,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)

View File

@ -35,7 +35,7 @@ def filter_files(
The file list can be obtained from the `files` field of HuggingFaceMetadata, The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`. as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
""" """
variant = variant or ModelRepoVariant.DEFAULT variant = variant or ModelRepoVariant.Default
paths: List[Path] = [] paths: List[Path] = []
root = files[0].parts[0] root = files[0].parts[0]
@ -90,11 +90,11 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
result.add(path) result.add(path)
elif "openvino_model" in path.name: elif "openvino_model" in path.name:
if variant == ModelRepoVariant.OPENVINO: if variant == ModelRepoVariant.OpenVINO:
result.add(path) result.add(path)
elif "flax_model" in path.name: elif "flax_model" in path.name:
if variant == ModelRepoVariant.FLAX: if variant == ModelRepoVariant.Flax:
result.add(path) result.add(path)
elif path.suffix in [".json", ".txt"]: elif path.suffix in [".json", ".txt"]:
@ -103,7 +103,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
elif variant in [ elif variant in [
ModelRepoVariant.FP16, ModelRepoVariant.FP16,
ModelRepoVariant.FP32, ModelRepoVariant.FP32,
ModelRepoVariant.DEFAULT, ModelRepoVariant.Default,
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]: ] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple # For weights files, we want to select the best one for each subfolder. For example, we may have multiple
# text encoders: # text encoders:
@ -127,7 +127,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant # Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT. # from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
if candidate_variant_label == f".{variant}" or ( if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT] not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
): ):
score += 1 score += 1
@ -148,7 +148,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# config and text files then we return an empty list # config and text files then we return an empty list
if ( if (
variant variant
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX] and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
and not any(variant.value in x.name for x in result) and not any(variant.value in x.name for x in result)
): ):
return set() return set()

View File

@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.nextrely = top_of_table self.nextrely = top_of_table
self.lora_models = self.add_model_widgets( self.lora_models = self.add_model_widgets(
model_type=ModelType.Lora, model_type=ModelType.LoRA,
window_width=window_width, window_width=window_width,
) )
bottom_of_table = max(bottom_of_table, self.nextrely) bottom_of_table = max(bottom_of_table, self.nextrely)

View File

@ -23,7 +23,7 @@ from invokeai.backend.model_manager.config import (
ModelSourceType, ModelSourceType,
ModelType, ModelType,
TextualInversionFileConfig, TextualInversionFileConfig,
VaeDiffusersConfig, VAEDiffusersConfig,
) )
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@ -141,12 +141,12 @@ def test_filter(store: ModelRecordServiceBase):
source="test/source", source="test/source",
source_type=ModelSourceType.Path, source_type=ModelSourceType.Path,
) )
config3 = VaeDiffusersConfig( config3 = VAEDiffusersConfig(
key="config3", key="config3",
path="/tmp/config3", path="/tmp/config3",
name="config3", name="config3",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType.Vae, type=ModelType.VAE,
hash="CONFIG3HASH", hash="CONFIG3HASH",
source="test/source", source="test/source",
source_type=ModelSourceType.Path, source_type=ModelSourceType.Path,
@ -157,7 +157,7 @@ def test_filter(store: ModelRecordServiceBase):
assert len(matches) == 2 assert len(matches) == 2
assert matches[0].name in {"config1", "config2"} assert matches[0].name in {"config1", "config2"}
matches = store.search_by_attr(model_type=ModelType.Vae) matches = store.search_by_attr(model_type=ModelType.VAE)
assert len(matches) == 1 assert len(matches) == 1
assert matches[0].name == "config3" assert matches[0].name == "config3"
assert matches[0].key == "config3" assert matches[0].key == "config3"
@ -190,10 +190,10 @@ def test_unique(store: ModelRecordServiceBase):
source="test/source/", source="test/source/",
source_type=ModelSourceType.Path, source_type=ModelSourceType.Path,
) )
config3 = VaeDiffusersConfig( config3 = VAEDiffusersConfig(
path="/tmp/config3", path="/tmp/config3",
base=BaseModelType("sd-2"), base=BaseModelType("sd-2"),
type=ModelType.Vae, type=ModelType.VAE,
name="nonuniquename", name="nonuniquename",
hash="CONFIG1HASH", hash="CONFIG1HASH",
source="test/source/", source="test/source/",
@ -257,11 +257,11 @@ def test_filter_2(store: ModelRecordServiceBase):
source="test/source/", source="test/source/",
source_type=ModelSourceType.Path, source_type=ModelSourceType.Path,
) )
config5 = VaeDiffusersConfig( config5 = VAEDiffusersConfig(
path="/tmp/config5", path="/tmp/config5",
name="dup_name1", name="dup_name1",
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
type=ModelType.Vae, type=ModelType.VAE,
hash="CONFIG3HASH", hash="CONFIG3HASH",
source="test/source/", source="test/source/",
source_type=ModelSourceType.Path, source_type=ModelSourceType.Path,
@ -283,7 +283,7 @@ def test_filter_2(store: ModelRecordServiceBase):
matches = store.search_by_attr( matches = store.search_by_attr(
base_model=BaseModelType.StableDiffusion1, base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Vae, model_type=ModelType.VAE,
model_name="dup_name1", model_name="dup_name1",
) )
assert len(matches) == 1 assert len(matches) == 1

View File

@ -28,7 +28,7 @@ from invokeai.backend.model_manager.config import (
ModelSourceType, ModelSourceType,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
VaeDiffusersConfig, VAEDiffusersConfig,
) )
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -162,13 +162,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db)
# add five simple config records to the database # add five simple config records to the database
config1 = VaeDiffusersConfig( config1 = VAEDiffusersConfig(
key="test_config_1", key="test_config_1",
path="/tmp/foo1", path="/tmp/foo1",
format=ModelFormat.Diffusers, format=ModelFormat.Diffusers,
name="test2", name="test2",
base=BaseModelType.StableDiffusion2, base=BaseModelType.StableDiffusion2,
type=ModelType.Vae, type=ModelType.VAE,
hash="111222333444", hash="111222333444",
source="stabilityai/sdxl-vae", source="stabilityai/sdxl-vae",
source_type=ModelSourceType.HFRepoID, source_type=ModelSourceType.HFRepoID,
@ -204,7 +204,7 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
format=ModelFormat.Diffusers, format=ModelFormat.Diffusers,
name="test4", name="test4",
base=BaseModelType.StableDiffusionXL, base=BaseModelType.StableDiffusionXL,
type=ModelType.Lora, type=ModelType.LoRA,
hash="111222333444", hash="111222333444",
source="author4/model4", source="author4/model4",
source_type=ModelSourceType.HFRepoID, source_type=ModelSourceType.HFRepoID,
@ -215,7 +215,7 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
format=ModelFormat.Diffusers, format=ModelFormat.Diffusers,
name="test5", name="test5",
base=BaseModelType.StableDiffusion1, base=BaseModelType.StableDiffusion1,
type=ModelType.Lora, type=ModelType.LoRA,
hash="111222333444", hash="111222333444",
source="author4/model5", source="author4/model5",
source_type=ModelSourceType.HFRepoID, source_type=ModelSourceType.HFRepoID,

View File

@ -104,7 +104,7 @@ def sdxl_base_files() -> List[Path]:
], ],
), ),
( (
ModelRepoVariant.DEFAULT, ModelRepoVariant.Default,
[ [
"model_index.json", "model_index.json",
"scheduler/scheduler_config.json", "scheduler/scheduler_config.json",
@ -129,7 +129,7 @@ def sdxl_base_files() -> List[Path]:
], ],
), ),
( (
ModelRepoVariant.OPENVINO, ModelRepoVariant.OpenVINO,
[ [
"model_index.json", "model_index.json",
"scheduler/scheduler_config.json", "scheduler/scheduler_config.json",
@ -211,7 +211,7 @@ def sdxl_base_files() -> List[Path]:
], ],
), ),
( (
ModelRepoVariant.FLAX, ModelRepoVariant.Flax,
[ [
"model_index.json", "model_index.json",
"scheduler/scheduler_config.json", "scheduler/scheduler_config.json",

View File

@ -21,7 +21,7 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
base_type = probe.get_base_type() base_type = probe.get_base_type()
assert base_type == expected_type assert base_type == expected_type
repo_variant = probe.get_repo_variant() repo_variant = probe.get_repo_variant()
assert repo_variant == ModelRepoVariant.DEFAULT assert repo_variant == ModelRepoVariant.Default
def test_repo_variant(datadir: Path): def test_repo_variant(datadir: Path):