From 0f60b1ced472981a8b276a140bb31e9878053f7b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 4 Mar 2024 22:36:52 +1100 Subject: [PATCH] fix(mm): use `.value` for model config discriminators There is a breaking change in python 3.11 related to how enums with `str` as a mixin are formatted. This appears to have not caused any grief for us until now. Re-jigger the discriminator setup to use `.value` so everything works on both python 3.10 and 3.11. --- invokeai/backend/model_manager/config.py | 41 +++++++++++++++--------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 0386eab8ca..f1733ed79a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -177,7 +177,7 @@ class LoRALycorisConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}") + return Tag(f"{ModelType.Lora.value}.{ModelFormat.Lycoris.value}") class LoRADiffusersConfig(ModelConfigBase): @@ -188,7 +188,7 @@ class LoRADiffusersConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}") + return Tag(f"{ModelType.Lora.value}.{ModelFormat.Diffusers.value}") class VaeCheckpointConfig(CheckpointConfigBase): @@ -199,7 +199,7 @@ class VaeCheckpointConfig(CheckpointConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}") + return Tag(f"{ModelType.Vae.value}.{ModelFormat.Checkpoint.value}") class VaeDiffusersConfig(ModelConfigBase): @@ -210,7 +210,7 @@ class VaeDiffusersConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}") + return Tag(f"{ModelType.Vae.value}.{ModelFormat.Diffusers.value}") class ControlNetDiffusersConfig(DiffusersConfigBase): @@ -221,7 +221,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}") + return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}") class ControlNetCheckpointConfig(CheckpointConfigBase): @@ -232,7 +232,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}") + return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}") class TextualInversionFileConfig(ModelConfigBase): @@ -243,7 +243,7 @@ class TextualInversionFileConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}") + return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}") class TextualInversionFolderConfig(ModelConfigBase): @@ -254,7 +254,7 @@ class TextualInversionFolderConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}") + return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") class MainCheckpointConfig(CheckpointConfigBase): @@ -267,7 +267,7 @@ class MainCheckpointConfig(CheckpointConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}") + return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}") class MainDiffusersConfig(DiffusersConfigBase): @@ -277,7 +277,7 @@ class MainDiffusersConfig(DiffusersConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}") + return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}") class IPAdapterConfig(ModelConfigBase): @@ -289,7 +289,7 @@ class IPAdapterConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}") + return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}") class CLIPVisionDiffusersConfig(ModelConfigBase): @@ -300,7 +300,7 @@ class CLIPVisionDiffusersConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}") + return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}") class T2IAdapterConfig(ModelConfigBase): @@ -311,7 +311,7 @@ class T2IAdapterConfig(ModelConfigBase): @staticmethod def get_tag() -> Tag: - return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}") + return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}") def get_model_discriminator_value(v: Any) -> str: @@ -319,9 +319,20 @@ def get_model_discriminator_value(v: Any) -> str: Computes the discriminator value for a model config. https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator """ + format_ = None + type_ = None if isinstance(v, dict): - return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType] - return f"{v.type}.{v.format}" + format_ = v.get("format") + if isinstance(format_, Enum): + format_ = format_.value + type_ = v.get("type") + if isinstance(type_, Enum): + type_ = type_.value + else: + format_ = v.format.value + type_ = v.type.value + v = f"{type_}.{format_}" + return v AnyModelConfig = Annotated[