mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Set model type to const value in openapi schema, add model format enums to model schema(as they not not referenced in case of Literal definition)
This commit is contained in:
parent
da566b59e8
commit
21245a0fb2
@ -120,6 +120,22 @@ def custom_openapi():
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
||||
if name in openapi_schema["components"]["schemas"]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
type="string",
|
||||
enum=list(v.value for v in model_config_format_enum),
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
@ -30,15 +33,7 @@ MODEL_CLASSES = {
|
||||
#},
|
||||
}
|
||||
|
||||
def _get_all_model_configs():
|
||||
configs = set()
|
||||
for models in MODEL_CLASSES.values():
|
||||
for _, model in models.items():
|
||||
configs.update(model._get_configs().values())
|
||||
configs.discard(None)
|
||||
return list(configs)
|
||||
|
||||
MODEL_CONFIGS = _get_all_model_configs()
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
@ -46,11 +41,55 @@ class OpenAPIModelInfoBase(BaseModel):
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
for cfg in MODEL_CONFIGS:
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
name_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), {})
|
||||
|
||||
#globals()[name] = value
|
||||
vars()[openapi_cfg_name] = name_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(name_wrapper)
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
model_configs = set(model_class._get_configs().values())
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
for cfg in model_configs:
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
type=Literal[model_type.value],
|
||||
),
|
||||
))
|
||||
|
||||
#globals()[openapi_cfg_name] = api_wrapper
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = list()
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
fields = inspect.get_annotations(model_config)
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
raise Exception("format field not found")
|
||||
|
||||
# model_format: None
|
||||
# model_format: SomeModelFormat
|
||||
# model_format: Literal[SomeModelFormat.Diffusers]
|
||||
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
enums.append(field)
|
||||
|
||||
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
enums.append(type(field.__args__[0]))
|
||||
|
||||
elif field is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||
|
||||
return enums
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user