Generate config names for openapi

This commit is contained in:
Sergey Borisov 2023-06-17 17:15:36 +03:00
parent 7a66856785
commit 16dc78f6c6
7 changed files with 27 additions and 15 deletions

View File

@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])

View File

@ -29,10 +29,22 @@ MODEL_CLASSES = {
#},
}
def get_all_model_configs():
def _get_all_model_configs():
configs = set()
for models in MODEL_CLASSES.values():
for _, model in models.items():
configs.update(model._get_configs().values())
configs.discard(None)
return list(configs) # TODO: set, list or tuple
return list(configs)
MODEL_CONFIGS = _get_all_model_configs()
OPENAPI_MODEL_CONFIGS = list()
for cfg in MODEL_CONFIGS:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
name_wrapper = type(openapi_cfg_name, (cfg,), {})
#globals()[name] = value
vars()[openapi_cfg_name] = name_wrapper
OPENAPI_MODEL_CONFIGS.append(name_wrapper)

View File

@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class ControlNetModelConfig(ModelConfigBase):
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
class LoRAModel(ModelBase):
#model_size: int
class LoraModelConfig(ModelConfigBase):
class Config(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -22,12 +22,12 @@ from omegaconf import OmegaConf
class StableDiffusion1Model(DiffusersModel):
class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
) -> str:
assert model_path == config.path
if isinstance(config, cls.StableDiffusion1CheckpointModelConfig):
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
) -> str:
assert model_path == config.path
if isinstance(config, cls.StableDiffusion2CheckpointModelConfig):
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
# TODO: rework
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
output_path: str,
) -> str:
"""

View File

@ -14,7 +14,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase):
#model_size: int
class TextualInversionModelConfig(ModelConfigBase):
class Config(ModelConfigBase):
format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -23,7 +23,7 @@ class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class VAEModelConfig(ModelConfigBase):
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):