mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix potential race condition in config system (#3466)
There was a potential gotcha in the config system that was previously merged with main. The `InvokeAIAppConfig` object was configuring itself from the command line and configuration file within its initialization routine. However, this could cause it to read `argv` from the command line at unexpected times. This PR fixes the object so that it only reads from the init file and command line when its `parse_args()` method is explicitly called, which should be done at startup time in any top level script that uses it. In addition, using the `get_invokeai_config()` function to get a global version of the config object didn't feel pythonic to me, so I have changed this to `InvokeAIAppConfig.get_config()` throughout. ## Updated Usage In the main script, at startup time, do the following: ``` from invokeai.app.services.config import InvokeAIAppConfig config = InvokeAIAppConfig.get_config() config.parse_args() ``` In non-main scripts, it is not necessary (or recommended) to call `parse_args()`: ``` from invokeai.app.services.config import InvokeAIAppConfig config = InvokeAIAppConfig.get_config() ``` The configuration object properties can be overridden when `get_config()` is called by passing initialization values in the usual way. If a property is set this way, then it will not be changed by subsequent calls to `parse_args()`, but can only be changed by explicitly setting the property. ``` config = InvokeAIAppConfig.get_config(nsfw_checker=True) config.parse_args(argv=['--no-nsfw_checker']) config.nsfw_checker # True ``` You may specify alternative argv lists and configuration files in `parse_args()`: ``` config.parse_args(argv=['--no-nsfw_checker'], conf = OmegaConf.load('/tmp/test.yaml') ) ``` For backward compatibility, the `get_invokeai_config()` function is still available from the module, but has been removed from the rest of the source tree.
This commit is contained in:
commit
b31fc43bfa
@ -39,7 +39,8 @@ socket_io = SocketIO(app)
|
|||||||
|
|
||||||
# initialize config
|
# initialize config
|
||||||
# this is a module global
|
# this is a module global
|
||||||
app_config = InvokeAIAppConfig()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
app_config.parse_args()
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
|
@ -38,7 +38,7 @@ from .services.invocation_services import InvocationServices
|
|||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
from .services.config import get_invokeai_config
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
class CliCommand(BaseModel):
|
class CliCommand(BaseModel):
|
||||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||||
@ -197,7 +197,8 @@ logger = logger.InvokeAILogger.getLogger()
|
|||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
# this gets the basic configuration
|
# this gets the basic configuration
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args()
|
||||||
|
|
||||||
# get the optional list of invocations to execute on the command line
|
# get the optional list of invocations to execute on the command line
|
||||||
parser = config.get_parser()
|
parser = config.get_parser()
|
||||||
|
@ -51,18 +51,32 @@ in INVOKEAI_ROOT. You can replace supersede this by providing any
|
|||||||
OmegaConf dictionary object initialization time:
|
OmegaConf dictionary object initialization time:
|
||||||
|
|
||||||
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
omegaconf = OmegaConf.load('/tmp/init.yaml')
|
||||||
conf = InvokeAIAppConfig(conf=omegaconf)
|
conf = InvokeAIAppConfig()
|
||||||
|
conf.parse_args(conf=omegaconf)
|
||||||
|
|
||||||
By default, InvokeAIAppConfig will parse the contents of `sys.argv` at
|
InvokeAIAppConfig.parse_args() will parse the contents of `sys.argv`
|
||||||
initialization time. You may pass a list of strings in the optional
|
at initialization time. You may pass a list of strings in the optional
|
||||||
`argv` argument to use instead of the system argv:
|
`argv` argument to use instead of the system argv:
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(arg=['--xformers_enabled'])
|
conf.parse_args(argv=['--xformers_enabled'])
|
||||||
|
|
||||||
It is also possible to set a value at initialization time. This value
|
It is also possible to set a value at initialization time. However, if
|
||||||
has highest priority.
|
you call parse_args() it may be overwritten.
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(xformers_enabled=True)
|
conf = InvokeAIAppConfig(xformers_enabled=True)
|
||||||
|
conf.parse_args(argv=['--no-xformers'])
|
||||||
|
conf.xformers_enabled
|
||||||
|
# False
|
||||||
|
|
||||||
|
|
||||||
|
To avoid this, use `get_config()` to retrieve the application-wide
|
||||||
|
configuration object. This will retain any properties set at object
|
||||||
|
creation time:
|
||||||
|
|
||||||
|
conf = InvokeAIAppConfig.get_config(xformers_enabled=True)
|
||||||
|
conf.parse_args(argv=['--no-xformers'])
|
||||||
|
conf.xformers_enabled
|
||||||
|
# True
|
||||||
|
|
||||||
Any setting can be overwritten by setting an environment variable of
|
Any setting can be overwritten by setting an environment variable of
|
||||||
form: "INVOKEAI_<setting>", as in:
|
form: "INVOKEAI_<setting>", as in:
|
||||||
@ -76,18 +90,23 @@ Order of precedence (from highest):
|
|||||||
4) config file options
|
4) config file options
|
||||||
5) pydantic defaults
|
5) pydantic defaults
|
||||||
|
|
||||||
Typical usage:
|
Typical usage at the top level file:
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.invocations.generate import TextToImageInvocation
|
|
||||||
|
|
||||||
# get global configuration and print its nsfw_checker value
|
# get global configuration and print its nsfw_checker value
|
||||||
conf = InvokeAIAppConfig()
|
conf = InvokeAIAppConfig.get_config()
|
||||||
|
conf.parse_args()
|
||||||
|
print(conf.nsfw_checker)
|
||||||
|
|
||||||
|
Typical usage in a backend module:
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
# get global configuration and print its nsfw_checker value
|
||||||
|
conf = InvokeAIAppConfig.get_config()
|
||||||
print(conf.nsfw_checker)
|
print(conf.nsfw_checker)
|
||||||
|
|
||||||
# get the text2image invocation and print its step value
|
|
||||||
text2image = TextToImageInvocation()
|
|
||||||
print(text2image.steps)
|
|
||||||
|
|
||||||
Computed properties:
|
Computed properties:
|
||||||
|
|
||||||
@ -103,10 +122,11 @@ a Path object:
|
|||||||
lora_path - path to the LoRA directory
|
lora_path - path to the LoRA directory
|
||||||
|
|
||||||
In most cases, you will want to create a single InvokeAIAppConfig
|
In most cases, you will want to create a single InvokeAIAppConfig
|
||||||
object for the entire application. The get_invokeai_config() function
|
object for the entire application. The InvokeAIAppConfig.get_config() function
|
||||||
does this:
|
does this:
|
||||||
|
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
config.parse_args() # read values from the command line/config file
|
||||||
print(config.root)
|
print(config.root)
|
||||||
|
|
||||||
# Subclassing
|
# Subclassing
|
||||||
@ -140,7 +160,9 @@ two configs are kept in separate sections of the config file:
|
|||||||
legacy_conf_dir: configs/stable-diffusion
|
legacy_conf_dir: configs/stable-diffusion
|
||||||
outdir: outputs
|
outdir: outputs
|
||||||
...
|
...
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
from __future__ import annotations
|
||||||
import argparse
|
import argparse
|
||||||
import pydoc
|
import pydoc
|
||||||
import os
|
import os
|
||||||
@ -154,9 +176,6 @@ from typing import ClassVar, Dict, List, Literal, Type, Union, get_origin, get_t
|
|||||||
INIT_FILE = Path('invokeai.yaml')
|
INIT_FILE = Path('invokeai.yaml')
|
||||||
LEGACY_INIT_FILE = Path('invokeai.init')
|
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||||
|
|
||||||
# This global stores a singleton InvokeAIAppConfig configuration object
|
|
||||||
global_config = None
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
'''
|
'''
|
||||||
Runtime configuration settings in which default values are
|
Runtime configuration settings in which default values are
|
||||||
@ -329,6 +348,9 @@ the command-line client (recommended for experts only), or
|
|||||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||||
setting environment variables INVOKEAI_<setting>.
|
setting environment variables INVOKEAI_<setting>.
|
||||||
'''
|
'''
|
||||||
|
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||||
|
singleton_init: ClassVar[Dict] = None
|
||||||
|
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
@ -373,33 +395,44 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="debug", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def __init__(self, conf: DictConfig = None, argv: List[str]=None, **kwargs):
|
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||||
'''
|
'''
|
||||||
Initialize InvokeAIAppconfig.
|
Update settings with contents of init file, environment, and
|
||||||
|
command-line settings.
|
||||||
:param conf: alternate Omegaconf dictionary object
|
:param conf: alternate Omegaconf dictionary object
|
||||||
:param argv: aternate sys.argv list
|
:param argv: aternate sys.argv list
|
||||||
:param **kwargs: attributes to initialize with
|
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||||
'''
|
'''
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
# Set the runtime root directory. We parse command-line switches here
|
# Set the runtime root directory. We parse command-line switches here
|
||||||
# in order to pick up the --root_dir option.
|
# in order to pick up the --root_dir option.
|
||||||
self.parse_args(argv)
|
super().parse_args(argv)
|
||||||
if conf is None:
|
if conf is None:
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
conf = OmegaConf.load(self.root_dir / INIT_FILE)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
InvokeAISettings.initconf = conf
|
InvokeAISettings.initconf = conf
|
||||||
|
|
||||||
# parse args again in order to pick up settings in configuration file
|
# parse args again in order to pick up settings in configuration file
|
||||||
self.parse_args(argv)
|
super().parse_args(argv)
|
||||||
|
|
||||||
# restore initialization values
|
if self.singleton_init and not clobber:
|
||||||
hints = get_type_hints(self)
|
hints = get_type_hints(self.__class__)
|
||||||
for k in kwargs:
|
for k in self.singleton_init:
|
||||||
setattr(self,k,parse_obj_as(hints[k],kwargs[k]))
|
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||||
|
'''
|
||||||
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
|
'''
|
||||||
|
if cls.singleton_config is None \
|
||||||
|
or type(cls.singleton_config)!=cls \
|
||||||
|
or (kwargs and cls.singleton_init != kwargs):
|
||||||
|
cls.singleton_config = cls(**kwargs)
|
||||||
|
cls.singleton_init = kwargs
|
||||||
|
return cls.singleton_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_path(self)->Path:
|
def root_path(self)->Path:
|
||||||
'''
|
'''
|
||||||
@ -517,11 +550,8 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig:
|
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||||
'''
|
'''
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||||
'''
|
'''
|
||||||
global global_config
|
return InvokeAIAppConfig.get_config(**kwargs)
|
||||||
if global_config is None or type(global_config)!=cls:
|
|
||||||
global_config = cls(**kwargs)
|
|
||||||
return global_config
|
|
||||||
|
@ -26,7 +26,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
|
||||||
self._conn = sqlite3.connect(
|
self._conn = sqlite3.connect(
|
||||||
self._filename, check_same_thread=False
|
self._filename, check_same_thread=False
|
||||||
) # TODO: figure out a better threading solution
|
) # TODO: figure out a better threading solution
|
||||||
|
@ -56,6 +56,8 @@ from invokeai.backend.config.model_install_backend import (
|
|||||||
recommended_datasets,
|
recommended_datasets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -63,7 +65,7 @@ transformers.logging.set_verbosity_error()
|
|||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
|
|
||||||
config = get_invokeai_config(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||||
@ -635,7 +637,7 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
|||||||
|
|
||||||
|
|
||||||
def default_startup_options(init_file: Path) -> Namespace:
|
def default_startup_options(init_file: Path) -> Namespace:
|
||||||
opts = InvokeAIAppConfig(argv=[])
|
opts = InvokeAIAppConfig.get_config()
|
||||||
outdir = Path(opts.outdir)
|
outdir = Path(opts.outdir)
|
||||||
if not outdir.is_absolute():
|
if not outdir.is_absolute():
|
||||||
opts.outdir = str(config.root / opts.outdir)
|
opts.outdir = str(config.root / opts.outdir)
|
||||||
@ -700,7 +702,7 @@ def write_opts(opts: Namespace, init_file: Path):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# this will load current settings
|
# this will load current settings
|
||||||
config = InvokeAIAppConfig(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
for key,value in opts.__dict__.items():
|
for key,value in opts.__dict__.items():
|
||||||
if hasattr(config,key):
|
if hasattr(config,key):
|
||||||
setattr(config,key,value)
|
setattr(config,key,value)
|
||||||
@ -732,7 +734,7 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
|||||||
# yaml format.
|
# yaml format.
|
||||||
def migrate_init_file(legacy_format:Path):
|
def migrate_init_file(legacy_format:Path):
|
||||||
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||||
new = InvokeAIAppConfig(conf={})
|
new = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||||
for attr in fields:
|
for attr in fields:
|
||||||
@ -821,8 +823,9 @@ def main():
|
|||||||
if old_init_file.exists() and not new_init_file.exists():
|
if old_init_file.exists() and not new_init_file.exists():
|
||||||
print('** Migrating invokeai.init to invokeai.yaml')
|
print('** Migrating invokeai.init to invokeai.yaml')
|
||||||
migrate_init_file(old_init_file)
|
migrate_init_file(old_init_file)
|
||||||
config = get_invokeai_config(argv=[]) # reread defaults
|
|
||||||
|
|
||||||
|
# Load new init file into config
|
||||||
|
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||||
|
|
||||||
if not config.model_conf_path.exists():
|
if not config.model_conf_path.exists():
|
||||||
initialize_rootdir(config.root, opt.yes_to_all)
|
initialize_rootdir(config.root, opt.yes_to_all)
|
||||||
|
@ -19,7 +19,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..model_management import ModelManager
|
from ..model_management import ModelManager
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
@ -27,7 +27,8 @@ from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
|||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = get_invokeai_config(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
Model_dir = "models"
|
Model_dir = "models"
|
||||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ be suppressed or deferred
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
@ -21,7 +22,6 @@ class PatchMatch:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_patch_match(self):
|
def _load_patch_match(self):
|
||||||
config = get_invokeai_config()
|
|
||||||
if self.tried_load:
|
if self.tried_load:
|
||||||
return
|
return
|
||||||
if config.try_patchmatch:
|
if config.try_patchmatch:
|
||||||
|
@ -33,10 +33,11 @@ from PIL import Image, ImageOps
|
|||||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
CLIPSEG_SIZE = 352
|
CLIPSEG_SIZE = 352
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||||
@ -83,7 +84,6 @@ class Txt2Mask(object):
|
|||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
logger.info("Initializing clipseg model for text to mask inference")
|
logger.info("Initializing clipseg model for text to mask inference")
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -26,7 +26,7 @@ import torch
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
|
|
||||||
@ -842,7 +842,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
|
"openai/clip-vit-large-patch14", cache_dir=InvokeAIAppConfig.get_config().cache_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
@ -897,7 +897,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
|
|
||||||
|
|
||||||
def convert_paint_by_example_checkpoint(checkpoint):
|
def convert_paint_by_example_checkpoint(checkpoint):
|
||||||
cache_dir = get_invokeai_config().cache_dir
|
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||||
config = CLIPVisionConfig.from_pretrained(
|
config = CLIPVisionConfig.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -969,7 +969,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
|||||||
|
|
||||||
|
|
||||||
def convert_open_clip_checkpoint(checkpoint):
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
cache_dir = get_invokeai_config().cache_dir
|
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -1092,7 +1092,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param vae: A diffusers VAE to load into the pipeline.
|
:param vae: A diffusers VAE to load into the pipeline.
|
||||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||||
"""
|
"""
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
|
@ -47,7 +47,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|||||||
from ..stable_diffusion import (
|
from ..stable_diffusion import (
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
@ -98,7 +98,7 @@ class ModelManager(object):
|
|||||||
if not isinstance(config, DictConfig):
|
if not isinstance(config, DictConfig):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.globals = get_invokeai_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
self.max_loaded_models = max_loaded_models
|
self.max_loaded_models = max_loaded_models
|
||||||
@ -1057,7 +1057,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
# Three transformer models to check: bert, clip and safety checker, and
|
# Three transformer models to check: bert, clip and safety checker, and
|
||||||
# the diffusers as well
|
# the diffusers as well
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
models_dir = config.root_dir / "models"
|
models_dir = config.root_dir / "models"
|
||||||
legacy_locations = [
|
legacy_locations = [
|
||||||
Path(
|
Path(
|
||||||
@ -1287,7 +1287,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
def _delete_model_from_cache(cls,repo_id):
|
||||||
cache_info = scan_cache_dir(get_invokeai_config().cache_dir)
|
cache_info = scan_cache_dir(InvokeAIAppConfig.get_config().cache_dir)
|
||||||
|
|
||||||
# I'm sure there is a way to do this with comprehensions
|
# I'm sure there is a way to do this with comprehensions
|
||||||
# but the code quickly became incomprehensible!
|
# but the code quickly became incomprehensible!
|
||||||
@ -1304,7 +1304,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _abs_path(path: str | Path) -> Path:
|
def _abs_path(path: str | Path) -> Path:
|
||||||
globals = get_invokeai_config()
|
globals = InvokeAIAppConfig.get_config()
|
||||||
if path is None or Path(path).is_absolute():
|
if path is None or Path(path).is_absolute():
|
||||||
return path
|
return path
|
||||||
return Path(globals.root_dir, path).resolve()
|
return Path(globals.root_dir, path).resolve()
|
||||||
|
@ -21,10 +21,12 @@ from compel.prompt_parser import (
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..util import torch_dtype
|
from ..util import torch_dtype
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string,
|
def get_uc_and_c_and_ec(prompt_string,
|
||||||
model: InvokeAIDiffuserComponent,
|
model: InvokeAIDiffuserComponent,
|
||||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
@ -37,9 +39,7 @@ def get_uc_and_c_and_ec(prompt_string,
|
|||||||
textual_inversion_manager=model.textual_inversion_manager,
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=False,
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
# get rid of any newline characters
|
# get rid of any newline characters
|
||||||
prompt_string = prompt_string.replace("\n", " ")
|
prompt_string = prompt_string.replace("\n", " ")
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||||
@ -18,7 +18,7 @@ class CodeFormerRestoration:
|
|||||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.globals = get_invokeai_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
codeformer_dir = self.globals.root_dir / codeformer_dir
|
codeformer_dir = self.globals.root_dir / codeformer_dir
|
||||||
self.model_path = codeformer_dir / codeformer_model_path
|
self.model_path = codeformer_dir / codeformer_model_path
|
||||||
self.codeformer_model_exists = self.model_path.exists()
|
self.codeformer_model_exists = self.model_path.exists()
|
||||||
|
@ -7,11 +7,11 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||||
self.globals = get_invokeai_config()
|
self.globals = InvokeAIAppConfig.get_config()
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
||||||
self.model_path = gfpgan_model_path
|
self.model_path = gfpgan_model_path
|
||||||
|
@ -6,8 +6,8 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
|
@ -15,9 +15,11 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class SafetyChecker(object):
|
class SafetyChecker(object):
|
||||||
CAUTION_IMG = "caution.png"
|
CAUTION_IMG = "caution.png"
|
||||||
|
|
||||||
@ -26,7 +28,6 @@ class SafetyChecker(object):
|
|||||||
caution = Image.open(path)
|
caution = Image.open(path)
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||||
self.device = device
|
self.device = device
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
@ -17,15 +17,16 @@ from huggingface_hub import (
|
|||||||
hf_hub_url,
|
hf_hub_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
logger = InvokeAILogger.getLogger()
|
||||||
|
|
||||||
class HuggingFaceConceptsLibrary(object):
|
class HuggingFaceConceptsLibrary(object):
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
"""
|
"""
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
Initialize the Concepts object. May optionally pass a root directory.
|
||||||
"""
|
"""
|
||||||
self.config = get_invokeai_config()
|
self.config = InvokeAIAppConfig.get_config()
|
||||||
self.root = root or self.config.root
|
self.root = root or self.config.root
|
||||||
self.hf_api = HfApi()
|
self.hf_api = HfApi()
|
||||||
self.local_concepts = dict()
|
self.local_concepts = dict()
|
||||||
|
@ -40,7 +40,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ..util import CPU_DEVICE, normalize_device
|
from ..util import CPU_DEVICE, normalize_device
|
||||||
from .diffusion import (
|
from .diffusion import (
|
||||||
AttentionMapSaver,
|
AttentionMapSaver,
|
||||||
@ -364,7 +364,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
if (
|
if (
|
||||||
torch.cuda.is_available()
|
torch.cuda.is_available()
|
||||||
and is_xformers_available()
|
and is_xformers_available()
|
||||||
|
@ -10,7 +10,7 @@ from diffusers.models.attention_processor import AttentionProcessor
|
|||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
Arguments,
|
Arguments,
|
||||||
@ -72,7 +72,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_running_diffusers = is_running_diffusers
|
self.is_running_diffusers = is_running_diffusers
|
||||||
@ -112,23 +112,25 @@ class InvokeAIDiffuserComponent:
|
|||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
def override_cross_attention(
|
# apparently unused code
|
||||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
# TODO: delete
|
||||||
) -> Dict[str, AttentionProcessor]:
|
# def override_cross_attention(
|
||||||
"""
|
# self, conditioning: ExtraConditioningInfo, step_count: int
|
||||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
# ) -> Dict[str, AttentionProcessor]:
|
||||||
the previous attention processor is returned so that the caller can restore it later.
|
# """
|
||||||
"""
|
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
self.conditioning = conditioning
|
# the previous attention processor is returned so that the caller can restore it later.
|
||||||
self.cross_attention_control_context = Context(
|
# """
|
||||||
arguments=self.conditioning.cross_attention_control_args,
|
# self.conditioning = conditioning
|
||||||
step_count=step_count,
|
# self.cross_attention_control_context = Context(
|
||||||
)
|
# arguments=self.conditioning.cross_attention_control_args,
|
||||||
return override_cross_attention(
|
# step_count=step_count,
|
||||||
self.model,
|
# )
|
||||||
self.cross_attention_control_context,
|
# return override_cross_attention(
|
||||||
is_running_diffusers=self.is_running_diffusers,
|
# self.model,
|
||||||
)
|
# self.cross_attention_control_context,
|
||||||
|
# is_running_diffusers=self.is_running_diffusers,
|
||||||
|
# )
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
def restore_default_cross_attention(
|
||||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||||
|
@ -88,7 +88,7 @@ def save_progress(
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
config = InvokeAIAppConfig(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
parser = PagingArgumentParser(
|
parser = PagingArgumentParser(
|
||||||
description="Textual inversion training"
|
description="Textual inversion training"
|
||||||
)
|
)
|
||||||
|
@ -4,15 +4,15 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Convenience routine for guessing which GPU device to run model on"""
|
||||||
config = get_invokeai_config()
|
|
||||||
if config.always_use_cpu:
|
if config.always_use_cpu:
|
||||||
return CPU_DEVICE
|
return CPU_DEVICE
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -32,7 +32,6 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
config = get_invokeai_config()
|
|
||||||
if config.full_precision:
|
if config.full_precision:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
if choose_precision(device) == "float16":
|
if choose_precision(device) == "float16":
|
||||||
|
@ -40,13 +40,13 @@ from .widgets import (
|
|||||||
TextBox,
|
TextBox,
|
||||||
set_min_terminal_size,
|
set_min_terminal_size,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
# minimum size for the UI
|
# minimum size for the UI
|
||||||
MIN_COLS = 120
|
MIN_COLS = 120
|
||||||
MIN_LINES = 45
|
MIN_LINES = 45
|
||||||
|
|
||||||
config = get_invokeai_config(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
class addModelsForm(npyscreen.FormMultiPage):
|
class addModelsForm(npyscreen.FormMultiPage):
|
||||||
# for responsive resizing - disabled
|
# for responsive resizing - disabled
|
||||||
|
@ -20,12 +20,12 @@ from npyscreen import widget
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.services.config import get_invokeai_config
|
from invokeai.services.config import InvokeAIAppConfig
|
||||||
from ...backend.model_management import ModelManager
|
from ...backend.model_management import ModelManager
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from ...frontend.install.widgets import FloatTitleSlider
|
||||||
|
|
||||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||||
config = get_invokeai_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
model_ids_or_paths: List[Union[str, Path]],
|
model_ids_or_paths: List[Union[str, Path]],
|
||||||
|
@ -22,7 +22,7 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from ...backend.training import (
|
from ...backend.training import (
|
||||||
do_textual_inversion_training,
|
do_textual_inversion_training,
|
||||||
parse_args
|
parse_args
|
||||||
@ -423,7 +423,7 @@ def do_front_end(args: Namespace):
|
|||||||
save_args(args)
|
save_args(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
do_textual_inversion_training(get_invokeai_config(),**args)
|
do_textual_inversion_training(InvokeAIAppConfig.get_config(),**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("An exception occurred during training. The exception was:")
|
logger.error("An exception occurred during training. The exception was:")
|
||||||
@ -436,7 +436,7 @@ def main():
|
|||||||
global config
|
global config
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
config = get_invokeai_config(argv=[])
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
# change root if needed
|
# change root if needed
|
||||||
if args.root_dir:
|
if args.root_dir:
|
||||||
|
@ -6,9 +6,8 @@ from omegaconf import OmegaConf
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
os.environ['INVOKEAI_ROOT']='/tmp'
|
os.environ['INVOKEAI_ROOT']='/tmp'
|
||||||
sys.argv = [] # to prevent config from trying to parse pytest arguments
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig, InvokeAISettings
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.invocations.generate import TextToImageInvocation
|
from invokeai.app.invocations.generate import TextToImageInvocation
|
||||||
|
|
||||||
|
|
||||||
@ -36,48 +35,56 @@ def test_use_init():
|
|||||||
# note that we explicitly set omegaconf dict and argv here
|
# note that we explicitly set omegaconf dict and argv here
|
||||||
# so that the values aren't read from ~invokeai/invokeai.yaml and
|
# so that the values aren't read from ~invokeai/invokeai.yaml and
|
||||||
# sys.argv respectively.
|
# sys.argv respectively.
|
||||||
conf1 = InvokeAIAppConfig(init1,[])
|
conf1 = InvokeAIAppConfig.get_config()
|
||||||
assert conf1
|
assert conf1
|
||||||
|
conf1.parse_args(conf=init1)
|
||||||
assert conf1.max_loaded_models==5
|
assert conf1.max_loaded_models==5
|
||||||
assert not conf1.nsfw_checker
|
assert not conf1.nsfw_checker
|
||||||
|
|
||||||
conf2 = InvokeAIAppConfig(init2,[])
|
conf2 = InvokeAIAppConfig.get_config()
|
||||||
assert conf2
|
assert conf2
|
||||||
|
conf2.parse_args(conf=init2)
|
||||||
assert conf2.nsfw_checker
|
assert conf2.nsfw_checker
|
||||||
assert conf2.max_loaded_models==2
|
assert conf2.max_loaded_models==2
|
||||||
assert not hasattr(conf2,'invalid_attribute')
|
assert not hasattr(conf2,'invalid_attribute')
|
||||||
|
|
||||||
def test_argv_override():
|
def test_argv_override():
|
||||||
conf = InvokeAIAppConfig(init1,['--nsfw_checker','--max_loaded=10'])
|
conf = InvokeAIAppConfig.get_config()
|
||||||
|
conf.parse_args(conf=init1,argv=['--nsfw_checker','--max_loaded=10'])
|
||||||
assert conf.nsfw_checker
|
assert conf.nsfw_checker
|
||||||
assert conf.max_loaded_models==10
|
assert conf.max_loaded_models==10
|
||||||
assert conf.outdir==Path('outputs') # this is the default
|
assert conf.outdir==Path('outputs') # this is the default
|
||||||
|
|
||||||
def test_env_override():
|
def test_env_override():
|
||||||
# argv overrides
|
# argv overrides
|
||||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
conf = InvokeAIAppConfig()
|
||||||
|
conf.parse_args(conf=init1,argv=['--max_loaded=10'])
|
||||||
assert conf.nsfw_checker==False
|
assert conf.nsfw_checker==False
|
||||||
|
|
||||||
os.environ['INVOKEAI_nsfw_checker'] = 'True'
|
os.environ['INVOKEAI_nsfw_checker'] = 'True'
|
||||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
conf.parse_args(conf=init1,argv=['--max_loaded=10'])
|
||||||
assert conf.nsfw_checker==True
|
assert conf.nsfw_checker==True
|
||||||
|
|
||||||
# environment variables should be case insensitive
|
# environment variables should be case insensitive
|
||||||
os.environ['InvokeAI_Max_Loaded_Models'] = '15'
|
os.environ['InvokeAI_Max_Loaded_Models'] = '15'
|
||||||
conf = InvokeAIAppConfig(conf=init1)
|
conf = InvokeAIAppConfig()
|
||||||
|
conf.parse_args(conf=init1)
|
||||||
assert conf.max_loaded_models == 15
|
assert conf.max_loaded_models == 15
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
|
conf = InvokeAIAppConfig()
|
||||||
|
conf.parse_args(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
|
||||||
assert conf.nsfw_checker==False
|
assert conf.nsfw_checker==False
|
||||||
assert conf.max_loaded_models==10
|
assert conf.max_loaded_models==10
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(conf=init1,argv=[],max_loaded_models=20)
|
conf = InvokeAIAppConfig.get_config(max_loaded_models=20)
|
||||||
|
conf.parse_args(conf=init1,argv=[])
|
||||||
assert conf.max_loaded_models==20
|
assert conf.max_loaded_models==20
|
||||||
|
|
||||||
def test_type_coercion():
|
def test_type_coercion():
|
||||||
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'])
|
conf = InvokeAIAppConfig().get_config()
|
||||||
|
conf.parse_args(argv=['--root=/tmp/foobar'])
|
||||||
assert conf.root==Path('/tmp/foobar')
|
assert conf.root==Path('/tmp/foobar')
|
||||||
assert isinstance(conf.root,Path)
|
assert isinstance(conf.root,Path)
|
||||||
conf = InvokeAIAppConfig(argv=['--root=/tmp/foobar'],root='/tmp/different')
|
conf = InvokeAIAppConfig.get_config(root='/tmp/different')
|
||||||
|
conf.parse_args(argv=['--root=/tmp/foobar'])
|
||||||
assert conf.root==Path('/tmp/different')
|
assert conf.root==Path('/tmp/different')
|
||||||
assert isinstance(conf.root,Path)
|
assert isinstance(conf.root,Path)
|
||||||
|
Loading…
Reference in New Issue
Block a user