Fix potential race condition in config system (#3466)

There was a potential gotcha in the config system that was previously
merged with main. The `InvokeAIAppConfig` object was configuring itself
from the command line and configuration file within its initialization
routine. However, this could cause it to read `argv` from the command
line at unexpected times. This PR fixes the object so that it only reads
from the init file and command line when its `parse_args()` method is
explicitly called, which should be done at startup time in any top level
script that uses it.

In addition, using the `get_invokeai_config()` function to get a global
version of the config object didn't feel pythonic to me, so I have
changed this to `InvokeAIAppConfig.get_config()` throughout.

## Updated Usage

In the main script, at startup time, do the following:

```
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
config.parse_args()
```

In non-main scripts, it is not necessary (or recommended) to call
`parse_args()`:
```
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
```

The configuration object properties can be overridden when
`get_config()` is called by passing initialization values in the usual
way. If a property is set this way, then it will not be changed by
subsequent calls to `parse_args()`, but can only be changed by
explicitly setting the property.

```
config = InvokeAIAppConfig.get_config(nsfw_checker=True)
config.parse_args(argv=['--no-nsfw_checker'])
config.nsfw_checker
# True
```

You may specify alternative argv lists and configuration files in
`parse_args()`:

```
config.parse_args(argv=['--no-nsfw_checker'],
                             conf = OmegaConf.load('/tmp/test.yaml')
)
```

For backward compatibility, the `get_invokeai_config()` function is
still available from the module, but has been removed from the rest of
the source tree.
This commit is contained in:
Lincoln Stein 2023-06-05 15:26:50 -07:00 committed by GitHub
commit b31fc43bfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 165 additions and 120 deletions

View File

@ -39,7 +39,8 @@ socket_io = SocketIO(app)
# initialize config
# this is a module global
app_config = InvokeAIAppConfig()
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
# Add startup event to load dependencies
@app.on_event("startup")

View File

@ -38,7 +38,7 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
from .services.config import get_invokeai_config
from .services.config import InvokeAIAppConfig
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -197,7 +197,8 @@ logger = logger.InvokeAILogger.getLogger()
def invoke_cli():
# this gets the basic configuration
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
config.parse_args()
# get the optional list of invocations to execute on the command line
parser = config.get_parser()

View File

@ -51,18 +51,32 @@ in INVOKEAI_ROOT. You can replace supersede this by providing any
OmegaConf dictionary object initialization time:
omegaconf = OmegaConf.load('/tmp/init.yaml')
conf = InvokeAIAppConfig(conf=omegaconf)
conf = InvokeAIAppConfig()
conf.parse_args(conf=omegaconf)
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
initialization time. You may pass a list of strings in the optional
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
at initialization time. You may pass a list of strings in the optional
`argv` argument to use instead of the system argv:
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
conf.parse_args(argv=['--xformers_enabled'])
It is also possible to set a value at initialization time. This value
has highest priority.
It is also possible to set a value at initialization time. However, if
you call parse_args() it may be overwritten.
conf = InvokeAIAppConfig(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# False
To avoid this, use `get_config()` to retrieve the application-wide
configuration object. This will retain any properties set at object
creation time:
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
conf.parse_args(argv=['--no-xformers'])
conf.xformers_enabled
# True
Any setting can be overwritten by setting an environment variable of
form: "INVOKEAI_<setting>", as in:
@ -76,18 +90,23 @@ Order of precedence (from highest):
4) config file options
5) pydantic defaults
Typical usage:
Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.invocations.generate import TextToImageInvocation
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig()
conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.nsfw_checker)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
print(conf.nsfw_checker)
# get the text2image invocation and print its step value
text2image = TextToImageInvocation()
print(text2image.steps)
Computed properties:
@ -103,10 +122,11 @@ a Path object:
lora_path - path to the LoRA directory
In most cases, you will want to create a single InvokeAIAppConfig
object for the entire application. The get_invokeai_config() function
object for the entire application. The InvokeAIAppConfig.get_config() function
does this:
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
config.parse_args() # read values from the command line/config file
print(config.root)
# Subclassing
@ -140,7 +160,9 @@ two configs are kept in separate sections of the config file:
legacy_conf_dir: configs/stable-diffusion
outdir: outputs
...
'''
from __future__ import annotations
import argparse
import pydoc
import os
@ -154,9 +176,6 @@ from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_t
INIT_FILE = Path('invokeai.yaml')
LEGACY_INIT_FILE = Path('invokeai.init')
# This global stores a singleton InvokeAIAppConfig configuration object
global_config = None
class InvokeAISettings(BaseSettings):
'''
Runtime configuration settings in which default values are
@ -329,6 +348,9 @@ the command-line client (recommended for experts only), or
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
'''
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
#fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
@ -373,33 +395,44 @@ setting environment variables INVOKEAI_<setting>.
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
#fmt: on
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
'''
Initialize InvokeAIAppconfig.
Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
:param **kwargs: attributes to initialize with
:param clobber: ovewrite any initialization parameters passed during initialization
'''
super().__init__(**kwargs)
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
self.parse_args(argv)
super().parse_args(argv)
if conf is None:
try:
conf = OmegaConf.load(self.root_dir / INIT_FILE)
except:
pass
InvokeAISettings.initconf = conf
# parse args again in order to pick up settings in configuration file
self.parse_args(argv)
super().parse_args(argv)
# restore initialization values
hints = get_type_hints(self)
for k in kwargs:
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
if self.singleton_init and not clobber:
hints = get_type_hints(self.__class__)
for k in self.singleton_init:
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
@classmethod
def get_config(cls,**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
'''
if cls.singleton_config is None \
or type(cls.singleton_config)!=cls \
or (kwargs and cls.singleton_init != kwargs):
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
@property
def root_path(self)->Path:
'''
@ -517,11 +550,8 @@ class PagingArgumentParser(argparse.ArgumentParser):
text = self.format_help()
pydoc.pager(text)
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig:
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
Legacy function which returns InvokeAIAppConfig.get_config()
'''
global global_config
if global_config is None or type(global_config)!=cls:
global_config = cls(**kwargs)
return global_config
return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution

View File

@ -56,6 +56,8 @@ from invokeai.backend.config.model_install_backend import (
recommended_datasets,
)
from invokeai.app.services.config import InvokeAIAppConfig
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
@ -63,7 +65,7 @@ transformers.logging.set_verbosity_error()
# --------------------------globals-----------------------
config = get_invokeai_config(argv=[])
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"
@ -635,7 +637,7 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig(argv=[])
opts = InvokeAIAppConfig.get_config()
outdir = Path(opts.outdir)
if not outdir.is_absolute():
opts.outdir = str(config.root / opts.outdir)
@ -700,7 +702,7 @@ def write_opts(opts: Namespace, init_file: Path):
"""
# this will load current settings
config = InvokeAIAppConfig(argv=[])
config = InvokeAIAppConfig.get_config()
for key,value in opts.__dict__.items():
if hasattr(config,key):
setattr(config,key,value)
@ -732,7 +734,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
# yaml format.
def migrate_init_file(legacy_format:Path):
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
new = InvokeAIAppConfig(conf={})
new = InvokeAIAppConfig.get_config()
fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields:
@ -821,8 +823,9 @@ def main():
if old_init_file.exists() and not new_init_file.exists():
print('** Migrating invokeai.init to invokeai.yaml')
migrate_init_file(old_init_file)
config = get_invokeai_config(argv=[]) # reread defaults
# Load new init file into config
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
if not config.model_conf_path.exists():
initialize_rootdir(config.root, opt.yes_to_all)

View File

@ -19,7 +19,7 @@ from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from ..model_management import ModelManager
from ..stable_diffusion import StableDiffusionGeneratorPipeline
@ -27,7 +27,8 @@ from ..stable_diffusion import StableDiffusionGeneratorPipeline
warnings.filterwarnings("ignore")
# --------------------------globals-----------------------
config = get_invokeai_config(argv=[])
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Weights_dir = "ldm/stable-diffusion-v1/"

View File

@ -6,7 +6,8 @@ be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class PatchMatch:
"""
@ -21,7 +22,6 @@ class PatchMatch:
@classmethod
def _load_patch_match(self):
config = get_invokeai_config()
if self.tried_load:
return
if config.try_patchmatch:

View File

@ -33,10 +33,11 @@ from PIL import Image, ImageOps
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
config = InvokeAIAppConfig.get_config()
class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor):
@ -83,7 +84,6 @@ class Txt2Mask(object):
def __init__(self, device="cpu", refined=False):
logger.info("Initializing clipseg model for text to mask inference")
config = get_invokeai_config()
# BUG: we are not doing anything with the device option at this time
self.device = device

View File

@ -26,7 +26,7 @@ import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager, SDLegacyType
@ -842,7 +842,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
"openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
)
keys = list(checkpoint.keys())
@ -897,7 +897,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir
cache_dir = InvokeAIAppConfig.get_config().cache_dir
config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
@ -969,7 +969,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint):
cache_dir = get_invokeai_config().cache_dir
cache_dir = InvokeAIAppConfig.get_config().cache_dir
text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
)
@ -1092,7 +1092,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()

View File

@ -47,7 +47,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
@ -98,7 +98,7 @@ class ModelManager(object):
if not isinstance(config, DictConfig):
config = OmegaConf.load(config)
self.config = config
self.globals = get_invokeai_config()
self.globals = InvokeAIAppConfig.get_config()
self.precision = precision
self.device = torch.device(device_type)
self.max_loaded_models = max_loaded_models
@ -1057,7 +1057,7 @@ class ModelManager(object):
"""
# Three transformer models to check: bert, clip and safety checker, and
# the diffusers as well
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
models_dir = config.root_dir / "models"
legacy_locations = [
Path(
@ -1287,7 +1287,7 @@ class ModelManager(object):
@classmethod
def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(get_invokeai_config().cache_dir)
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
@ -1304,7 +1304,7 @@ class ModelManager(object):
@staticmethod
def _abs_path(path: str | Path) -> Path:
globals = get_invokeai_config()
globals = InvokeAIAppConfig.get_config()
if path is None or Path(path).is_absolute():
return path
return Path(globals.root_dir, path).resolve()

View File

@ -21,10 +21,12 @@ from compel.prompt_parser import (
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
config = InvokeAIAppConfig.get_config()
def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False):
@ -37,9 +39,7 @@ def get_uc_and_c_and_ec(prompt_string,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
config = get_invokeai_config()
)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")

View File

@ -6,7 +6,7 @@ import numpy as np
import torch
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
@ -18,7 +18,7 @@ class CodeFormerRestoration:
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
self.globals = get_invokeai_config()
self.globals = InvokeAIAppConfig.get_config()
codeformer_dir = self.globals.root_dir / codeformer_dir
self.model_path = codeformer_dir / codeformer_model_path
self.codeformer_model_exists = self.model_path.exists()

View File

@ -7,11 +7,11 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
self.globals = get_invokeai_config()
self.globals = InvokeAIAppConfig.get_config()
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
self.model_path = gfpgan_model_path

View File

@ -6,8 +6,8 @@ from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
config = get_invokeai_config()
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig.get_config()
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:

View File

@ -15,9 +15,11 @@ from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
@ -26,7 +28,6 @@ class SafetyChecker(object):
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
config = get_invokeai_config()
try:
safety_model_id = "CompVis/stable-diffusion-safety-checker"

View File

@ -17,15 +17,16 @@ from huggingface_hub import (
hf_hub_url,
)
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None):
"""
Initialize the Concepts object. May optionally pass a root directory.
"""
self.config = get_invokeai_config()
self.config = InvokeAIAppConfig.get_config()
self.root = root or self.config.root
self.hf_api = HfApi()
self.local_concepts = dict()

View File

@ -40,7 +40,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
@ -364,7 +364,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
"""
if xformers is available, use it, otherwise use sliced attention.
"""
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
if (
torch.cuda.is_available()
and is_xformers_available()

View File

@ -10,7 +10,7 @@ from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from .cross_attention_control import (
Arguments,
@ -72,7 +72,7 @@ class InvokeAIDiffuserComponent:
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
self.conditioning = None
self.model = model
self.is_running_diffusers = is_running_diffusers
@ -112,23 +112,25 @@ class InvokeAIDiffuserComponent:
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
def override_cross_attention(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttentionProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
"""
self.conditioning = conditioning
self.cross_attention_control_context = Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count,
)
return override_cross_attention(
self.model,
self.cross_attention_control_context,
is_running_diffusers=self.is_running_diffusers,
)
# apparently unused code
# TODO: delete
# def override_cross_attention(
# self, conditioning: ExtraConditioningInfo, step_count: int
# ) -> Dict[str, AttentionProcessor]:
# """
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
# the previous attention processor is returned so that the caller can restore it later.
# """
# self.conditioning = conditioning
# self.cross_attention_control_context = Context(
# arguments=self.conditioning.cross_attention_control_args,
# step_count=step_count,
# )
# return override_cross_attention(
# self.model,
# self.cross_attention_control_context,
# is_running_diffusers=self.is_running_diffusers,
# )
def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttentionProcessor"] = None

View File

@ -88,7 +88,7 @@ def save_progress(
def parse_args():
config = InvokeAIAppConfig(argv=[])
config = InvokeAIAppConfig.get_config()
parser = PagingArgumentParser(
description="Textual inversion training"
)

View File

@ -4,15 +4,15 @@ from contextlib import nullcontext
import torch
from torch import autocast
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
config = InvokeAIAppConfig.get_config()
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
config = get_invokeai_config()
if config.always_use_cpu:
return CPU_DEVICE
if torch.cuda.is_available():
@ -32,7 +32,6 @@ def choose_precision(device: torch.device) -> str:
def torch_dtype(device: torch.device) -> torch.dtype:
config = get_invokeai_config()
if config.full_precision:
return torch.float32
if choose_precision(device) == "float16":

View File

@ -40,13 +40,13 @@ from .widgets import (
TextBox,
set_min_terminal_size,
)
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
# minimum size for the UI
MIN_COLS = 120
MIN_LINES = 45
config = get_invokeai_config(argv=[])
config = InvokeAIAppConfig.get_config()
class addModelsForm(npyscreen.FormMultiPage):
# for responsive resizing - disabled

View File

@ -20,12 +20,12 @@ from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.services.config import get_invokeai_config
from invokeai.services.config import InvokeAIAppConfig
from ...backend.model_management import ModelManager
from ...frontend.install.widgets import FloatTitleSlider
DEST_MERGED_MODEL_DIR = "merged_models"
config = get_invokeai_config()
config = InvokeAIAppConfig.get_config()
def merge_diffusion_models(
model_ids_or_paths: List[Union[str, Path]],

View File

@ -22,7 +22,7 @@ from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from invokeai.app.services.config import InvokeAIAppConfig
from ...backend.training import (
do_textual_inversion_training,
parse_args
@ -423,7 +423,7 @@ def do_front_end(args: Namespace):
save_args(args)
try:
do_textual_inversion_training(get_invokeai_config(),**args)
do_textual_inversion_training(InvokeAIAppConfig.get_config(),**args)
copy_to_embeddings_folder(args)
except Exception as e:
logger.error("An exception occurred during training. The exception was:")
@ -436,7 +436,7 @@ def main():
global config
args = parse_args()
config = get_invokeai_config(argv=[])
config = InvokeAIAppConfig.get_config()
# change root if needed
if args.root_dir:

View File

@ -6,9 +6,8 @@ from omegaconf import OmegaConf
from pathlib import Path
os.environ['INVOKEAI_ROOT']='/tmp'
sys.argv = [] # to prevent config from trying to parse pytest arguments
from invokeai.app.services.config import InvokeAIAppConfig, InvokeAISettings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.invocations.generate import TextToImageInvocation
@ -36,48 +35,56 @@ def test_use_init():
# note that we explicitly set omegaconf dict and argv here
# so that the values aren't read from ~invokeai/invokeai.yaml and
# sys.argv respectively.
conf1 = InvokeAIAppConfig(init1,[])
conf1 = InvokeAIAppConfig.get_config()
assert conf1
conf1.parse_args(conf=init1)
assert conf1.max_loaded_models==5
assert not conf1.nsfw_checker
conf2 = InvokeAIAppConfig(init2,[])
conf2 = InvokeAIAppConfig.get_config()
assert conf2
conf2.parse_args(conf=init2)
assert conf2.nsfw_checker
assert conf2.max_loaded_models==2
assert not hasattr(conf2,'invalid_attribute')
def test_argv_override():
conf = InvokeAIAppConfig(init1,['--nsfw_checker','--max_loaded=10'])
conf = InvokeAIAppConfig.get_config()
conf.parse_args(conf=init1,argv=['--nsfw_checker','--max_loaded=10'])
assert conf.nsfw_checker
assert conf.max_loaded_models==10
assert conf.outdir==Path('outputs') # this is the default
def test_env_override():
# argv overrides
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
conf = InvokeAIAppConfig()
conf.parse_args(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==False
os.environ['INVOKEAI_nsfw_checker'] = 'True'
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
conf.parse_args(conf=init1,argv=['--max_loaded=10'])
assert conf.nsfw_checker==True
# environment variables should be case insensitive
os.environ['InvokeAI_Max_Loaded_Models'] = '15'
conf = InvokeAIAppConfig(conf=init1)
conf = InvokeAIAppConfig()
conf.parse_args(conf=init1)
assert conf.max_loaded_models == 15
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
conf = InvokeAIAppConfig()
conf.parse_args(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
assert conf.nsfw_checker==False
assert conf.max_loaded_models==10
conf = InvokeAIAppConfig(conf=init1,argv=[],max_loaded_models=20)
conf = InvokeAIAppConfig.get_config(max_loaded_models=20)
conf.parse_args(conf=init1,argv=[])
assert conf.max_loaded_models==20
def test_type_coercion():
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'])
conf = InvokeAIAppConfig().get_config()
conf.parse_args(argv=['--root=/tmp/foobar'])
assert conf.root==Path('/tmp/foobar')
assert isinstance(conf.root,Path)
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'],root='/tmp/different')
conf = InvokeAIAppConfig.get_config(root='/tmp/different')
conf.parse_args(argv=['--root=/tmp/foobar'])
assert conf.root==Path('/tmp/different')
assert isinstance(conf.root,Path)