mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow optional base model lists to be passed in argparse
This commit is contained in:
parent
628639c565
commit
a828ea5de9
@ -11,11 +11,12 @@ the command line.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, Type, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
|
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -23,6 +24,23 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||||
|
|
||||||
|
class ParseModelListAction(argparse.Action):
|
||||||
|
"""An argparse action that parses a JSON string into a list of Pydantic models."""
|
||||||
|
|
||||||
|
model_type: Type[BaseModel]
|
||||||
|
|
||||||
|
def __init__(self, model_type: Type[BaseModel], *args, **kwargs): # type: ignore
|
||||||
|
super(ParseModelListAction, self).__init__(*args, **kwargs) # type: ignore
|
||||||
|
self.model_type = model_type
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None): # type: ignore
|
||||||
|
try:
|
||||||
|
items_data = json.loads(values) # type: ignore
|
||||||
|
items = [self.model_type(**item_data) for item_data in items_data]
|
||||||
|
setattr(namespace, self.dest, items)
|
||||||
|
except Exception as e:
|
||||||
|
parser.error(f"Could not parse models: {e}")
|
||||||
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||||
@ -152,7 +170,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(cls) -> List[str]:
|
def _excluded(cls) -> List[str]:
|
||||||
# internal fields that shouldn't be exposed as command line options
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ["type", "initconf", "remote_api_tokens"]
|
return ["type", "initconf"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(cls) -> List[str]:
|
def _excluded_from_yaml(cls) -> List[str]:
|
||||||
@ -194,7 +212,28 @@ class InvokeAISettings(BaseSettings):
|
|||||||
else:
|
else:
|
||||||
argparse_group = command_parser
|
argparse_group = command_parser
|
||||||
|
|
||||||
if get_origin(field_type) == Literal:
|
def matches_optional_list_of_basemodel_subclasses(field_type):
|
||||||
|
args = get_args(field_type)
|
||||||
|
for arg in args:
|
||||||
|
list_origin = get_origin(arg)
|
||||||
|
if list_origin is list:
|
||||||
|
list_args = get_args(arg)
|
||||||
|
if len(list_args) == 1 and issubclass(list_args[0], BaseModel):
|
||||||
|
return list_args[0]
|
||||||
|
return None
|
||||||
|
if name == "remote_api_tokens":
|
||||||
|
pass
|
||||||
|
if bm_type:=matches_optional_list_of_basemodel_subclasses(field_type):
|
||||||
|
argparse_group.add_argument(
|
||||||
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
|
action=ParseModelListAction,
|
||||||
|
model_type=bm_type,
|
||||||
|
type=str,
|
||||||
|
default=default,
|
||||||
|
help=field.description,
|
||||||
|
)
|
||||||
|
elif get_origin(field_type) == Literal:
|
||||||
allowed_values = get_args(field.annotation)
|
allowed_values = get_args(field.annotation)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
for val in allowed_values:
|
for val in allowed_values:
|
||||||
|
Loading…
Reference in New Issue
Block a user