fix(mm): use .value for model config discriminators

There is a breaking change in python 3.11 related to how enums with `str` as a mixin are formatted. This appears to have not caused any grief for us until now.

Re-jigger the discriminator setup to use `.value` so everything works on both python 3.10 and 3.11.
This commit is contained in:
psychedelicious 2024-03-04 22:36:52 +11:00
parent 44c40d7d1a
commit 0f60b1ced4

View File

@ -177,7 +177,7 @@ class LoRALycorisConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Lycoris.value}")
class LoRADiffusersConfig(ModelConfigBase):
@ -188,7 +188,7 @@ class LoRADiffusersConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Diffusers.value}")
class VaeCheckpointConfig(CheckpointConfigBase):
@ -199,7 +199,7 @@ class VaeCheckpointConfig(CheckpointConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Checkpoint.value}")
class VaeDiffusersConfig(ModelConfigBase):
@ -210,7 +210,7 @@ class VaeDiffusersConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(DiffusersConfigBase):
@ -221,7 +221,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(CheckpointConfigBase):
@ -232,7 +232,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
class TextualInversionFileConfig(ModelConfigBase):
@ -243,7 +243,7 @@ class TextualInversionFileConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
class TextualInversionFolderConfig(ModelConfigBase):
@ -254,7 +254,7 @@ class TextualInversionFolderConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainCheckpointConfig(CheckpointConfigBase):
@ -267,7 +267,7 @@ class MainCheckpointConfig(CheckpointConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainDiffusersConfig(DiffusersConfigBase):
@ -277,7 +277,7 @@ class MainDiffusersConfig(DiffusersConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterConfig(ModelConfigBase):
@ -289,7 +289,7 @@ class IPAdapterConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class CLIPVisionDiffusersConfig(ModelConfigBase):
@ -300,7 +300,7 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IAdapterConfig(ModelConfigBase):
@ -311,7 +311,7 @@ class T2IAdapterConfig(ModelConfigBase):
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
def get_model_discriminator_value(v: Any) -> str:
@ -319,9 +319,20 @@ def get_model_discriminator_value(v: Any) -> str:
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
format_ = None
type_ = None
if isinstance(v, dict):
return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
return f"{v.type}.{v.format}"
format_ = v.get("format")
if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
else:
format_ = v.format.value
type_ = v.type.value
v = f"{type_}.{format_}"
return v
AnyModelConfig = Annotated[