Generate config names for openapi

This commit is contained in:
Sergey Borisov 2023-06-17 17:15:36 +03:00
parent 7a66856785
commit 16dc78f6c6
7 changed files with 27 additions and 15 deletions

View File

@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(get_all_model_configs())] MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])

View File

@ -29,10 +29,22 @@ MODEL_CLASSES = {
#}, #},
} }
def get_all_model_configs(): def _get_all_model_configs():
configs = set() configs = set()
for models in MODEL_CLASSES.values(): for models in MODEL_CLASSES.values():
for _, model in models.items(): for _, model in models.items():
configs.update(model._get_configs().values()) configs.update(model._get_configs().values())
configs.discard(None) configs.discard(None)
return list(configs) # TODO: set, list or tuple return list(configs)
MODEL_CONFIGS = _get_all_model_configs()
OPENAPI_MODEL_CONFIGS = list()
for cfg in MODEL_CONFIGS:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
name_wrapper = type(openapi_cfg_name, (cfg,), {})
#globals()[name] = value
vars()[openapi_cfg_name] = name_wrapper
OPENAPI_MODEL_CONFIGS.append(name_wrapper)

View File

@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
#model_class: Type #model_class: Type
#model_size: int #model_size: int
class ControlNetModelConfig(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
class LoRAModel(ModelBase): class LoRAModel(ModelBase):
#model_size: int #model_size: int
class LoraModelConfig(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]] format: Union[Literal["lycoris"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -22,12 +22,12 @@ from omegaconf import OmegaConf
class StableDiffusion1Model(DiffusersModel): class StableDiffusion1Model(DiffusersModel):
class StableDiffusion1DiffusersModelConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
class StableDiffusion1CheckpointModelConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] format: Literal["checkpoint"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
) -> str: ) -> str:
assert model_path == config.path assert model_path == config.path
if isinstance(config, cls.StableDiffusion1CheckpointModelConfig): if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1, version=BaseModelType.StableDiffusion1,
model_config=config, model_config=config,
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
class StableDiffusion2Model(DiffusersModel): class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly # TODO: check that configs overwriten properly
class StableDiffusion2DiffusersModelConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
class StableDiffusion2CheckpointModelConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] format: Literal["checkpoint"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
) -> str: ) -> str:
assert model_path == config.path assert model_path == config.path
if isinstance(config, cls.StableDiffusion2CheckpointModelConfig): if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2, version=BaseModelType.StableDiffusion2,
model_config=config, model_config=config,
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
# TODO: rework # TODO: rework
def _convert_ckpt_and_cache( def _convert_ckpt_and_cache(
version: BaseModelType, version: BaseModelType,
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig], model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
output_path: str, output_path: str,
) -> str: ) -> str:
""" """

View File

@ -14,7 +14,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
class TextualInversionModel(ModelBase): class TextualInversionModel(ModelBase):
#model_size: int #model_size: int
class TextualInversionModelConfig(ModelConfigBase): class Config(ModelConfigBase):
format: None format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):

View File

@ -23,7 +23,7 @@ class VaeModel(ModelBase):
#vae_class: Type #vae_class: Type
#model_size: int #model_size: int
class VAEModelConfig(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):