From bbcbcd9b63b84cd4572af709b1e0f70045e48d08 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:36:18 +1100 Subject: [PATCH] fix(mm): only loras and main models get `trigger_phrases` --- invokeai/backend/model_manager/config.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index d274867707..9f8b163246 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -157,7 +157,6 @@ class ModelConfigBase(BaseModel): source_api_response: Optional[str] = Field( description="The original API response from the source, as stringified JSON.", default=None ) - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) default_settings: Optional[ModelDefaultSettings] = Field( description="Default settings for this model", default=None ) @@ -187,10 +186,14 @@ class DiffusersConfigBase(ModelConfigBase): repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default -class LoRALyCORISConfig(ModelConfigBase): +class LoRAConfigBase(ModelConfigBase): + type: Literal[ModelType.LoRA] = ModelType.LoRA + trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) + + +class LoRALyCORISConfig(LoRAConfigBase): """Model config for LoRA/Lycoris models.""" - type: Literal[ModelType.LoRA] = ModelType.LoRA format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS @staticmethod @@ -198,10 +201,9 @@ class LoRALyCORISConfig(ModelConfigBase): return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}") -class LoRADiffusersConfig(ModelConfigBase): +class LoRADiffusersConfig(LoRAConfigBase): """Model config for LoRA/Diffusers models.""" - type: Literal[ModelType.LoRA] = ModelType.LoRA format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers @staticmethod @@ -275,10 +277,14 @@ class TextualInversionFolderConfig(ModelConfigBase): return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") -class MainCheckpointConfig(CheckpointConfigBase): +class MainConfigBase(ModelConfigBase): + type: Literal[ModelType.Main] = ModelType.Main + trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) + + +class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase): """Model config for main checkpoint models.""" - type: Literal[ModelType.Main] = ModelType.Main variant: ModelVariantType = ModelVariantType.Normal prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False @@ -288,11 +294,9 @@ class MainCheckpointConfig(CheckpointConfigBase): return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}") -class MainDiffusersConfig(DiffusersConfigBase): +class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase): """Model config for main diffusers models.""" - type: Literal[ModelType.Main] = ModelType.Main - @staticmethod def get_tag() -> Tag: return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")