fix type mismatches in invokeai.app.services.config.config_base & config_default

This commit is contained in:
Lincoln Stein 2023-11-26 17:00:27 -05:00 committed by psychedelicious
parent e509d719ee
commit eee863e380
4 changed files with 62 additions and 86 deletions

View File

@ -1,4 +1,5 @@
from typing import Any from typing import Any
import sys
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
@ -7,8 +8,12 @@ from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config # parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the # which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file. # values from the command line or config file.
from invokeai.version.invokeai_version import __version__
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
app_config.parse_args() app_config.parse_args()
if app_config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
if True: # hack to make flake8 happy with imports coming after setting up the config if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio import asyncio
@ -34,7 +39,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from invokeai.version.invokeai_version import __version__
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
@ -222,6 +226,7 @@ app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")
def invoke_api() -> None: def invoke_api() -> None:
def find_port(port: int) -> int: def find_port(port: int) -> int:
"""Find a port not in use starting at given port""" """Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
@ -273,7 +278,4 @@ def invoke_api() -> None:
if __name__ == "__main__": if __name__ == "__main__":
if app_config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_api() invoke_api()

View File

@ -15,7 +15,7 @@ import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
@ -24,10 +24,7 @@ from invokeai.app.services.config.config_common import PagingArgumentParser, int
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
""" """Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
"""
initconf: ClassVar[Optional[DictConfig]] = None initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict] = {} argparse_groups: ClassVar[Dict] = {}
@ -35,6 +32,7 @@ class InvokeAISettings(BaseSettings):
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True) model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
def parse_args(self, argv: Optional[list] = sys.argv[1:]): def parse_args(self, argv: Optional[list] = sys.argv[1:]):
"""Call to parse command-line arguments."""
parser = self.get_parser() parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv) opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0: if len(unknown_opts) > 0:
@ -49,20 +47,19 @@ class InvokeAISettings(BaseSettings):
setattr(self, name, value) setattr(self, name, value)
def to_yaml(self) -> str: def to_yaml(self) -> str:
""" """Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__ cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0] type = get_args(get_type_hints(cls)["type"])[0]
field_dict = {type: {}} field_dict: Dict[str, Dict[str, Any]] = {type: {}}
for name, field in self.model_fields.items(): for name, field in self.model_fields.items():
if name in cls._excluded_from_yaml(): if name in cls._excluded_from_yaml():
continue continue
assert isinstance(field.json_schema_extra, dict)
category = ( category = (
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized" field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
) )
value = getattr(self, name) value = getattr(self, name)
assert isinstance(category, str)
if category not in field_dict[type]: if category not in field_dict[type]:
field_dict[type][category] = {} field_dict[type][category] = {}
# keep paths as strings to make it easier to read # keep paths as strings to make it easier to read
@ -72,6 +69,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def add_parser_arguments(cls, parser): def add_parser_arguments(cls, parser):
"""Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls): if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0] settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else: else:
@ -116,6 +114,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def cmd_name(cls, command_field: str = "type") -> str: def cmd_name(cls, command_field: str = "type") -> str:
"""Return the category of a setting."""
hints = get_type_hints(cls) hints = get_type_hints(cls)
if command_field in hints: if command_field in hints:
return get_args(hints[command_field])[0] return get_args(hints[command_field])[0]
@ -124,6 +123,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def get_parser(cls) -> ArgumentParser: def get_parser(cls) -> ArgumentParser:
"""Get the command-line parser for a setting."""
parser = PagingArgumentParser( parser = PagingArgumentParser(
prog=cls.cmd_name(), prog=cls.cmd_name(),
description=cls.__doc__, description=cls.__doc__,
@ -156,6 +156,7 @@ class InvokeAISettings(BaseSettings):
@classmethod @classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None): def add_field_argument(cls, command_parser, name: str, field, default_override=None):
"""Add the argparse arguments for a setting parser."""
field_type = get_type_hints(cls).get(name) field_type = get_type_hints(cls).get(name)
default = ( default = (
default_override default_override

View File

@ -177,6 +177,7 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hint
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
from pydantic.fields import JsonDict
from pydantic_settings import SettingsConfigDict from pydantic_settings import SettingsConfigDict
from .config_base import InvokeAISettings from .config_base import InvokeAISettings
@ -188,28 +189,24 @@ DEFAULT_MAX_VRAM = 0.5
class Categories(object): class Categories(object):
WebServer = {"category": "Web Server"} """Category headers for configuration variable groups."""
Features = {"category": "Features"}
Paths = {"category": "Paths"} WebServer: JsonDict = {"category": "Web Server"}
Logging = {"category": "Logging"} Features: JsonDict = {"category": "Features"}
Development = {"category": "Development"} Paths: JsonDict = {"category": "Paths"}
Other = {"category": "Other"} Logging: JsonDict = {"category": "Logging"}
ModelCache = {"category": "Model Cache"} Development: JsonDict = {"category": "Development"}
Device = {"category": "Device"} Other: JsonDict = {"category": "Other"}
Generation = {"category": "Generation"} ModelCache: JsonDict = {"category": "Model Cache"}
Queue = {"category": "Queue"} Device: JsonDict = {"category": "Device"}
Nodes = {"category": "Nodes"} Generation: JsonDict = {"category": "Generation"}
MemoryPerformance = {"category": "Memory/Performance"} Queue: JsonDict = {"category": "Queue"}
Nodes: JsonDict = {"category": "Nodes"}
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
class InvokeAIAppConfig(InvokeAISettings): class InvokeAIAppConfig(InvokeAISettings):
""" """Configuration object for InvokeAI App."""
Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
"""
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict]] = None singleton_init: ClassVar[Optional[Dict]] = None
@ -303,8 +300,8 @@ class InvokeAIAppConfig(InvokeAISettings):
clobber=False, clobber=False,
): ):
""" """
Update settings with contents of init file, environment, and Update settings with contents of init file, environment, and command-line settings.
command-line settings.
:param conf: alternate Omegaconf dictionary object :param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list :param argv: aternate sys.argv list
:param clobber: ovewrite any initialization parameters passed during initialization :param clobber: ovewrite any initialization parameters passed during initialization
@ -337,13 +334,9 @@ class InvokeAIAppConfig(InvokeAISettings):
@classmethod @classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig: def get_config(cls, **kwargs) -> InvokeAIAppConfig:
""" """Return a singleton InvokeAIAppConfig configuration object."""
This returns a singleton InvokeAIAppConfig configuration object.
"""
if ( if (
cls.singleton_config is None cls.singleton_config is None or type(cls.singleton_config) is not cls or (kwargs and cls.singleton_init != kwargs)
or type(cls.singleton_config) is not cls
or (kwargs and cls.singleton_init != kwargs)
): ):
cls.singleton_config = cls(**kwargs) cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs cls.singleton_init = kwargs
@ -351,9 +344,7 @@ class InvokeAIAppConfig(InvokeAISettings):
@property @property
def root_path(self) -> Path: def root_path(self) -> Path:
""" """Path to the runtime root directory."""
Path to the runtime root directory
"""
if self.root: if self.root:
root = Path(self.root).expanduser().absolute() root = Path(self.root).expanduser().absolute()
else: else:
@ -363,9 +354,7 @@ class InvokeAIAppConfig(InvokeAISettings):
@property @property
def root_dir(self) -> Path: def root_dir(self) -> Path:
""" """Alias for above."""
Alias for above.
"""
return self.root_path return self.root_path
def _resolve(self, partial_path: Path) -> Path: def _resolve(self, partial_path: Path) -> Path:
@ -373,108 +362,94 @@ class InvokeAIAppConfig(InvokeAISettings):
@property @property
def init_file_path(self) -> Path: def init_file_path(self) -> Path:
""" """Path to invokeai.yaml."""
Path to invokeai.yaml
"""
return self._resolve(INIT_FILE) return self._resolve(INIT_FILE)
@property @property
def output_path(self) -> Path: def output_path(self) -> Path:
""" """Path to defaults outputs directory."""
Path to defaults outputs directory. assert self.outdir
"""
return self._resolve(self.outdir) return self._resolve(self.outdir)
@property @property
def db_path(self) -> Path: def db_path(self) -> Path:
""" """Path to the invokeai.db file."""
Path to the invokeai.db file. assert self.db_dir
"""
return self._resolve(self.db_dir) / DB_FILE return self._resolve(self.db_dir) / DB_FILE
@property @property
def model_conf_path(self) -> Path: def model_conf_path(self) -> Path:
""" """Path to models configuration file."""
Path to models configuration file. assert self.conf_path
"""
return self._resolve(self.conf_path) return self._resolve(self.conf_path)
@property @property
def legacy_conf_path(self) -> Path: def legacy_conf_path(self) -> Path:
""" """Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
Path to directory of legacy configuration files (e.g. v1-inference.yaml) assert self.legacy_conf_dir
"""
return self._resolve(self.legacy_conf_dir) return self._resolve(self.legacy_conf_dir)
@property @property
def models_path(self) -> Path: def models_path(self) -> Path:
""" """Path to the models directory."""
Path to the models directory assert self.models_dir
"""
return self._resolve(self.models_dir) return self._resolve(self.models_dir)
@property @property
def custom_nodes_path(self) -> Path: def custom_nodes_path(self) -> Path:
""" """Path to the custom nodes directory."""
Path to the custom nodes directory
"""
return self._resolve(self.custom_nodes_dir) return self._resolve(self.custom_nodes_dir)
# the following methods support legacy calls leftover from the Globals era # the following methods support legacy calls leftover from the Globals era
@property @property
def full_precision(self) -> bool: def full_precision(self) -> bool:
"""Return true if precision set to float32""" """Return true if precision set to float32."""
return self.precision == "float32" return self.precision == "float32"
@property @property
def try_patchmatch(self) -> bool: def try_patchmatch(self) -> bool:
"""Return true if patchmatch true""" """Return true if patchmatch true."""
return self.patchmatch return self.patchmatch
@property @property
def nsfw_checker(self) -> bool: def nsfw_checker(self) -> bool:
"""NSFW node is always active and disabled from Web UIe""" """Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
return True return True
@property @property
def invisible_watermark(self) -> bool: def invisible_watermark(self) -> bool:
"""invisible watermark node is always active and disabled from Web UIe""" """Return value of invisible watermark. It is always active and disabled from Web UI."""
return True return True
@property @property
def ram_cache_size(self) -> Union[Literal["auto"], float]: def ram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the ram cache size using the legacy or modern setting."""
return self.max_cache_size or self.ram return self.max_cache_size or self.ram
@property @property
def vram_cache_size(self) -> Union[Literal["auto"], float]: def vram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the vram cache size using the legacy or modern setting."""
return self.max_vram_cache_size or self.vram return self.max_vram_cache_size or self.vram
@property @property
def use_cpu(self) -> bool: def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
return self.always_use_cpu or self.device == "cpu" return self.always_use_cpu or self.device == "cpu"
@property @property
def disable_xformers(self) -> bool: def disable_xformers(self) -> bool:
""" """Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
Return true if enable_xformers is false (reversed logic)
and attention type is not set to xformers.
"""
disabled_in_config = not self.xformers_enabled disabled_in_config = not self.xformers_enabled
return disabled_in_config and self.attention_type != "xformers" return disabled_in_config and self.attention_type != "xformers"
@staticmethod @staticmethod
def find_root() -> Path: def find_root() -> Path:
""" """Choose the runtime root directory when not specified on command line or init file."""
Choose the runtime root directory when not specified on command line or
init file.
"""
return _find_root() return _find_root()
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig: def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
""" """Legacy function which returns InvokeAIAppConfig.get_config()."""
Legacy function which returns InvokeAIAppConfig.get_config()
"""
return InvokeAIAppConfig.get_config(**kwargs) return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -227,8 +227,6 @@ module = [
"invokeai.app.api.routers.models", "invokeai.app.api.routers.models",
"invokeai.app.invocations.compel", "invokeai.app.invocations.compel",
"invokeai.app.invocations.latent", "invokeai.app.invocations.latent",
"invokeai.app.services.config.config_base",
"invokeai.app.services.config.config_default",
"invokeai.app.services.invocation_stats.invocation_stats_default", "invokeai.app.services.invocation_stats.invocation_stats_default",
"invokeai.app.services.model_manager.model_manager_base", "invokeai.app.services.model_manager.model_manager_base",
"invokeai.app.services.model_manager.model_manager_default", "invokeai.app.services.model_manager.model_manager_default",