mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): use .value
for model config discriminators
There is a breaking change in python 3.11 related to how enums with `str` as a mixin are formatted. This appears to have not caused any grief for us until now. Re-jigger the discriminator setup to use `.value` so everything works on both python 3.10 and 3.11.
This commit is contained in:
parent
44c40d7d1a
commit
0f60b1ced4
@ -177,7 +177,7 @@ class LoRALycorisConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
|
||||
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Lycoris.value}")
|
||||
|
||||
|
||||
class LoRADiffusersConfig(ModelConfigBase):
|
||||
@ -188,7 +188,7 @@ class LoRADiffusersConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
|
||||
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class VaeCheckpointConfig(CheckpointConfigBase):
|
||||
@ -199,7 +199,7 @@ class VaeCheckpointConfig(CheckpointConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
|
||||
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
@ -210,7 +210,7 @@ class VaeDiffusersConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
|
||||
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
@ -221,7 +221,7 @@ class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
@ -232,7 +232,7 @@ class ControlNetCheckpointConfig(CheckpointConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
|
||||
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class TextualInversionFileConfig(ModelConfigBase):
|
||||
@ -243,7 +243,7 @@ class TextualInversionFileConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
||||
|
||||
|
||||
class TextualInversionFolderConfig(ModelConfigBase):
|
||||
@ -254,7 +254,7 @@ class TextualInversionFolderConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
|
||||
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase):
|
||||
@ -267,7 +267,7 @@ class MainCheckpointConfig(CheckpointConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
class MainDiffusersConfig(DiffusersConfigBase):
|
||||
@ -277,7 +277,7 @@ class MainDiffusersConfig(DiffusersConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
|
||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
@ -289,7 +289,7 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
|
||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
@ -300,7 +300,7 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
|
||||
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class T2IAdapterConfig(ModelConfigBase):
|
||||
@ -311,7 +311,7 @@ class T2IAdapterConfig(ModelConfigBase):
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
@ -319,9 +319,20 @@ def get_model_discriminator_value(v: Any) -> str:
|
||||
Computes the discriminator value for a model config.
|
||||
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||
"""
|
||||
format_ = None
|
||||
type_ = None
|
||||
if isinstance(v, dict):
|
||||
return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
|
||||
return f"{v.type}.{v.format}"
|
||||
format_ = v.get("format")
|
||||
if isinstance(format_, Enum):
|
||||
format_ = format_.value
|
||||
type_ = v.get("type")
|
||||
if isinstance(type_, Enum):
|
||||
type_ = type_.value
|
||||
else:
|
||||
format_ = v.format.value
|
||||
type_ = v.type.value
|
||||
v = f"{type_}.{format_}"
|
||||
return v
|
||||
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
|
Loading…
Reference in New Issue
Block a user