mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): remove vae
field on _MainConfig
We will handle default VAE selection in the UI.
This commit is contained in:
parent
316573df2d
commit
dd31bc4586
@ -75,7 +75,6 @@ example_model_config = {
|
|||||||
"description": "string",
|
"description": "string",
|
||||||
"source": "string",
|
"source": "string",
|
||||||
"last_modified": 0,
|
"last_modified": 0,
|
||||||
"vae": "string",
|
|
||||||
"variant": "normal",
|
"variant": "normal",
|
||||||
"prediction_type": "epsilon",
|
"prediction_type": "epsilon",
|
||||||
"repo_variant": "fp16",
|
"repo_variant": "fp16",
|
||||||
|
@ -260,7 +260,6 @@ class TextualInversionFolderConfig(ModelConfigBase):
|
|||||||
class _MainConfig(ModelConfigBase):
|
class _MainConfig(ModelConfigBase):
|
||||||
"""Model config for main models."""
|
"""Model config for main models."""
|
||||||
|
|
||||||
vae: Optional[str] = Field(default=None)
|
|
||||||
variant: ModelVariantType = ModelVariantType.Normal
|
variant: ModelVariantType = ModelVariantType.Normal
|
||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
|
@ -15,9 +15,7 @@ Use like this:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Dict, Optional, Tuple, Type
|
from typing import Callable, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
from ..config import (
|
from ..config import (
|
||||||
@ -27,8 +25,6 @@ from ..config import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
VaeCheckpointConfig,
|
|
||||||
VaeDiffusersConfig,
|
|
||||||
)
|
)
|
||||||
from . import ModelLoaderBase
|
from . import ModelLoaderBase
|
||||||
|
|
||||||
@ -90,33 +86,15 @@ class ModelLoaderRegistry:
|
|||||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
|
||||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
|
||||||
|
|
||||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||||
if not implementation:
|
if not implementation:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||||
)
|
)
|
||||||
return implementation, conf2, submodel_type
|
return implementation, config, submodel_type
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _handle_subtype_overrides(
|
|
||||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
|
||||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
|
||||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
|
||||||
model_path = Path(config.vae)
|
|
||||||
config_class = (
|
|
||||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
|
||||||
)
|
|
||||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
|
||||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
|
||||||
submodel_type = None
|
|
||||||
else:
|
|
||||||
new_conf = config
|
|
||||||
return new_conf, submodel_type
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user