chore: Update model config type names

This commit is contained in:
blessedcoolant 2023-06-17 21:14:37 +12:00
parent 4cbc802e36
commit 67d05d2066
5 changed files with 15 additions and 14 deletions

View File

@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
#model_class: Type #model_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class ControlNetModelConfig(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
@ -82,6 +82,6 @@ class ControlNetModel(ModelBase):
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) != "diffusers":
raise NotImlemetedError("Checkpoint controlnet models currently unsupported") raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else: else:
return model_path return model_path

View File

@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
class LoRAModel(ModelBase): class LoRAModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class LoraModelConfig(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]] format: Union[Literal["lycoris"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -22,12 +22,12 @@ from omegaconf import OmegaConf
class StableDiffusion1Model(DiffusersModel): class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase): class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
class CheckpointConfig(ModelConfigBase): class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
format: Literal["checkpoint"] format: Literal["checkpoint"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
) -> str: ) -> str:
assert model_path == config.path assert model_path == config.path
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointModelConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1, version=BaseModelType.StableDiffusion1,
model_config=config, model_config=config,
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
class StableDiffusion2Model(DiffusersModel): class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly # TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase): class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
class CheckpointConfig(ModelConfigBase): class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
format: Literal["checkpoint"] format: Literal["checkpoint"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
) -> str: ) -> str:
assert model_path == config.path assert model_path == config.path
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointModelConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2, version=BaseModelType.StableDiffusion2,
model_config=config, model_config=config,
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
# TODO: rework # TODO: rework
def _convert_ckpt_and_cache( def _convert_ckpt_and_cache(
version: BaseModelType, version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig], model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
output_path: str, output_path: str,
) -> str: ) -> str:
""" """
@ -281,8 +281,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2: elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention upcast_attention = model_config.upcast_attention
prediction_type = config.prediction_type prediction_type = model_config.prediction_type
else: else:
raise Exception(f"Unknown model provided: {version}") raise Exception(f"Unknown model provided: {version}")

View File

@ -14,7 +14,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase): class TextualInversionModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class TextualInversionModelConfig(ModelConfigBase):
format: None format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
import safetensors
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
@ -22,7 +23,7 @@ class VaeModel(ModelBase):
#vae_class: Type #vae_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class VAEModelConfig(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):