mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Generate config names for openapi
This commit is contained in:
parent
7a66856785
commit
16dc78f6c6
@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
|
|||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models import get_all_model_configs
|
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||||
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
|
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
|
@ -29,10 +29,22 @@ MODEL_CLASSES = {
|
|||||||
#},
|
#},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_all_model_configs():
|
def _get_all_model_configs():
|
||||||
configs = set()
|
configs = set()
|
||||||
for models in MODEL_CLASSES.values():
|
for models in MODEL_CLASSES.values():
|
||||||
for _, model in models.items():
|
for _, model in models.items():
|
||||||
configs.update(model._get_configs().values())
|
configs.update(model._get_configs().values())
|
||||||
configs.discard(None)
|
configs.discard(None)
|
||||||
return list(configs) # TODO: set, list or tuple
|
return list(configs)
|
||||||
|
|
||||||
|
MODEL_CONFIGS = _get_all_model_configs()
|
||||||
|
OPENAPI_MODEL_CONFIGS = list()
|
||||||
|
|
||||||
|
for cfg in MODEL_CONFIGS:
|
||||||
|
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||||
|
openapi_cfg_name = model_name + cfg_name
|
||||||
|
name_wrapper = type(openapi_cfg_name, (cfg,), {})
|
||||||
|
|
||||||
|
#globals()[name] = value
|
||||||
|
vars()[openapi_cfg_name] = name_wrapper
|
||||||
|
OPENAPI_MODEL_CONFIGS.append(name_wrapper)
|
||||||
|
@ -18,7 +18,7 @@ class ControlNetModel(ModelBase):
|
|||||||
#model_class: Type
|
#model_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class ControlNetModelConfig(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
@ -15,7 +15,7 @@ from ..lora import LoRAModel as LoRAModelRaw
|
|||||||
class LoRAModel(ModelBase):
|
class LoRAModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class LoraModelConfig(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
@ -22,12 +22,12 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
class StableDiffusion1Model(DiffusersModel):
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
class StableDiffusion1DiffusersModelConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
class StableDiffusion1CheckpointModelConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
@ -107,7 +107,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
assert model_path == config.path
|
assert model_path == config.path
|
||||||
|
|
||||||
if isinstance(config, cls.StableDiffusion1CheckpointModelConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion1,
|
version=BaseModelType.StableDiffusion1,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
@ -120,14 +120,14 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
class StableDiffusion2Model(DiffusersModel):
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
|
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class StableDiffusion2DiffusersModelConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
|
|
||||||
class StableDiffusion2CheckpointModelConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
@ -220,7 +220,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
assert model_path == config.path
|
assert model_path == config.path
|
||||||
|
|
||||||
if isinstance(config, cls.StableDiffusion2CheckpointModelConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=BaseModelType.StableDiffusion2,
|
version=BaseModelType.StableDiffusion2,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
@ -256,7 +256,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|||||||
# TODO: rework
|
# TODO: rework
|
||||||
def _convert_ckpt_and_cache(
|
def _convert_ckpt_and_cache(
|
||||||
version: BaseModelType,
|
version: BaseModelType,
|
||||||
model_config: Union[StableDiffusion1Model.StableDiffusion1CheckpointModelConfig, StableDiffusion2Model.StableDiffusion2CheckpointModelConfig],
|
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||||
output_path: str,
|
output_path: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -14,7 +14,7 @@ from ..lora import TextualInversionModel as TextualInversionModelRaw
|
|||||||
class TextualInversionModel(ModelBase):
|
class TextualInversionModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class TextualInversionModelConfig(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: None
|
format: None
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
@ -23,7 +23,7 @@ class VaeModel(ModelBase):
|
|||||||
#vae_class: Type
|
#vae_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class VAEModelConfig(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
Loading…
Reference in New Issue
Block a user