Add name, base_mode, type fields to model info

This commit is contained in:
Sergey Borisov 2023-06-17 22:48:44 +03:00 committed by psychedelicious
parent f8d7477c7a
commit ef83a2fffe
3 changed files with 11 additions and 3 deletions

View File

@ -530,6 +530,8 @@ class ModelManager(object):
models[cur_base_model][cur_model_type][cur_model_name] = dict(
**model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase
name=cur_model_name,
base_model=cur_base_model,
type=cur_model_type,
@ -646,7 +648,7 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path

View File

@ -1,3 +1,4 @@
from pydantic import BaseModel
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
@ -40,10 +41,15 @@ def _get_all_model_configs():
MODEL_CONFIGS = _get_all_model_configs()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
name: str
base_model: BaseModelType
type: ModelType
for cfg in MODEL_CONFIGS:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
name_wrapper = type(openapi_cfg_name, (cfg,), {})
name_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), {})
#globals()[name] = value
vars()[openapi_cfg_name] = name_wrapper

View File

@ -53,7 +53,7 @@ class ModelConfigBase(BaseModel):
format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True