Fix potential race condition in config system (#3466)

There was a potential gotcha in the config system that was previously
merged with main. The `InvokeAIAppConfig` object was configuring itself
from the command line and configuration file within its initialization
routine. However, this could cause it to read `argv` from the command
line at unexpected times. This PR fixes the object so that it only reads
from the init file and command line when its `parse_args()` method is
explicitly called, which should be done at startup time in any top level
script that uses it.

In addition, using the `get_invokeai_config()` function to get a global
version of the config object didn't feel pythonic to me, so I have
changed this to `InvokeAIAppConfig.get_config()` throughout.

## Updated Usage

In the main script, at startup time, do the following:

```
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
config.parse_args()
```

In non-main scripts, it is not necessary (or recommended) to call
`parse_args()`:
```
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
```

The configuration object properties can be overridden when
`get_config()` is called by passing initialization values in the usual
way. If a property is set this way, then it will not be changed by
subsequent calls to `parse_args()`, but can only be changed by
explicitly setting the property.

```
config = InvokeAIAppConfig.get_config(nsfw_checker=True)
config.parse_args(argv=['--no-nsfw_checker'])
config.nsfw_checker
# True
```

You may specify alternative argv lists and configuration files in
`parse_args()`:

```
config.parse_args(argv=['--no-nsfw_checker'],
                             conf = OmegaConf.load('/tmp/test.yaml')
)
```

For backward compatibility, the `get_invokeai_config()` function is
still available from the module, but has been removed from the rest of
the source tree.
This commit is contained in:
Lincoln Stein 2023-06-05 15:26:50 -07:00 committed by GitHub
commit b31fc43bfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 165 additions and 120 deletions

View File

@ -39,7 +39,8 @@ socket_io = SocketIO(app)
# initialize config # initialize config
# this is a module global # this is a module global
app_config = InvokeAIAppConfig() app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
# Add startup event to load dependencies # Add startup event to load dependencies
@app.on_event("startup") @app.on_event("startup")

View File

@ -38,7 +38,7 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
from .services.config import get_invokeai_config from .services.config import InvokeAIAppConfig
class CliCommand(BaseModel): class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -197,7 +197,8 @@ logger = logger.InvokeAILogger.getLogger()
def invoke_cli(): def invoke_cli():
# this gets the basic configuration # this gets the basic configuration
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
config.parse_args()
# get the optional list of invocations to execute on the command line # get the optional list of invocations to execute on the command line
parser = config.get_parser() parser = config.get_parser()

View File

@ -51,18 +51,32 @@ in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time: OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml') omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf) conf = InvokeAIAppConfig()
conf.parse_args(conf=omegaconf)
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
initialization time. You may pass a list of strings in the optional at initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv: `argv` argument to use instead of the system argv:
conf = InvokeAIAppConfig(arg=['--xformers_enabled']) conf.parse_args(argv=['--xformers_enabled'])
It is also possible to set a value at initialization time. This value It is also possible to set a value at initialization time. However, if
has highest priority. you call parse_args() it may be overwritten.
conf = InvokeAIAppConfig(xformers_enabled=True) conf = InvokeAIAppConfig(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# False
To avoid this, use `get_config()` to retrieve the application-wide
configuration object. This will retain any properties set at object
creation time:
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# True
Any setting can be overwritten by setting an environment variable of Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in: form: "INVOKEAI_<setting>", as in:
@ -76,18 +90,23 @@ Order of precedence (from highest):
4) config file options 4) config file options
5) pydantic defaults 5) pydantic defaults
Typical usage: Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.invocations.generate import TextToImageInvocation
# get global configuration and print its nsfw_checker value # get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig() conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.nsfw_checker)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
print(conf.nsfw_checker) print(conf.nsfw_checker)
# get the text2image invocation and print its step value
text2image = TextToImageInvocation()
print(text2image.steps)
Computed properties: Computed properties:
@ -103,10 +122,11 @@ a Path object:
lora_path - path to the LoRA directory lora_path - path to the LoRA directory
In most cases, you will want to create a single InvokeAIAppConfig In most cases, you will want to create a single InvokeAIAppConfig
object for the entire application. The get_invokeai_config() function object for the entire application. The InvokeAIAppConfig.get_config() function
does this: does this:
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
config.parse_args() # read values from the command line/config file
print(config.root) print(config.root)
# Subclassing # Subclassing
@ -140,7 +160,9 @@ two configs are kept in separate sections of the config file:
legacy_conf_dir: configs/stable-diffusion legacy_conf_dir: configs/stable-diffusion
outdir: outputs outdir: outputs
... ...
''' '''
from __future__ import annotations
import argparse import argparse
import pydoc import pydoc
import os import os
@ -154,9 +176,6 @@ from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_t
INIT_FILE = Path('invokeai.yaml') INIT_FILE = Path('invokeai.yaml')
LEGACY_INIT_FILE = Path('invokeai.init') LEGACY_INIT_FILE = Path('invokeai.init')
# This global stores a singleton InvokeAIAppConfig configuration object
global_config = None
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
''' '''
Runtime configuration settings in which default values are Runtime configuration settings in which default values are
@ -329,6 +348,9 @@ the command-line client (recommended for experts only), or
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>. setting environment variables INVOKEAI_<setting>.
''' '''
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
#fmt: off #fmt: off
type: Literal["InvokeAI"] = "InvokeAI" type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
@ -373,33 +395,44 @@ setting environment variables INVOKEAI_<setting>.
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging") log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
#fmt: on #fmt: on
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs): def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
''' '''
Initialize InvokeAIAppconfig. Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object :param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list :param argv: aternate sys.argv list
:param **kwargs: attributes to initialize with :param clobber: ovewrite any initialization parameters passed during initialization
''' '''
super().__init__(**kwargs)
# Set the runtime root directory. We parse command-line switches here # Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option. # in order to pick up the --root_dir option.
self.parse_args(argv) super().parse_args(argv)
if conf is None: if conf is None:
try: try:
conf = OmegaConf.load(self.root_dir / INIT_FILE) conf = OmegaConf.load(self.root_dir / INIT_FILE)
except: except:
pass pass
InvokeAISettings.initconf = conf InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file # parse args again in order to pick up settings in configuration file
self.parse_args(argv) super().parse_args(argv)
# restore initialization values if self.singleton_init and not clobber:
hints = get_type_hints(self) hints = get_type_hints(self.__class__)
for k in kwargs: for k in self.singleton_init:
setattr(self,k,parse_obj_as(hints[k],kwargs[k])) setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
@classmethod
def get_config(cls,**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
'''
if cls.singleton_config is None \
or type(cls.singleton_config)!=cls \
or (kwargs and cls.singleton_init != kwargs):
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
@property @property
def root_path(self)->Path: def root_path(self)->Path:
''' '''
@ -517,11 +550,8 @@ class PagingArgumentParser(argparse.ArgumentParser):
text = self.format_help() text = self.format_help()
pydoc.pager(text) pydoc.pager(text)
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig: def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
''' '''
This returns a singleton InvokeAIAppConfig configuration object. Legacy function which returns InvokeAIAppConfig.get_config()
''' '''
global global_config return InvokeAIAppConfig.get_config(**kwargs)
if global_config is None or type(global_config)!=cls:
global_config = cls(**kwargs)
return global_config

View File

@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._table_name = table_name self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock() self._lock = Lock()
self._conn = sqlite3.connect( self._conn = sqlite3.connect(
self._filename, check_same_thread=False self._filename, check_same_thread=False
) # TODO: figure out a better threading solution ) # TODO: figure out a better threading solution

View File

@ -56,6 +56,8 @@ from invokeai.backend.config.model_install_backend import (
recommended_datasets, recommended_datasets,
) )
from invokeai.app.services.config import InvokeAIAppConfig
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -63,7 +65,7 @@ transformers.logging.set_verbosity_error()
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = get_invokeai_config(argv=[]) config = InvokeAIAppConfig.get_config()
Model_dir = "models" Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/" Weights_dir = "ldm/stable-diffusion-v1/"
@ -635,7 +637,7 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace: def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig(argv=[]) opts = InvokeAIAppConfig.get_config()
outdir = Path(opts.outdir) outdir = Path(opts.outdir)
if not outdir.is_absolute(): if not outdir.is_absolute():
opts.outdir = str(config.root / opts.outdir) opts.outdir = str(config.root / opts.outdir)
@ -700,7 +702,7 @@ def write_opts(opts: Namespace, init_file: Path):
""" """
# this will load current settings # this will load current settings
config = InvokeAIAppConfig(argv=[]) config = InvokeAIAppConfig.get_config()
for key,value in opts.__dict__.items(): for key,value in opts.__dict__.items():
if hasattr(config,key): if hasattr(config,key):
setattr(config,key,value) setattr(config,key,value)
@ -732,7 +734,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
# yaml format. # yaml format.
def migrate_init_file(legacy_format:Path): def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}']) old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new = InvokeAIAppConfig(conf={}) new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys()) fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields: for attr in fields:
@ -821,8 +823,9 @@ def main():
if old_init_file.exists() and not new_init_file.exists(): if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml') print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file) migrate_init_file(old_init_file)
config = get_invokeai_config(argv=[]) # reread defaults
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
if not config.model_conf_path.exists(): if not config.model_conf_path.exists():
initialize_rootdir(config.root, opt.yes_to_all) initialize_rootdir(config.root, opt.yes_to_all)

View File

@ -19,7 +19,7 @@ from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ..model_management import ModelManager from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
@ -27,7 +27,8 @@ from ..stable_diffusion import StableDiffusionGeneratorPipeline
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
# --------------------------globals----------------------- # --------------------------globals-----------------------
config = get_invokeai_config(argv=[]) config = InvokeAIAppConfig.get_config()
Model_dir = "models" Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/" Weights_dir = "ldm/stable-diffusion-v1/"

View File

@ -6,7 +6,8 @@ be suppressed or deferred
""" """
import numpy as np import numpy as np
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class PatchMatch: class PatchMatch:
""" """
@ -21,7 +22,6 @@ class PatchMatch:
@classmethod @classmethod
def _load_patch_match(self): def _load_patch_match(self):
config = get_invokeai_config()
if self.tried_load: if self.tried_load:
return return
if config.try_patchmatch: if config.try_patchmatch:

View File

@ -33,10 +33,11 @@ from PIL import Image, ImageOps
from transformers import AutoProcessor, CLIPSegForImageSegmentation from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined" CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352 CLIPSEG_SIZE = 352
config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object): class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor): def __init__(self, image: Image, heatmap: torch.Tensor):
@ -83,7 +84,6 @@ class Txt2Mask(object):
def __init__(self, device="cpu", refined=False): def __init__(self, device="cpu", refined=False):
logger.info("Initializing clipseg model for text to mask inference") logger.info("Initializing clipseg model for text to mask inference")
config = get_invokeai_config()
# BUG: we are not doing anything with the device option at this time # BUG: we are not doing anything with the device option at this time
self.device = device self.device = device

View File

@ -26,7 +26,7 @@ import torch
from safetensors.torch import load_file from safetensors.torch import load_file
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager, SDLegacyType from .model_manager import ModelManager, SDLegacyType
@ -842,7 +842,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint): def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained( text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir "openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
) )
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
@ -897,7 +897,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint): def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir cache_dir = InvokeAIAppConfig.get_config().cache_dir
config = CLIPVisionConfig.from_pretrained( config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir "openai/clip-vit-large-patch14", cache_dir=cache_dir
) )
@ -969,7 +969,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir cache_dir = InvokeAIAppConfig.get_config().cache_dir
text_model = CLIPTextModel.from_pretrained( text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir "stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
) )
@ -1092,7 +1092,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param vae: A diffusers VAE to load into the pipeline. :param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline. :param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
""" """
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()

View File

@ -47,7 +47,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
from ..stable_diffusion import ( from ..stable_diffusion import (
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
) )
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CUDA_DEVICE, ask_user, download_with_resume from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum): class SDLegacyType(Enum):
@ -98,7 +98,7 @@ class ModelManager(object):
if not isinstance(config, DictConfig): if not isinstance(config, DictConfig):
config = OmegaConf.load(config) config = OmegaConf.load(config)
self.config = config self.config = config
self.globals = get_invokeai_config() self.globals = InvokeAIAppConfig.get_config()
self.precision = precision self.precision = precision
self.device = torch.device(device_type) self.device = torch.device(device_type)
self.max_loaded_models = max_loaded_models self.max_loaded_models = max_loaded_models
@ -1057,7 +1057,7 @@ class ModelManager(object):
""" """
# Three transformer models to check: bert, clip and safety checker, and # Three transformer models to check: bert, clip and safety checker, and
# the diffusers as well # the diffusers as well
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
models_dir = config.root_dir / "models" models_dir = config.root_dir / "models"
legacy_locations = [ legacy_locations = [
Path( Path(
@ -1287,7 +1287,7 @@ class ModelManager(object):
@classmethod @classmethod
def _delete_model_from_cache(cls,repo_id): def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(get_invokeai_config().cache_dir) cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
# I'm sure there is a way to do this with comprehensions # I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible! # but the code quickly became incomprehensible!
@ -1304,7 +1304,7 @@ class ModelManager(object):
@staticmethod @staticmethod
def _abs_path(path: str | Path) -> Path: def _abs_path(path: str | Path) -> Path:
globals = get_invokeai_config() globals = InvokeAIAppConfig.get_config()
if path is None or Path(path).is_absolute(): if path is None or Path(path).is_absolute():
return path return path
return Path(globals.root_dir, path).resolve() return Path(globals.root_dir, path).resolve()

View File

@ -21,10 +21,12 @@ from compel.prompt_parser import (
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import InvokeAIDiffuserComponent from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype from ..util import torch_dtype
config = InvokeAIAppConfig.get_config()
def get_uc_and_c_and_ec(prompt_string, def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent, model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False): log_tokens=False, skip_normalize_legacy_blend=False):
@ -37,9 +39,7 @@ def get_uc_and_c_and_ec(prompt_string,
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False, truncate_long_prompts=False,
) )
config = get_invokeai_config()
# get rid of any newline characters # get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ") prompt_string = prompt_string.replace("\n", " ")

View File

@ -6,7 +6,7 @@ import numpy as np
import torch import torch
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
pretrained_model_url = ( pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
@ -18,7 +18,7 @@ class CodeFormerRestoration:
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth" self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
) -> None: ) -> None:
self.globals = get_invokeai_config() self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists() self.codeformer_model_exists = self.model_path.exists()

View File

@ -7,11 +7,11 @@ import torch
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
class GFPGAN: class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None: def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = get_invokeai_config() self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path): if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path self.model_path = gfpgan_model_path

View File

@ -6,8 +6,8 @@ from PIL import Image
from PIL.Image import Image as ImageType from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
class ESRGAN: class ESRGAN:
def __init__(self, bg_tile_size=400) -> None: def __init__(self, bg_tile_size=400) -> None:

View File

@ -15,9 +15,11 @@ from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from .util import CPU_DEVICE from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object): class SafetyChecker(object):
CAUTION_IMG = "caution.png" CAUTION_IMG = "caution.png"
@ -26,7 +28,6 @@ class SafetyChecker(object):
caution = Image.open(path) caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device self.device = device
config = get_invokeai_config()
try: try:
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"

View File

@ -17,15 +17,16 @@ from huggingface_hub import (
hf_hub_url, hf_hub_url,
) )
import invokeai.backend.util.logging as logger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
class HuggingFaceConceptsLibrary(object): class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None): def __init__(self, root=None):
""" """
Initialize the Concepts object. May optionally pass a root directory. Initialize the Concepts object. May optionally pass a root directory.
""" """
self.config = get_invokeai_config() self.config = InvokeAIAppConfig.get_config()
self.root = root or self.config.root self.root = root or self.config.root
self.hf_api = HfApi() self.hf_api = HfApi()
self.local_concepts = dict() self.local_concepts = dict()

View File

@ -40,7 +40,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device from ..util import CPU_DEVICE, normalize_device
from .diffusion import ( from .diffusion import (
AttentionMapSaver, AttentionMapSaver,
@ -364,7 +364,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
""" """
if xformers is available, use it, otherwise use sliced attention. if xformers is available, use it, otherwise use sliced attention.
""" """
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
if ( if (
torch.cuda.is_available() torch.cuda.is_available()
and is_xformers_available() and is_xformers_available()

View File

@ -10,7 +10,7 @@ from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from .cross_attention_control import ( from .cross_attention_control import (
Arguments, Arguments,
@ -72,7 +72,7 @@ class InvokeAIDiffuserComponent:
:param model: the unet model to pass through to cross attention control :param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
""" """
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
self.conditioning = None self.conditioning = None
self.model = model self.model = model
self.is_running_diffusers = is_running_diffusers self.is_running_diffusers = is_running_diffusers
@ -112,23 +112,25 @@ class InvokeAIDiffuserComponent:
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
# self.remove_attention_map_saving() # self.remove_attention_map_saving()
def override_cross_attention( # apparently unused code
self, conditioning: ExtraConditioningInfo, step_count: int # TODO: delete
) -> Dict[str, AttentionProcessor]: # def override_cross_attention(
""" # self, conditioning: ExtraConditioningInfo, step_count: int
setup cross attention .swap control. for diffusers this replaces the attention processor, so # ) -> Dict[str, AttentionProcessor]:
the previous attention processor is returned so that the caller can restore it later. # """
""" # setup cross attention .swap control. for diffusers this replaces the attention processor, so
self.conditioning = conditioning # the previous attention processor is returned so that the caller can restore it later.
self.cross_attention_control_context = Context( # """
arguments=self.conditioning.cross_attention_control_args, # self.conditioning = conditioning
step_count=step_count, # self.cross_attention_control_context = Context(
) # arguments=self.conditioning.cross_attention_control_args,
return override_cross_attention( # step_count=step_count,
self.model, # )
self.cross_attention_control_context, # return override_cross_attention(
is_running_diffusers=self.is_running_diffusers, # self.model,
) # self.cross_attention_control_context,
# is_running_diffusers=self.is_running_diffusers,
# )
def restore_default_cross_attention( def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttentionProcessor"] = None self, restore_attention_processor: Optional["AttentionProcessor"] = None

View File

@ -88,7 +88,7 @@ def save_progress(
def parse_args(): def parse_args():
config = InvokeAIAppConfig(argv=[]) config = InvokeAIAppConfig.get_config()
parser = PagingArgumentParser( parser = PagingArgumentParser(
description="Textual inversion training" description="Textual inversion training"
) )

View File

@ -4,15 +4,15 @@ from contextlib import nullcontext
import torch import torch
from torch import autocast from torch import autocast
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda") CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps") MPS_DEVICE = torch.device("mps")
config = InvokeAIAppConfig.get_config()
def choose_torch_device() -> torch.device: def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on""" """Convenience routine for guessing which GPU device to run model on"""
config = get_invokeai_config()
if config.always_use_cpu: if config.always_use_cpu:
return CPU_DEVICE return CPU_DEVICE
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -32,7 +32,6 @@ def choose_precision(device: torch.device) -> str:
def torch_dtype(device: torch.device) -> torch.dtype: def torch_dtype(device: torch.device) -> torch.dtype:
config = get_invokeai_config()
if config.full_precision: if config.full_precision:
return torch.float32 return torch.float32
if choose_precision(device) == "float16": if choose_precision(device) == "float16":

View File

@ -40,13 +40,13 @@ from .widgets import (
TextBox, TextBox,
set_min_terminal_size, set_min_terminal_size,
) )
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
# minimum size for the UI # minimum size for the UI
MIN_COLS = 120 MIN_COLS = 120
MIN_LINES = 45 MIN_LINES = 45
config = get_invokeai_config(argv=[]) config = InvokeAIAppConfig.get_config()
class addModelsForm(npyscreen.FormMultiPage): class addModelsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled # for responsive resizing - disabled

View File

@ -20,12 +20,12 @@ from npyscreen import widget
from omegaconf import OmegaConf from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.services.config import get_invokeai_config from invokeai.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager from ...backend.model_management import ModelManager
from ...frontend.install.widgets import FloatTitleSlider from ...frontend.install.widgets import FloatTitleSlider
DEST_MERGED_MODEL_DIR = "merged_models" DEST_MERGED_MODEL_DIR = "merged_models"
config = get_invokeai_config() config = InvokeAIAppConfig.get_config()
def merge_diffusion_models( def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]], model_ids_or_paths: List[Union[str, Path]],

View File

@ -22,7 +22,7 @@ from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.training import ( from ...backend.training import (
do_textual_inversion_training, do_textual_inversion_training,
parse_args parse_args
@ -423,7 +423,7 @@ def do_front_end(args: Namespace):
save_args(args) save_args(args)
try: try:
do_textual_inversion_training(get_invokeai_config(),**args) do_textual_inversion_training(InvokeAIAppConfig.get_config(),**args)
copy_to_embeddings_folder(args) copy_to_embeddings_folder(args)
except Exception as e: except Exception as e:
logger.error("An exception occurred during training. The exception was:") logger.error("An exception occurred during training. The exception was:")
@ -436,7 +436,7 @@ def main():
global config global config
args = parse_args() args = parse_args()
config = get_invokeai_config(argv=[]) config = InvokeAIAppConfig.get_config()
# change root if needed # change root if needed
if args.root_dir: if args.root_dir:

View File

@ -6,9 +6,8 @@ from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
os.environ['INVOKEAI_ROOT']='/tmp' os.environ['INVOKEAI_ROOT']='/tmp'
sys.argv = [] # to prevent config from trying to parse pytest arguments
from invokeai.app.services.config import InvokeAIAppConfig, InvokeAISettings from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.invocations.generate import TextToImageInvocation from invokeai.app.invocations.generate import TextToImageInvocation
@ -36,48 +35,56 @@ def test_use_init():
# note that we explicitly set omegaconf dict and argv here # note that we explicitly set omegaconf dict and argv here
# so that the values aren't read from ~invokeai/invokeai.yaml and # so that the values aren't read from ~invokeai/invokeai.yaml and
# sys.argv respectively. # sys.argv respectively.
conf1 = InvokeAIAppConfig(init1,[]) conf1 = InvokeAIAppConfig.get_config()
assert conf1 assert conf1
conf1.parse_args(conf=init1)
assert conf1.max_loaded_models==5 assert conf1.max_loaded_models==5
assert not conf1.nsfw_checker assert not conf1.nsfw_checker
conf2 = InvokeAIAppConfig(init2,[]) conf2 = InvokeAIAppConfig.get_config()
assert conf2 assert conf2
conf2.parse_args(conf=init2)
assert conf2.nsfw_checker assert conf2.nsfw_checker
assert conf2.max_loaded_models==2 assert conf2.max_loaded_models==2
assert not hasattr(conf2,'invalid_attribute') assert not hasattr(conf2,'invalid_attribute')
def test_argv_override(): def test_argv_override():
conf = InvokeAIAppConfig(init1,['--nsfw_checker','--max_loaded=10']) conf = InvokeAIAppConfig.get_config()
conf.parse_args(conf=init1,argv=['--nsfw_checker','--max_loaded=10'])
assert conf.nsfw_checker assert conf.nsfw_checker
assert conf.max_loaded_models==10 assert conf.max_loaded_models==10
assert conf.outdir==Path('outputs') # this is the default assert conf.outdir==Path('outputs') # this is the default
def test_env_override(): def test_env_override():
# argv overrides # argv overrides
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10']) conf = InvokeAIAppConfig()
conf.parse_args(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==False assert conf.nsfw_checker==False
os.environ['INVOKEAI_nsfw_checker'] = 'True' os.environ['INVOKEAI_nsfw_checker'] = 'True'
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10']) conf.parse_args(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==True assert conf.nsfw_checker==True
# environment variables should be case insensitive # environment variables should be case insensitive
os.environ['InvokeAI_Max_Loaded_Models'] = '15' os.environ['InvokeAI_Max_Loaded_Models'] = '15'
conf = InvokeAIAppConfig(conf=init1) conf = InvokeAIAppConfig()
conf.parse_args(conf=init1)
assert conf.max_loaded_models == 15 assert conf.max_loaded_models == 15
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10']) conf = InvokeAIAppConfig()
conf.parse_args(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
assert conf.nsfw_checker==False assert conf.nsfw_checker==False
assert conf.max_loaded_models==10 assert conf.max_loaded_models==10
conf = InvokeAIAppConfig(conf=init1,argv=[],max_loaded_models=20) conf = InvokeAIAppConfig.get_config(max_loaded_models=20)
conf.parse_args(conf=init1,argv=[])
assert conf.max_loaded_models==20 assert conf.max_loaded_models==20
def test_type_coercion(): def test_type_coercion():
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar']) conf = InvokeAIAppConfig().get_config()
conf.parse_args(argv=['--root=/tmp/foobar'])
assert conf.root==Path('/tmp/foobar') assert conf.root==Path('/tmp/foobar')
assert isinstance(conf.root,Path) assert isinstance(conf.root,Path)
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'],root='/tmp/different') conf = InvokeAIAppConfig.get_config(root='/tmp/different')
conf.parse_args(argv=['--root=/tmp/foobar'])
assert conf.root==Path('/tmp/different') assert conf.root==Path('/tmp/different')
assert isinstance(conf.root,Path) assert isinstance(conf.root,Path)