mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add name, base_mode, type fields to model info
This commit is contained in:
parent
f8d7477c7a
commit
ef83a2fffe
@ -530,6 +530,8 @@ class ModelManager(object):
|
|||||||
|
|
||||||
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
||||||
**model_config.dict(exclude_defaults=True),
|
**model_config.dict(exclude_defaults=True),
|
||||||
|
|
||||||
|
# OpenAPIModelInfoBase
|
||||||
name=cur_model_name,
|
name=cur_model_name,
|
||||||
base_model=cur_base_model,
|
base_model=cur_base_model,
|
||||||
type=cur_model_type,
|
type=cur_model_type,
|
||||||
@ -646,7 +648,7 @@ class ModelManager(object):
|
|||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
# TODO: or exclude_unset better fits here?
|
# TODO: or exclude_unset better fits here?
|
||||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
|
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||||
|
|
||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
@ -40,10 +41,15 @@ def _get_all_model_configs():
|
|||||||
MODEL_CONFIGS = _get_all_model_configs()
|
MODEL_CONFIGS = _get_all_model_configs()
|
||||||
OPENAPI_MODEL_CONFIGS = list()
|
OPENAPI_MODEL_CONFIGS = list()
|
||||||
|
|
||||||
|
class OpenAPIModelInfoBase(BaseModel):
|
||||||
|
name: str
|
||||||
|
base_model: BaseModelType
|
||||||
|
type: ModelType
|
||||||
|
|
||||||
for cfg in MODEL_CONFIGS:
|
for cfg in MODEL_CONFIGS:
|
||||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||||
openapi_cfg_name = model_name + cfg_name
|
openapi_cfg_name = model_name + cfg_name
|
||||||
name_wrapper = type(openapi_cfg_name, (cfg,), {})
|
name_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), {})
|
||||||
|
|
||||||
#globals()[name] = value
|
#globals()[name] = value
|
||||||
vars()[openapi_cfg_name] = name_wrapper
|
vars()[openapi_cfg_name] = name_wrapper
|
||||||
|
@ -53,7 +53,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
format: Optional[str] = Field(None)
|
format: Optional[str] = Field(None)
|
||||||
default: Optional[bool] = Field(False)
|
default: Optional[bool] = Field(False)
|
||||||
# do not save to config
|
# do not save to config
|
||||||
error: Optional[ModelError] = Field(None, exclude=True)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
use_enum_values = True
|
use_enum_values = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user