# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team """ Base class for the InvokeAI configuration system. It defines a type of pydantic BaseSettings object that is able to read and write from an omegaconf-based config file, with overriding of settings from environment variables and/or the command line. """ from __future__ import annotations import argparse import os import sys from argparse import ArgumentParser from pathlib import Path from pydantic import BaseModel from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints from omegaconf import DictConfig, ListConfig, OmegaConf, DictKeyType from pydantic_settings import BaseSettings, SettingsConfigDict from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str class InvokeAISettings(BaseSettings): """Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" initconf: ClassVar[Optional[DictConfig]] = None argparse_groups: ClassVar[Dict[str, Any]] = {} model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None: """Call to parse command-line arguments.""" parser = self.get_parser() opt, unknown_opts = parser.parse_known_args(argv) if len(unknown_opts) > 0: print("Unknown args:", unknown_opts) for name in self.model_fields: if name not in self._excluded(): value = getattr(opt, name) if isinstance(value, ListConfig): value = list(value) elif isinstance(value, DictConfig): value = dict(value) setattr(self, name, value) def to_yaml(self) -> str: """Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later.""" cls = self.__class__ type = get_args(get_type_hints(cls)["type"])[0] field_dict: Dict[str, Dict[str, Any]] = {type: {}} for name, field in self.model_fields.items(): if name in cls._excluded_from_yaml(): continue assert isinstance(field.json_schema_extra, dict) category = ( field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized" ) value = getattr(self, name) assert isinstance(category, str) if category not in field_dict[type]: field_dict[type][category] = {} if isinstance(value, BaseModel): dump = value.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True) field_dict[type][category][name] = dump continue if isinstance(value, list): if not value or len(value) == 0: continue primitive = isinstance(value[0], get_args(DictKeyType)) if not primitive: val_list: List[Dict[str, Any]] = [] for list_val in value: if isinstance(list_val, BaseModel): dump = list_val.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True) val_list.append(dump) field_dict[type][category][name] = val_list continue # keep paths as strings to make it easier to read field_dict[type][category][name] = str(value) if isinstance(value, Path) else value conf = OmegaConf.create(field_dict) return OmegaConf.to_yaml(conf) @classmethod def add_parser_arguments(cls, parser: ArgumentParser) -> None: """Dynamically create arguments for a settings parser.""" if "type" in get_type_hints(cls): settings_stanza = get_args(get_type_hints(cls)["type"])[0] else: settings_stanza = "Uncategorized" env_prefix = getattr(cls.model_config, "env_prefix", None) env_prefix = env_prefix if env_prefix is not None else settings_stanza.upper() initconf = ( cls.initconf.get(settings_stanza) if cls.initconf and settings_stanza in cls.initconf else OmegaConf.create() ) # create an upcase version of the environment in # order to achieve case-insensitive environment # variables (the way Windows does) upcase_environ = {} for key, value in os.environ.items(): upcase_environ[key.upper()] = value fields = cls.model_fields cls.argparse_groups = {} for name, field in fields.items(): if name not in cls._excluded(): current_default = field.default category = ( field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized" ) env_name = env_prefix + "_" + name if category in initconf and name in initconf.get(category): field.default = initconf.get(category).get(name) if env_name.upper() in upcase_environ: field.default = upcase_environ[env_name.upper()] cls.add_field_argument(parser, name, field) field.default = current_default @classmethod def cmd_name(cls, command_field: str = "type") -> str: """Return the category of a setting.""" hints = get_type_hints(cls) if command_field in hints: result: str = get_args(hints[command_field])[0] return result else: return "Uncategorized" @classmethod def get_parser(cls) -> ArgumentParser: """Get the command-line parser for a setting.""" parser = PagingArgumentParser( prog=cls.cmd_name(), description=cls.__doc__, ) cls.add_parser_arguments(parser) return parser @classmethod def _excluded(cls) -> List[str]: # internal fields that shouldn't be exposed as command line options return ["type", "initconf"] @classmethod def _excluded_from_yaml(cls) -> List[str]: # combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options return [ "type", "initconf", "version", "from_file", "model", "root", "max_cache_size", "max_vram_cache_size", "always_use_cpu", "free_gpu_mem", "xformers_enabled", "tiled_decode", "lora_dir", "embedding_dir", "controlnet_dir", "conf_path", ] @classmethod def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None: """Add the argparse arguments for a setting parser.""" field_type = get_type_hints(cls).get(name) default = ( default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() ) if category := (field.json_schema_extra.get("category", None) if field.json_schema_extra else None): if category not in cls.argparse_groups: cls.argparse_groups[category] = command_parser.add_argument_group(category) argparse_group = cls.argparse_groups[category] else: argparse_group = command_parser if get_origin(field_type) == Literal: allowed_values = get_args(field.annotation) allowed_types = set() for val in allowed_values: allowed_types.add(type(val)) allowed_types_list = list(allowed_types) field_type = allowed_types_list[0] if len(allowed_types) == 1 else int_or_float_or_str argparse_group.add_argument( f"--{name}", dest=name, type=field_type, default=default, choices=allowed_values, help=field.description, ) elif get_origin(field_type) == Union: argparse_group.add_argument( f"--{name}", dest=name, type=int_or_float_or_str, default=default, help=field.description, ) elif get_origin(field_type) == list: argparse_group.add_argument( f"--{name}", dest=name, nargs="*", type=field.annotation, default=default, action=argparse.BooleanOptionalAction if field.annotation == bool else "store", help=field.description, ) else: argparse_group.add_argument( f"--{name}", dest=name, type=field.annotation, default=default, action=argparse.BooleanOptionalAction if field.annotation == bool else "store", help=field.description, )