fix(mm): only loras and main models get trigger_phrases

This commit is contained in:
psychedelicious 2024-03-07 15:36:18 +11:00
parent fdecb886b2
commit bbcbcd9b63

View File

@ -157,7 +157,6 @@ class ModelConfigBase(BaseModel):
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
@ -187,10 +186,14 @@ class DiffusersConfigBase(ModelConfigBase):
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRALyCORISConfig(ModelConfigBase):
class LoRAConfigBase(ModelConfigBase):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class LoRALyCORISConfig(LoRAConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
@ -198,10 +201,9 @@ class LoRALyCORISConfig(ModelConfigBase):
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class LoRADiffusersConfig(ModelConfigBase):
class LoRADiffusersConfig(LoRAConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
@ -275,10 +277,14 @@ class TextualInversionFolderConfig(ModelConfigBase):
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainCheckpointConfig(CheckpointConfigBase):
class MainConfigBase(ModelConfigBase):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@ -288,11 +294,9 @@ class MainCheckpointConfig(CheckpointConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainDiffusersConfig(DiffusersConfigBase):
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")