diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index f510279f18..0abcc19dcf 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management.models import get_all_model_configs -MODEL_CONFIGS = Union[tuple(get_all_model_configs())] +from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS +MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] models_router = APIRouter(prefix="/v1/models", tags=["models"]) diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 40995498bf..eff71798a5 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -29,10 +29,22 @@ MODEL_CLASSES = { #}, } -def get_all_model_configs(): +def _get_all_model_configs(): configs = set() for models in MODEL_CLASSES.values(): for _, model in models.items(): configs.update(model._get_configs().values()) configs.discard(None) - return list(configs) # TODO: set, list or tuple + return list(configs) + +MODEL_CONFIGS = _get_all_model_configs() +OPENAPI_MODEL_CONFIGS = list() + +for cfg in MODEL_CONFIGS: + model_name, cfg_name = cfg.__qualname__.split('.')[-2:] + openapi_cfg_name = model_name + cfg_name + name_wrapper = type(openapi_cfg_name, (cfg,), {}) + + #globals()[name] = value + vars()[openapi_cfg_name] = name_wrapper + OPENAPI_MODEL_CONFIGS.append(name_wrapper) diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index 687afbffbd..de9926c83e 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -18,7 +18,7 @@ class ControlNetModel(ModelBase): #model_class: Type #model_size: int - class ControlNetModelConfig(ModelConfigBase): + class Config(ModelConfigBase): format: Union[Literal["checkpoint"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 60865817b9..bcf3224ece 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw class LoRAModel(ModelBase): #model_size: int - class LoraModelConfig(ModelConfigBase): + class Config(ModelConfigBase): format: Union[Literal["lycoris"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index 0ac88c8a94..20aaae23a6 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -22,12 +22,12 @@ from omegaconf import OmegaConf class StableDiffusion1Model(DiffusersModel): - class StableDiffusion1DiffusersModelConfig(ModelConfigBase): + class DiffusersConfig(ModelConfigBase): format: Literal["diffusers"] vae: Optional[str] = Field(None) variant: ModelVariantType - class StableDiffusion1CheckpointModelConfig(ModelConfigBase): + class CheckpointConfig(ModelConfigBase): format: Literal["checkpoint"] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) @@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel): ) -> str: assert model_path == config.path - if isinstance(config, cls.StableDiffusion1CheckpointModelConfig): + if isinstance(config, cls.CheckpointConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion1, model_config=config, @@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel): class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly - class StableDiffusion2DiffusersModelConfig(ModelConfigBase): + class DiffusersConfig(ModelConfigBase): format: Literal["diffusers"] vae: Optional[str] = Field(None) variant: ModelVariantType prediction_type: SchedulerPredictionType upcast_attention: bool - class StableDiffusion2CheckpointModelConfig(ModelConfigBase): + class CheckpointConfig(ModelConfigBase): format: Literal["checkpoint"] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) @@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel): ) -> str: assert model_path == config.path - if isinstance(config, cls.StableDiffusion2CheckpointModelConfig): + if isinstance(config, cls.CheckpointConfig): return _convert_ckpt_and_cache( version=BaseModelType.StableDiffusion2, model_config=config, @@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): # TODO: rework def _convert_ckpt_and_cache( version: BaseModelType, - model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig], + model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig], output_path: str, ) -> str: """ diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index 0ed19e0b92..e8c96ff31e 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -14,7 +14,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw class TextualInversionModel(ModelBase): #model_size: int - class TextualInversionModelConfig(ModelConfigBase): + class Config(ModelConfigBase): format: None def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index f285648323..b78617869a 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -23,7 +23,7 @@ class VaeModel(ModelBase): #vae_class: Type #model_size: int - class VAEModelConfig(ModelConfigBase): + class Config(ModelConfigBase): format: Union[Literal["checkpoint"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):