diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index d274867707..9f8b163246 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -157,7 +157,6 @@ class ModelConfigBase(BaseModel): source_api_response: Optional[str] = Field( description="The original API response from the source, as stringified JSON.", default=None ) - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) default_settings: Optional[ModelDefaultSettings] = Field( description="Default settings for this model", default=None ) @@ -187,10 +186,14 @@ class DiffusersConfigBase(ModelConfigBase): repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default -class LoRALyCORISConfig(ModelConfigBase): +class LoRAConfigBase(ModelConfigBase): + type: Literal[ModelType.LoRA] = ModelType.LoRA + trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) + + +class LoRALyCORISConfig(LoRAConfigBase): """Model config for LoRA/Lycoris models.""" - type: Literal[ModelType.LoRA] = ModelType.LoRA format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS @staticmethod @@ -198,10 +201,9 @@ class LoRALyCORISConfig(ModelConfigBase): return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}") -class LoRADiffusersConfig(ModelConfigBase): +class LoRADiffusersConfig(LoRAConfigBase): """Model config for LoRA/Diffusers models.""" - type: Literal[ModelType.LoRA] = ModelType.LoRA format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers @staticmethod @@ -275,10 +277,14 @@ class TextualInversionFolderConfig(ModelConfigBase): return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}") -class MainCheckpointConfig(CheckpointConfigBase): +class MainConfigBase(ModelConfigBase): + type: Literal[ModelType.Main] = ModelType.Main + trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) + + +class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase): """Model config for main checkpoint models.""" - type: Literal[ModelType.Main] = ModelType.Main variant: ModelVariantType = ModelVariantType.Normal prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False @@ -288,11 +294,9 @@ class MainCheckpointConfig(CheckpointConfigBase): return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}") -class MainDiffusersConfig(DiffusersConfigBase): +class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase): """Model config for main diffusers models.""" - type: Literal[ModelType.Main] = ModelType.Main - @staticmethod def get_tag() -> Tag: return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")