mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Generate config names for openapi
This commit is contained in:
parent
bf0d5f4cfc
commit
01d17601b8
@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models import get_all_model_configs
|
||||
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
@ -29,10 +29,22 @@ MODEL_CLASSES = {
|
||||
#},
|
||||
}
|
||||
|
||||
def get_all_model_configs():
|
||||
def _get_all_model_configs():
|
||||
configs = set()
|
||||
for models in MODEL_CLASSES.values():
|
||||
for _, model in models.items():
|
||||
configs.update(model._get_configs().values())
|
||||
configs.discard(None)
|
||||
return list(configs) # TODO: set, list or tuple
|
||||
return list(configs)
|
||||
|
||||
MODEL_CONFIGS = _get_all_model_configs()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
for cfg in MODEL_CONFIGS:
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
name_wrapper = type(openapi_cfg_name, (cfg,), {})
|
||||
|
||||
#globals()[name] = value
|
||||
vars()[openapi_cfg_name] = name_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(name_wrapper)
|
||||
|
@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
|
||||
#model_class: Type
|
||||
#model_size: int
|
||||
|
||||
class ControlNetModelConfig(ModelConfigBase):
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
|
@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
|
||||
class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class LoraModelConfig(ModelConfigBase):
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
|
@ -22,12 +22,12 @@ from omegaconf import OmegaConf
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.StableDiffusion1CheckpointModelConfig):
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.StableDiffusion2CheckpointModelConfig):
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
model_config=config,
|
||||
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
# TODO: rework
|
||||
def _convert_ckpt_and_cache(
|
||||
version: BaseModelType,
|
||||
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
|
||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||
output_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
|
@ -15,7 +15,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
|
||||
class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class TextualInversionModelConfig(ModelConfigBase):
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
|
@ -23,7 +23,7 @@ class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
#model_size: int
|
||||
|
||||
class VAEModelConfig(ModelConfigBase):
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
|
Loading…
Reference in New Issue
Block a user