chore: Update model config type names

This commit is contained in:
blessedcoolant 2023-06-17 21:14:37 +12:00 committed by psychedelicious
parent 2d889e133d
commit 9838dda1b7
5 changed files with 15 additions and 14 deletions

View File

@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
class ControlNetModelConfig(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
@ -82,6 +82,6 @@ class ControlNetModel(ModelBase):
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
class LoraModelConfig(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -22,12 +22,12 @@ from omegaconf import OmegaConf
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig):
if isinstance(config, cls.CheckpointModelConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig):
if isinstance(config, cls.CheckpointModelConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
# TODO: rework
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
output_path: str,
) -> str:
"""
@ -281,8 +281,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention
prediction_type = config.prediction_type
upcast_attention = model_config.upcast_attention
prediction_type = model_config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")

View File

@ -15,7 +15,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
class TextualInversionModelConfig(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -1,5 +1,6 @@
import os
import torch
import safetensors
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@ -22,7 +23,7 @@ class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
class VAEModelConfig(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):