mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
replace _class_map
in ModelConfigFactory
with a nested discriminated union
This commit is contained in:
parent
bd56e9bc81
commit
3a6ba236f5
@ -18,8 +18,6 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
|
||||
|
||||
ModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
@ -22,7 +22,8 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@ -143,36 +144,42 @@ class _DiffusersConfig(ModelConfigBase):
|
||||
class LoRAConfig(ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
||||
|
||||
|
||||
class VaeCheckpointConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class VaeDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||
"""Model config for ControlNet models (diffusers version)."""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class TextualInversionConfig(ModelConfigBase):
|
||||
"""Model config for textual inversion embeddings."""
|
||||
|
||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
||||
|
||||
|
||||
@ -187,6 +194,7 @@ class _MainConfig(ModelConfigBase):
|
||||
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
# Note that we do not need prediction_type or upcast_attention here
|
||||
# because they are provided in the checkpoint's own config file.
|
||||
|
||||
@ -194,6 +202,7 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@ -201,7 +210,9 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-1."""
|
||||
|
||||
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
@ -209,8 +220,10 @@ class ONNXSD1Config(_MainConfig):
|
||||
class ONNXSD2Config(_MainConfig):
|
||||
"""Model config for ONNX format models based on sd-2."""
|
||||
|
||||
type: Literal[ModelType.ONNX] = ModelType.ONNX
|
||||
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
|
||||
# No yaml config file for ONNX, so these are part of config
|
||||
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
|
||||
upcast_attention: bool = True
|
||||
|
||||
@ -218,79 +231,49 @@ class ONNXSD2Config(_MainConfig):
|
||||
class IPAdapterConfig(ModelConfigBase):
|
||||
"""Model config for IP Adaptor format models."""
|
||||
|
||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||
format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
|
||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||
"""Model config for ClipVision."""
|
||||
|
||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
class T2IConfig(ModelConfigBase):
|
||||
"""Model config for T2I."""
|
||||
|
||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||
format: Literal[ModelFormat.Diffusers]
|
||||
|
||||
|
||||
AnyModelConfig = Union[
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
ONNXSD1Config,
|
||||
ONNXSD2Config,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator='base')]
|
||||
_ControlNetConfig = Annotated[Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator='format')]
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator='format')]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator='format')]
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
_MainModelConfig,
|
||||
_ONNXConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
],
|
||||
Field(discriminator='type')
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
_class_map: dict = {
|
||||
ModelFormat.Checkpoint: {
|
||||
ModelType.Main: MainCheckpointConfig,
|
||||
ModelType.Vae: VaeCheckpointConfig,
|
||||
},
|
||||
ModelFormat.Diffusers: {
|
||||
ModelType.Main: MainDiffusersConfig,
|
||||
ModelType.Lora: LoRAConfig,
|
||||
ModelType.Vae: VaeDiffusersConfig,
|
||||
ModelType.ControlNet: ControlNetDiffusersConfig,
|
||||
ModelType.CLIPVision: CLIPVisionDiffusersConfig,
|
||||
},
|
||||
ModelFormat.Lycoris: {
|
||||
ModelType.Lora: LoRAConfig,
|
||||
},
|
||||
ModelFormat.Onnx: {
|
||||
ModelType.ONNX: {
|
||||
BaseModelType.StableDiffusion1: ONNXSD1Config,
|
||||
BaseModelType.StableDiffusion2: ONNXSD2Config,
|
||||
},
|
||||
},
|
||||
ModelFormat.Olive: {
|
||||
ModelType.ONNX: {
|
||||
BaseModelType.StableDiffusion1: ONNXSD1Config,
|
||||
BaseModelType.StableDiffusion2: ONNXSD2Config,
|
||||
},
|
||||
},
|
||||
ModelFormat.EmbeddingFile: {
|
||||
ModelType.TextualInversion: TextualInversionConfig,
|
||||
},
|
||||
ModelFormat.EmbeddingFolder: {
|
||||
ModelType.TextualInversion: TextualInversionConfig,
|
||||
},
|
||||
ModelFormat.InvokeAI: {
|
||||
ModelType.IPAdapter: IPAdapterConfig,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
@ -311,18 +294,8 @@ class ModelConfigFactory(object):
|
||||
if key:
|
||||
model_data.key = key
|
||||
return model_data
|
||||
try:
|
||||
format = model_data.get("format")
|
||||
type = model_data.get("type")
|
||||
model_base = model_data.get("base")
|
||||
class_to_return = dest_class or cls._class_map[format][type]
|
||||
if isinstance(class_to_return, dict): # additional level allowed
|
||||
class_to_return = class_to_return[model_base]
|
||||
model = class_to_return.model_validate(model_data)
|
||||
else:
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
if key:
|
||||
model.key = key # ensure consistency
|
||||
model.key = key
|
||||
return model
|
||||
except KeyError as exc:
|
||||
raise InvalidModelConfigException(f"Unknown combination of format '{format}' and type '{type}'") from exc
|
||||
except ValidationError as exc:
|
||||
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc
|
||||
|
Loading…
Reference in New Issue
Block a user