refactor(mm): add CheckpointConfigBase for all ckpt models

This commit is contained in:
psychedelicious
2024-03-01 15:21:35 +11:00
parent 0a614943f6
commit 76cbc745e1
4 changed files with 17 additions and 18 deletions

View File

@ -135,7 +135,6 @@ class ModelConfigBase(BaseModel):
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
@ -160,6 +159,9 @@ class CheckpointConfigBase(ModelConfigBase):
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
last_modified: Optional[float] = Field(
description="When this model was last converted to diffusers", default_factory=time.time
)
class DiffusersConfigBase(ModelConfigBase):
@ -191,7 +193,7 @@ class LoRADiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
class VaeCheckpointConfig(ModelConfigBase):
class VaeCheckpointConfig(CheckpointConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
@ -257,25 +259,20 @@ class TextualInversionFolderConfig(ModelConfigBase):
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class MainCheckpointConfig(CheckpointConfigBase, _MainConfig):
class MainCheckpointConfig(CheckpointConfigBase):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
class MainDiffusersConfig(DiffusersConfigBase, _MainConfig):
class MainDiffusersConfig(DiffusersConfigBase):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@ -382,6 +379,6 @@ class ModelConfigFactory(object):
assert model is not None
if key:
model.key = key
if timestamp:
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.last_modified = timestamp
return model # type: ignore