Allow optional base model lists to be passed in argparse

This commit is contained in:
Brandon Rising 2024-03-11 19:00:46 -04:00
parent 628639c565
commit a828ea5de9

View File

@ -11,11 +11,12 @@ the command line.
from __future__ import annotations
import argparse
import json
import os
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, Type, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
from pydantic import BaseModel
@ -23,6 +24,23 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
class ParseModelListAction(argparse.Action):
"""An argparse action that parses a JSON string into a list of Pydantic models."""
model_type: Type[BaseModel]
def __init__(self, model_type: Type[BaseModel], *args, **kwargs): # type: ignore
super(ParseModelListAction, self).__init__(*args, **kwargs) # type: ignore
self.model_type = model_type
def __call__(self, parser, namespace, values, option_string=None): # type: ignore
try:
items_data = json.loads(values) # type: ignore
items = [self.model_type(**item_data) for item_data in items_data]
setattr(namespace, self.dest, items)
except Exception as e:
parser.error(f"Could not parse models: {e}")
class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
@ -152,7 +170,7 @@ class InvokeAISettings(BaseSettings):
@classmethod
def _excluded(cls) -> List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf", "remote_api_tokens"]
return ["type", "initconf"]
@classmethod
def _excluded_from_yaml(cls) -> List[str]:
@ -194,7 +212,28 @@ class InvokeAISettings(BaseSettings):
else:
argparse_group = command_parser
if get_origin(field_type) == Literal:
def matches_optional_list_of_basemodel_subclasses(field_type):
args = get_args(field_type)
for arg in args:
list_origin = get_origin(arg)
if list_origin is list:
list_args = get_args(arg)
if len(list_args) == 1 and issubclass(list_args[0], BaseModel):
return list_args[0]
return None
if name == "remote_api_tokens":
pass
if bm_type:=matches_optional_list_of_basemodel_subclasses(field_type):
argparse_group.add_argument(
f"--{name}",
dest=name,
action=ParseModelListAction,
model_type=bm_type,
type=str,
default=default,
help=field.description,
)
elif get_origin(field_type) == Literal:
allowed_values = get_args(field.annotation)
allowed_types = set()
for val in allowed_values: