mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(mm): use canonical capitalization for all model-related enums, classes
For example, "Lora" -> "LoRA", "Vae" -> "VAE".
This commit is contained in:
@ -59,8 +59,8 @@ class ModelType(str, Enum):
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
@ -76,9 +76,9 @@ class SubModelType(str, Enum):
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
@ -96,8 +96,8 @@ class ModelFormat(str, Enum):
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
Lycoris = "lycoris"
|
||||
Onnx = "onnx"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
@ -115,12 +115,12 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OPENVINO = "openvino"
|
||||
FLAX = "flax"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
@ -183,51 +183,51 @@ class DiffusersConfigBase(ModelConfigBase):
|
||||
"""Model config for diffusers-style models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||
|
||||
|
||||
class LoRALycorisConfig(ModelConfigBase):
|
||||
class LoRALyCORISConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Lycoris.value}")
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Diffusers.value}")
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(CheckpointConfigBase):
|
||||
class VAECheckpointConfig(CheckpointConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Checkpoint.value}")
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
class VAEDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Diffusers.value}")
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
@ -356,11 +356,11 @@ AnyModelConfig = Annotated[
|
||||
Union[
|
||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
|
||||
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
|
||||
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||
|
Reference in New Issue
Block a user