Allow optional base model lists to be passed in argparse

This commit is contained in:
Brandon Rising 2024-03-11 19:00:46 -04:00
parent 628639c565
commit a828ea5de9

View File

@ -11,11 +11,12 @@ the command line.
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import json
import os import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, Type, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
from pydantic import BaseModel from pydantic import BaseModel
@ -23,6 +24,23 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
class ParseModelListAction(argparse.Action):
"""An argparse action that parses a JSON string into a list of Pydantic models."""
model_type: Type[BaseModel]
def __init__(self, model_type: Type[BaseModel], *args, **kwargs): # type: ignore
super(ParseModelListAction, self).__init__(*args, **kwargs) # type: ignore
self.model_type = model_type
def __call__(self, parser, namespace, values, option_string=None): # type: ignore
try:
items_data = json.loads(values) # type: ignore
items = [self.model_type(**item_data) for item_data in items_data]
setattr(namespace, self.dest, items)
except Exception as e:
parser.error(f"Could not parse models: {e}")
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" """Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
@ -152,7 +170,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def _excluded(cls) -> List[str]: def _excluded(cls) -> List[str]:
# internal fields that shouldn't be exposed as command line options # internal fields that shouldn't be exposed as command line options
return ["type", "initconf", "remote_api_tokens"] return ["type", "initconf"]
@classmethod @classmethod
def _excluded_from_yaml(cls) -> List[str]: def _excluded_from_yaml(cls) -> List[str]:
@ -194,7 +212,28 @@ class InvokeAISettings(BaseSettings):
else: else:
argparse_group = command_parser argparse_group = command_parser
if get_origin(field_type) == Literal: def matches_optional_list_of_basemodel_subclasses(field_type):
args = get_args(field_type)
for arg in args:
list_origin = get_origin(arg)
if list_origin is list:
list_args = get_args(arg)
if len(list_args) == 1 and issubclass(list_args[0], BaseModel):
return list_args[0]
return None
if name == "remote_api_tokens":
pass
if bm_type:=matches_optional_list_of_basemodel_subclasses(field_type):
argparse_group.add_argument(
f"--{name}",
dest=name,
action=ParseModelListAction,
model_type=bm_type,
type=str,
default=default,
help=field.description,
)
elif get_origin(field_type) == Literal:
allowed_values = get_args(field.annotation) allowed_values = get_args(field.annotation)
allowed_types = set() allowed_types = set()
for val in allowed_values: for val in allowed_values: