mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow optional base model lists to be passed in argparse
This commit is contained in:
parent
628639c565
commit
a828ea5de9
@ -11,11 +11,12 @@ the command line.
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, Type, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
|
||||
from pydantic import BaseModel
|
||||
@ -23,6 +24,23 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||
|
||||
class ParseModelListAction(argparse.Action):
|
||||
"""An argparse action that parses a JSON string into a list of Pydantic models."""
|
||||
|
||||
model_type: Type[BaseModel]
|
||||
|
||||
def __init__(self, model_type: Type[BaseModel], *args, **kwargs): # type: ignore
|
||||
super(ParseModelListAction, self).__init__(*args, **kwargs) # type: ignore
|
||||
self.model_type = model_type
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None): # type: ignore
|
||||
try:
|
||||
items_data = json.loads(values) # type: ignore
|
||||
items = [self.model_type(**item_data) for item_data in items_data]
|
||||
setattr(namespace, self.dest, items)
|
||||
except Exception as e:
|
||||
parser.error(f"Could not parse models: {e}")
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||
@ -152,7 +170,7 @@ class InvokeAISettings(BaseSettings):
|
||||
@classmethod
|
||||
def _excluded(cls) -> List[str]:
|
||||
# internal fields that shouldn't be exposed as command line options
|
||||
return ["type", "initconf", "remote_api_tokens"]
|
||||
return ["type", "initconf"]
|
||||
|
||||
@classmethod
|
||||
def _excluded_from_yaml(cls) -> List[str]:
|
||||
@ -194,7 +212,28 @@ class InvokeAISettings(BaseSettings):
|
||||
else:
|
||||
argparse_group = command_parser
|
||||
|
||||
if get_origin(field_type) == Literal:
|
||||
def matches_optional_list_of_basemodel_subclasses(field_type):
|
||||
args = get_args(field_type)
|
||||
for arg in args:
|
||||
list_origin = get_origin(arg)
|
||||
if list_origin is list:
|
||||
list_args = get_args(arg)
|
||||
if len(list_args) == 1 and issubclass(list_args[0], BaseModel):
|
||||
return list_args[0]
|
||||
return None
|
||||
if name == "remote_api_tokens":
|
||||
pass
|
||||
if bm_type:=matches_optional_list_of_basemodel_subclasses(field_type):
|
||||
argparse_group.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
action=ParseModelListAction,
|
||||
model_type=bm_type,
|
||||
type=str,
|
||||
default=default,
|
||||
help=field.description,
|
||||
)
|
||||
elif get_origin(field_type) == Literal:
|
||||
allowed_values = get_args(field.annotation)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
|
Loading…
Reference in New Issue
Block a user