diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index d75c55010a..687afbffbd 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -18,7 +18,7 @@ class ControlNetModel(ModelBase): #model_class: Type #model_size: int - class Config(ModelConfigBase): + class ControlNetModelConfig(ModelConfigBase): format: Union[Literal["checkpoint"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): @@ -82,6 +82,6 @@ class ControlNetModel(ModelBase): base_model: BaseModelType, ) -> str: if cls.detect_format(model_path) != "diffusers": - raise NotImlemetedError("Checkpoint controlnet models currently unsupported") + raise NotImplementedError("Checkpoint controlnet models currently unsupported") else: return model_path diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index bcf3224ece..60865817b9 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw class LoRAModel(ModelBase): #model_size: int - class Config(ModelConfigBase): + class LoraModelConfig(ModelConfigBase): format: Union[Literal["lycoris"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index bd519c88c8..9856069ea5 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -22,12 +22,12 @@ from omegaconf import OmegaConf class StableDiffusion1Model(DiffusersModel): - class DiffusersConfig(ModelConfigBase): + class StableDiffusion1DiffusersModelConfig(ModelConfigBase): format: Literal["diffusers"] vae: Optional[str] = Field(None) variant: ModelVariantType - class CheckpointConfig(ModelConfigBase): + class StableDiffusion1CheckpointModelConfig(ModelConfigBase): format: Literal["checkpoint"] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) @@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel): ) -> str: assert model_path == config.path - if isinstance(config, cls.CheckpointConfig): + if isinstance(config, cls.CheckpointModelConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion1, model_config=config, @@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel): class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): + class StableDiffusion2DiffusersModelConfig(ModelConfigBase): format: Literal["diffusers"] vae: Optional[str] = Field(None) variant: ModelVariantType prediction_type: SchedulerPredictionType upcast_attention: bool - class CheckpointConfig(ModelConfigBase): + class StableDiffusion2CheckpointModelConfig(ModelConfigBase): format: Literal["checkpoint"] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) @@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel): ) -> str: assert model_path == config.path - if isinstance(config, cls.CheckpointConfig): + if isinstance(config, cls.CheckpointModelConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion2, model_config=config, @@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): # TODO: rework def _convert_ckpt_and_cache( version: BaseModelType, - model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig], + model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig], output_path: str, ) -> str: """ @@ -281,8 +281,8 @@ def _convert_ckpt_and_cache( prediction_type = SchedulerPredictionType.Epsilon elif version == BaseModelType.StableDiffusion2: - upcast_attention = config.upcast_attention - prediction_type = config.prediction_type + upcast_attention = model_config.upcast_attention + prediction_type = model_config.prediction_type else: raise Exception(f"Unknown model provided: {version}") diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index e8c96ff31e..0ed19e0b92 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -14,7 +14,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw class TextualInversionModel(ModelBase): #model_size: int - class Config(ModelConfigBase): + class TextualInversionModelConfig(ModelConfigBase): format: None def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index 1edb57ccc4..f285648323 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,5 +1,6 @@ import os import torch +import safetensors from pathlib import Path from typing import Optional, Union, Literal from .base import ( @@ -22,7 +23,7 @@ class VaeModel(ModelBase): #vae_class: Type #model_size: int - class Config(ModelConfigBase): + class VAEModelConfig(ModelConfigBase): format: Union[Literal["checkpoint"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):