diff --git a/invokeai/app/services/config/config_base.py b/invokeai/app/services/config/config_base.py index 37d84f57bf..3301f191eb 100644 --- a/invokeai/app/services/config/config_base.py +++ b/invokeai/app/services/config/config_base.py @@ -11,11 +11,12 @@ the command line. from __future__ import annotations import argparse +import json import os import sys from argparse import ArgumentParser from pathlib import Path -from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints +from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, Type, get_args, get_origin, get_type_hints from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf from pydantic import BaseModel @@ -23,6 +24,23 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str +class ParseModelListAction(argparse.Action): + """An argparse action that parses a JSON string into a list of Pydantic models.""" + + model_type: Type[BaseModel] + + def __init__(self, model_type: Type[BaseModel], *args, **kwargs): # type: ignore + super(ParseModelListAction, self).__init__(*args, **kwargs) # type: ignore + self.model_type = model_type + + def __call__(self, parser, namespace, values, option_string=None): # type: ignore + try: + items_data = json.loads(values) # type: ignore + items = [self.model_type(**item_data) for item_data in items_data] + setattr(namespace, self.dest, items) + except Exception as e: + parser.error(f"Could not parse models: {e}") + class InvokeAISettings(BaseSettings): """Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" @@ -152,7 +170,7 @@ class InvokeAISettings(BaseSettings): @classmethod def _excluded(cls) -> List[str]: # internal fields that shouldn't be exposed as command line options - return ["type", "initconf", "remote_api_tokens"] + return ["type", "initconf"] @classmethod def _excluded_from_yaml(cls) -> List[str]: @@ -194,7 +212,28 @@ class InvokeAISettings(BaseSettings): else: argparse_group = command_parser - if get_origin(field_type) == Literal: + def matches_optional_list_of_basemodel_subclasses(field_type): + args = get_args(field_type) + for arg in args: + list_origin = get_origin(arg) + if list_origin is list: + list_args = get_args(arg) + if len(list_args) == 1 and issubclass(list_args[0], BaseModel): + return list_args[0] + return None + if name == "remote_api_tokens": + pass + if bm_type:=matches_optional_list_of_basemodel_subclasses(field_type): + argparse_group.add_argument( + f"--{name}", + dest=name, + action=ParseModelListAction, + model_type=bm_type, + type=str, + default=default, + help=field.description, + ) + elif get_origin(field_type) == Literal: allowed_values = get_args(field.annotation) allowed_types = set() for val in allowed_values: