mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix type mismatches in invokeai.app.services.config.config_base & config_default
This commit is contained in:
parent
e509d719ee
commit
eee863e380
@ -1,4 +1,5 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
import sys
|
||||||
|
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
@ -7,8 +8,12 @@ from .services.config import InvokeAIAppConfig
|
|||||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||||
# values from the command line or config file.
|
# values from the command line or config file.
|
||||||
|
from invokeai.version.invokeai_version import __version__
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
app_config.parse_args()
|
app_config.parse_args()
|
||||||
|
if app_config.version:
|
||||||
|
print(f"InvokeAI version {__version__}")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -34,7 +39,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
from invokeai.version.invokeai_version import __version__
|
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
@ -222,6 +226,7 @@ app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")
|
|||||||
|
|
||||||
|
|
||||||
def invoke_api() -> None:
|
def invoke_api() -> None:
|
||||||
|
|
||||||
def find_port(port: int) -> int:
|
def find_port(port: int) -> int:
|
||||||
"""Find a port not in use starting at given port"""
|
"""Find a port not in use starting at given port"""
|
||||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||||
@ -273,7 +278,4 @@ def invoke_api() -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if app_config.version:
|
invoke_api()
|
||||||
print(f"InvokeAI version {__version__}")
|
|
||||||
else:
|
|
||||||
invoke_api()
|
|
||||||
|
@ -15,7 +15,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
@ -24,10 +24,7 @@ from invokeai.app.services.config.config_common import PagingArgumentParser, int
|
|||||||
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
"""
|
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||||
Runtime configuration settings in which default values are
|
|
||||||
read from an omegaconf .yaml file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
initconf: ClassVar[Optional[DictConfig]] = None
|
initconf: ClassVar[Optional[DictConfig]] = None
|
||||||
argparse_groups: ClassVar[Dict] = {}
|
argparse_groups: ClassVar[Dict] = {}
|
||||||
@ -35,6 +32,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||||
|
|
||||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||||
|
"""Call to parse command-line arguments."""
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt, unknown_opts = parser.parse_known_args(argv)
|
opt, unknown_opts = parser.parse_known_args(argv)
|
||||||
if len(unknown_opts) > 0:
|
if len(unknown_opts) > 0:
|
||||||
@ -49,20 +47,19 @@ class InvokeAISettings(BaseSettings):
|
|||||||
setattr(self, name, value)
|
setattr(self, name, value)
|
||||||
|
|
||||||
def to_yaml(self) -> str:
|
def to_yaml(self) -> str:
|
||||||
"""
|
"""Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
|
||||||
Return a YAML string representing our settings. This can be used
|
|
||||||
as the contents of `invokeai.yaml` to restore settings later.
|
|
||||||
"""
|
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
type = get_args(get_type_hints(cls)["type"])[0]
|
||||||
field_dict = {type: {}}
|
field_dict: Dict[str, Dict[str, Any]] = {type: {}}
|
||||||
for name, field in self.model_fields.items():
|
for name, field in self.model_fields.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
|
assert isinstance(field.json_schema_extra, dict)
|
||||||
category = (
|
category = (
|
||||||
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||||
)
|
)
|
||||||
value = getattr(self, name)
|
value = getattr(self, name)
|
||||||
|
assert isinstance(category, str)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = {}
|
field_dict[type][category] = {}
|
||||||
# keep paths as strings to make it easier to read
|
# keep paths as strings to make it easier to read
|
||||||
@ -72,6 +69,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_parser_arguments(cls, parser):
|
def add_parser_arguments(cls, parser):
|
||||||
|
"""Dynamically create arguments for a settings parser."""
|
||||||
if "type" in get_type_hints(cls):
|
if "type" in get_type_hints(cls):
|
||||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||||
else:
|
else:
|
||||||
@ -116,6 +114,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cmd_name(cls, command_field: str = "type") -> str:
|
def cmd_name(cls, command_field: str = "type") -> str:
|
||||||
|
"""Return the category of a setting."""
|
||||||
hints = get_type_hints(cls)
|
hints = get_type_hints(cls)
|
||||||
if command_field in hints:
|
if command_field in hints:
|
||||||
return get_args(hints[command_field])[0]
|
return get_args(hints[command_field])[0]
|
||||||
@ -124,6 +123,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_parser(cls) -> ArgumentParser:
|
def get_parser(cls) -> ArgumentParser:
|
||||||
|
"""Get the command-line parser for a setting."""
|
||||||
parser = PagingArgumentParser(
|
parser = PagingArgumentParser(
|
||||||
prog=cls.cmd_name(),
|
prog=cls.cmd_name(),
|
||||||
description=cls.__doc__,
|
description=cls.__doc__,
|
||||||
@ -156,6 +156,7 @@ class InvokeAISettings(BaseSettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||||
|
"""Add the argparse arguments for a setting parser."""
|
||||||
field_type = get_type_hints(cls).get(name)
|
field_type = get_type_hints(cls).get(name)
|
||||||
default = (
|
default = (
|
||||||
default_override
|
default_override
|
||||||
|
@ -177,6 +177,7 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hint
|
|||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
|
from pydantic.fields import JsonDict
|
||||||
from pydantic_settings import SettingsConfigDict
|
from pydantic_settings import SettingsConfigDict
|
||||||
|
|
||||||
from .config_base import InvokeAISettings
|
from .config_base import InvokeAISettings
|
||||||
@ -188,28 +189,24 @@ DEFAULT_MAX_VRAM = 0.5
|
|||||||
|
|
||||||
|
|
||||||
class Categories(object):
|
class Categories(object):
|
||||||
WebServer = {"category": "Web Server"}
|
"""Category headers for configuration variable groups."""
|
||||||
Features = {"category": "Features"}
|
|
||||||
Paths = {"category": "Paths"}
|
WebServer: JsonDict = {"category": "Web Server"}
|
||||||
Logging = {"category": "Logging"}
|
Features: JsonDict = {"category": "Features"}
|
||||||
Development = {"category": "Development"}
|
Paths: JsonDict = {"category": "Paths"}
|
||||||
Other = {"category": "Other"}
|
Logging: JsonDict = {"category": "Logging"}
|
||||||
ModelCache = {"category": "Model Cache"}
|
Development: JsonDict = {"category": "Development"}
|
||||||
Device = {"category": "Device"}
|
Other: JsonDict = {"category": "Other"}
|
||||||
Generation = {"category": "Generation"}
|
ModelCache: JsonDict = {"category": "Model Cache"}
|
||||||
Queue = {"category": "Queue"}
|
Device: JsonDict = {"category": "Device"}
|
||||||
Nodes = {"category": "Nodes"}
|
Generation: JsonDict = {"category": "Generation"}
|
||||||
MemoryPerformance = {"category": "Memory/Performance"}
|
Queue: JsonDict = {"category": "Queue"}
|
||||||
|
Nodes: JsonDict = {"category": "Nodes"}
|
||||||
|
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
"""
|
"""Configuration object for InvokeAI App."""
|
||||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
|
||||||
the command-line client (recommended for experts only), or
|
|
||||||
"invokeai-web" to launch the web server. Global options
|
|
||||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
|
||||||
setting environment variables INVOKEAI_<setting>.
|
|
||||||
"""
|
|
||||||
|
|
||||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||||
singleton_init: ClassVar[Optional[Dict]] = None
|
singleton_init: ClassVar[Optional[Dict]] = None
|
||||||
@ -303,8 +300,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
clobber=False,
|
clobber=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update settings with contents of init file, environment, and
|
Update settings with contents of init file, environment, and command-line settings.
|
||||||
command-line settings.
|
|
||||||
:param conf: alternate Omegaconf dictionary object
|
:param conf: alternate Omegaconf dictionary object
|
||||||
:param argv: aternate sys.argv list
|
:param argv: aternate sys.argv list
|
||||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||||
@ -337,13 +334,9 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||||
"""
|
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
|
||||||
"""
|
|
||||||
if (
|
if (
|
||||||
cls.singleton_config is None
|
cls.singleton_config is None or type(cls.singleton_config) is not cls or (kwargs and cls.singleton_init != kwargs)
|
||||||
or type(cls.singleton_config) is not cls
|
|
||||||
or (kwargs and cls.singleton_init != kwargs)
|
|
||||||
):
|
):
|
||||||
cls.singleton_config = cls(**kwargs)
|
cls.singleton_config = cls(**kwargs)
|
||||||
cls.singleton_init = kwargs
|
cls.singleton_init = kwargs
|
||||||
@ -351,9 +344,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def root_path(self) -> Path:
|
def root_path(self) -> Path:
|
||||||
"""
|
"""Path to the runtime root directory."""
|
||||||
Path to the runtime root directory
|
|
||||||
"""
|
|
||||||
if self.root:
|
if self.root:
|
||||||
root = Path(self.root).expanduser().absolute()
|
root = Path(self.root).expanduser().absolute()
|
||||||
else:
|
else:
|
||||||
@ -363,9 +354,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def root_dir(self) -> Path:
|
def root_dir(self) -> Path:
|
||||||
"""
|
"""Alias for above."""
|
||||||
Alias for above.
|
|
||||||
"""
|
|
||||||
return self.root_path
|
return self.root_path
|
||||||
|
|
||||||
def _resolve(self, partial_path: Path) -> Path:
|
def _resolve(self, partial_path: Path) -> Path:
|
||||||
@ -373,108 +362,94 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def init_file_path(self) -> Path:
|
def init_file_path(self) -> Path:
|
||||||
"""
|
"""Path to invokeai.yaml."""
|
||||||
Path to invokeai.yaml
|
|
||||||
"""
|
|
||||||
return self._resolve(INIT_FILE)
|
return self._resolve(INIT_FILE)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_path(self) -> Path:
|
def output_path(self) -> Path:
|
||||||
"""
|
"""Path to defaults outputs directory."""
|
||||||
Path to defaults outputs directory.
|
assert self.outdir
|
||||||
"""
|
|
||||||
return self._resolve(self.outdir)
|
return self._resolve(self.outdir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_path(self) -> Path:
|
def db_path(self) -> Path:
|
||||||
"""
|
"""Path to the invokeai.db file."""
|
||||||
Path to the invokeai.db file.
|
assert self.db_dir
|
||||||
"""
|
|
||||||
return self._resolve(self.db_dir) / DB_FILE
|
return self._resolve(self.db_dir) / DB_FILE
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_conf_path(self) -> Path:
|
def model_conf_path(self) -> Path:
|
||||||
"""
|
"""Path to models configuration file."""
|
||||||
Path to models configuration file.
|
assert self.conf_path
|
||||||
"""
|
|
||||||
return self._resolve(self.conf_path)
|
return self._resolve(self.conf_path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def legacy_conf_path(self) -> Path:
|
def legacy_conf_path(self) -> Path:
|
||||||
"""
|
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
||||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
assert self.legacy_conf_dir
|
||||||
"""
|
|
||||||
return self._resolve(self.legacy_conf_dir)
|
return self._resolve(self.legacy_conf_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models_path(self) -> Path:
|
def models_path(self) -> Path:
|
||||||
"""
|
"""Path to the models directory."""
|
||||||
Path to the models directory
|
assert self.models_dir
|
||||||
"""
|
|
||||||
return self._resolve(self.models_dir)
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def custom_nodes_path(self) -> Path:
|
def custom_nodes_path(self) -> Path:
|
||||||
"""
|
"""Path to the custom nodes directory."""
|
||||||
Path to the custom nodes directory
|
|
||||||
"""
|
|
||||||
return self._resolve(self.custom_nodes_dir)
|
return self._resolve(self.custom_nodes_dir)
|
||||||
|
|
||||||
# the following methods support legacy calls leftover from the Globals era
|
# the following methods support legacy calls leftover from the Globals era
|
||||||
@property
|
@property
|
||||||
def full_precision(self) -> bool:
|
def full_precision(self) -> bool:
|
||||||
"""Return true if precision set to float32"""
|
"""Return true if precision set to float32."""
|
||||||
return self.precision == "float32"
|
return self.precision == "float32"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def try_patchmatch(self) -> bool:
|
def try_patchmatch(self) -> bool:
|
||||||
"""Return true if patchmatch true"""
|
"""Return true if patchmatch true."""
|
||||||
return self.patchmatch
|
return self.patchmatch
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nsfw_checker(self) -> bool:
|
def nsfw_checker(self) -> bool:
|
||||||
"""NSFW node is always active and disabled from Web UIe"""
|
"""Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def invisible_watermark(self) -> bool:
|
def invisible_watermark(self) -> bool:
|
||||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
"""Return value of invisible watermark. It is always active and disabled from Web UI."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||||
|
"""Return the ram cache size using the legacy or modern setting."""
|
||||||
return self.max_cache_size or self.ram
|
return self.max_cache_size or self.ram
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||||
|
"""Return the vram cache size using the legacy or modern setting."""
|
||||||
return self.max_vram_cache_size or self.vram
|
return self.max_vram_cache_size or self.vram
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_cpu(self) -> bool:
|
def use_cpu(self) -> bool:
|
||||||
|
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
|
||||||
return self.always_use_cpu or self.device == "cpu"
|
return self.always_use_cpu or self.device == "cpu"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disable_xformers(self) -> bool:
|
def disable_xformers(self) -> bool:
|
||||||
"""
|
"""Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
|
||||||
Return true if enable_xformers is false (reversed logic)
|
|
||||||
and attention type is not set to xformers.
|
|
||||||
"""
|
|
||||||
disabled_in_config = not self.xformers_enabled
|
disabled_in_config = not self.xformers_enabled
|
||||||
return disabled_in_config and self.attention_type != "xformers"
|
return disabled_in_config and self.attention_type != "xformers"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_root() -> Path:
|
def find_root() -> Path:
|
||||||
"""
|
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||||
Choose the runtime root directory when not specified on command line or
|
|
||||||
init file.
|
|
||||||
"""
|
|
||||||
return _find_root()
|
return _find_root()
|
||||||
|
|
||||||
|
|
||||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||||
"""
|
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
|
||||||
"""
|
|
||||||
return InvokeAIAppConfig.get_config(**kwargs)
|
return InvokeAIAppConfig.get_config(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -227,8 +227,6 @@ module = [
|
|||||||
"invokeai.app.api.routers.models",
|
"invokeai.app.api.routers.models",
|
||||||
"invokeai.app.invocations.compel",
|
"invokeai.app.invocations.compel",
|
||||||
"invokeai.app.invocations.latent",
|
"invokeai.app.invocations.latent",
|
||||||
"invokeai.app.services.config.config_base",
|
|
||||||
"invokeai.app.services.config.config_default",
|
|
||||||
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
"invokeai.app.services.invocation_stats.invocation_stats_default",
|
||||||
"invokeai.app.services.model_manager.model_manager_base",
|
"invokeai.app.services.model_manager.model_manager_base",
|
||||||
"invokeai.app.services.model_manager.model_manager_default",
|
"invokeai.app.services.model_manager.model_manager_default",
|
||||||
|
Loading…
Reference in New Issue
Block a user