refactor(mm): remove vae field on _MainConfig

We will handle default VAE selection in the UI.
This commit is contained in:
psychedelicious 2024-03-01 13:00:55 +11:00
parent 316573df2d
commit dd31bc4586
3 changed files with 3 additions and 27 deletions

View File

@ -75,7 +75,6 @@ example_model_config = {
"description": "string", "description": "string",
"source": "string", "source": "string",
"last_modified": 0, "last_modified": 0,
"vae": "string",
"variant": "normal", "variant": "normal",
"prediction_type": "epsilon", "prediction_type": "epsilon",
"repo_variant": "fp16", "repo_variant": "fp16",

View File

@ -260,7 +260,6 @@ class TextualInversionFolderConfig(ModelConfigBase):
class _MainConfig(ModelConfigBase): class _MainConfig(ModelConfigBase):
"""Model config for main models.""" """Model config for main models."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False upcast_attention: bool = False

View File

@ -15,9 +15,7 @@ Use like this:
""" """
import hashlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type from typing import Callable, Dict, Optional, Tuple, Type
from ..config import ( from ..config import (
@ -27,8 +25,6 @@ from ..config import (
ModelFormat, ModelFormat,
ModelType, ModelType,
SubModelType, SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
) )
from . import ModelLoaderBase from . import ModelLoaderBase
@ -90,33 +86,15 @@ class ModelLoaderRegistry:
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type.""" """Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2) implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation: if not implementation:
raise NotImplementedError( raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
) )
return implementation, conf2, submodel_type return implementation, config, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@staticmethod @staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: