mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixes to env parsing, textual inversion & help text
- Make environment variable settings case InSenSiTive: INVOKEAI_MAX_LOADED_MODELS and InvokeAI_Max_Loaded_Models environment variables will both set `max_loaded_models` - Updated realesrgan to use new config system. - Updated textual_inversion_training to use new config system. - Discovered a race condition when InvokeAIAppConfig is created at module load time, which makes it impossible to customize or replace the help message produced with --help on the command line. To fix this, moved all instances of get_invokeai_config() from module load time to object initialization time. Makes code cleaner, too. - Added `--from_file` argument to `invokeai-node-cli` and changed github action to match. CI tests will hopefully work now.
This commit is contained in:
parent
f9710dd6ed
commit
7ea995149e
9
.github/workflows/test-invoke-pip.yml
vendored
9
.github/workflows/test-invoke-pip.yml
vendored
@ -80,12 +80,7 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: set test prompt to main branch validation
|
- name: set test prompt to main branch validation
|
||||||
if: ${{ github.ref == 'refs/heads/main' }}
|
run:echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||||
run: echo "TEST_PROMPTS=tests/preflight_prompts.txt" >> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: set test prompt to Pull Request validation
|
|
||||||
if: ${{ github.ref != 'refs/heads/main' }}
|
|
||||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
|
||||||
|
|
||||||
- name: setup python
|
- name: setup python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
@ -131,7 +126,7 @@ jobs:
|
|||||||
--precision=float32
|
--precision=float32
|
||||||
--always_use_cpu
|
--always_use_cpu
|
||||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||||
< ${{ env.TEST_PROMPTS }}
|
--from_file ${{ env.TEST_PROMPTS }}
|
||||||
|
|
||||||
- name: Archive results
|
- name: Archive results
|
||||||
id: archive-results
|
id: archive-results
|
||||||
|
@ -4,6 +4,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
Union,
|
Union,
|
||||||
@ -196,6 +197,11 @@ def invoke_cli():
|
|||||||
parser.add_argument('commands',nargs='*')
|
parser.add_argument('commands',nargs='*')
|
||||||
invocation_commands = parser.parse_args().commands
|
invocation_commands = parser.parse_args().commands
|
||||||
|
|
||||||
|
# get the optional file to read commands from.
|
||||||
|
# Simplest is to use it for STDIN
|
||||||
|
if infile := config.from_file:
|
||||||
|
sys.stdin = open(infile,"r")
|
||||||
|
|
||||||
model_manager = get_model_manager(config,logger=logger)
|
model_manager = get_model_manager(config,logger=logger)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein)
|
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||||
|
|
||||||
'''Invokeai configuration system.
|
'''Invokeai configuration system.
|
||||||
|
|
||||||
@ -206,8 +206,16 @@ class InvokeAISettings(BaseSettings):
|
|||||||
if cls.initconf and settings_stanza in cls.initconf \
|
if cls.initconf and settings_stanza in cls.initconf \
|
||||||
else OmegaConf.create()
|
else OmegaConf.create()
|
||||||
|
|
||||||
|
# create an upcase version of the environment in
|
||||||
|
# order to achieve case-insensitive environment
|
||||||
|
# variables (the way Windows does)
|
||||||
|
upcase_environ = dict()
|
||||||
|
for key,value in os.environ.items():
|
||||||
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
fields = cls.__fields__
|
fields = cls.__fields__
|
||||||
cls.argparse_groups = {}
|
cls.argparse_groups = {}
|
||||||
|
|
||||||
for name, field in fields.items():
|
for name, field in fields.items():
|
||||||
if name not in cls._excluded():
|
if name not in cls._excluded():
|
||||||
current_default = field.default
|
current_default = field.default
|
||||||
@ -216,8 +224,8 @@ class InvokeAISettings(BaseSettings):
|
|||||||
env_name = env_prefix + '_' + name
|
env_name = env_prefix + '_' + name
|
||||||
if category in initconf and name in initconf.get(category):
|
if category in initconf and name in initconf.get(category):
|
||||||
field.default = initconf.get(category).get(name)
|
field.default = initconf.get(category).get(name)
|
||||||
if env_name in os.environ:
|
if env_name.upper() in upcase_environ:
|
||||||
field.default = os.environ[env_name]
|
field.default = upcase_environ[env_name.upper()]
|
||||||
cls.add_field_argument(parser, name, field)
|
cls.add_field_argument(parser, name, field)
|
||||||
|
|
||||||
field.default = current_default
|
field.default = current_default
|
||||||
@ -353,6 +361,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
|
|
||||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||||
@ -502,11 +511,11 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig)->InvokeAISettings:
|
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
|
||||||
'''
|
'''
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
'''
|
'''
|
||||||
global global_config
|
global global_config
|
||||||
if global_config is None or type(global_config)!=cls:
|
if global_config is None or type(global_config)!=cls:
|
||||||
global_config = cls()
|
global_config = cls(**kwargs)
|
||||||
return global_config
|
return global_config
|
||||||
|
@ -389,8 +389,8 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
for i in [
|
for i in [
|
||||||
"If you have an account at HuggingFace you may paste your access token here",
|
"If you have an account at HuggingFace you may optionally paste your access token here",
|
||||||
'to allow InvokeAI to download styles & subjects from the "Concept Library".',
|
'to allow InvokeAI to download restricted styles & subjects from the "Concept Library".',
|
||||||
"See https://huggingface.co/settings/tokens",
|
"See https://huggingface.co/settings/tokens",
|
||||||
]:
|
]:
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
@ -594,6 +594,9 @@ class editOptsForm(npyscreen.FormMultiPage):
|
|||||||
new_opts.license_acceptance = self.license_acceptance.value
|
new_opts.license_acceptance = self.license_acceptance.value
|
||||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||||
|
|
||||||
|
# widget library workaround to make max_loaded_models an int rather than a float
|
||||||
|
new_opts.max_loaded_models = int(new_opts.max_loaded_models)
|
||||||
|
|
||||||
return new_opts
|
return new_opts
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,8 +8,6 @@ import numpy as np
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
Thin class wrapper around the patchmatch function.
|
Thin class wrapper around the patchmatch function.
|
||||||
@ -23,6 +21,7 @@ class PatchMatch:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_patch_match(self):
|
def _load_patch_match(self):
|
||||||
|
config = get_invokeai_config()
|
||||||
if self.tried_load:
|
if self.tried_load:
|
||||||
return
|
return
|
||||||
if config.try_patchmatch:
|
if config.try_patchmatch:
|
||||||
|
@ -37,7 +37,6 @@ from invokeai.app.services.config import get_invokeai_config
|
|||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
CLIPSEG_SIZE = 352
|
CLIPSEG_SIZE = 352
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||||
@ -84,6 +83,7 @@ class Txt2Mask(object):
|
|||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
logger.info("Initializing clipseg model for text to mask inference")
|
logger.info("Initializing clipseg model for text to mask inference")
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -74,8 +74,6 @@ from transformers import (
|
|||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
@ -844,7 +842,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=config.cache_dir
|
"openai/clip-vit-large-patch14", cache_dir=get_invokeai_config().cache_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
@ -899,7 +897,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
|
|
||||||
|
|
||||||
def convert_paint_by_example_checkpoint(checkpoint):
|
def convert_paint_by_example_checkpoint(checkpoint):
|
||||||
cache_dir = config.cache_dir
|
cache_dir = get_invokeai_config().cache_dir
|
||||||
config = CLIPVisionConfig.from_pretrained(
|
config = CLIPVisionConfig.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
"openai/clip-vit-large-patch14", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -971,7 +969,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
|||||||
|
|
||||||
|
|
||||||
def convert_open_clip_checkpoint(checkpoint):
|
def convert_open_clip_checkpoint(checkpoint):
|
||||||
cache_dir = config.cache_dir
|
cache_dir = get_invokeai_config().cache_dir
|
||||||
text_model = CLIPTextModel.from_pretrained(
|
text_model = CLIPTextModel.from_pretrained(
|
||||||
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
@ -1094,7 +1092,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
:param vae: A diffusers VAE to load into the pipeline.
|
:param vae: A diffusers VAE to load into the pipeline.
|
||||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||||
"""
|
"""
|
||||||
|
config = get_invokeai_config()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
|
@ -68,7 +68,6 @@ class SDModelComponent(Enum):
|
|||||||
feature_extractor="feature_extractor"
|
feature_extractor="feature_extractor"
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
@ -99,6 +98,7 @@ class ModelManager(object):
|
|||||||
if not isinstance(config, DictConfig):
|
if not isinstance(config, DictConfig):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.globals = get_invokeai_config()
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
self.max_loaded_models = max_loaded_models
|
self.max_loaded_models = max_loaded_models
|
||||||
@ -291,7 +291,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
# if we are converting legacy files automatically, then
|
# if we are converting legacy files automatically, then
|
||||||
# there are no legacy ckpts!
|
# there are no legacy ckpts!
|
||||||
if config.ckpt_convert:
|
if self.globals.ckpt_convert:
|
||||||
return False
|
return False
|
||||||
info = self.model_info(model_name)
|
info = self.model_info(model_name)
|
||||||
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
|
||||||
@ -501,13 +501,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
pipeline_args: dict[str, Any] = dict(
|
pipeline_args: dict[str, Any] = dict(
|
||||||
safety_checker=None, local_files_only=not config.internet_available
|
safety_checker=None, local_files_only=not self.globals.internet_available
|
||||||
)
|
)
|
||||||
if "vae" in mconfig and mconfig["vae"] is not None:
|
if "vae" in mconfig and mconfig["vae"] is not None:
|
||||||
if vae := self._load_vae(mconfig["vae"]):
|
if vae := self._load_vae(mconfig["vae"]):
|
||||||
pipeline_args.update(vae=vae)
|
pipeline_args.update(vae=vae)
|
||||||
if not isinstance(name_or_path, Path):
|
if not isinstance(name_or_path, Path):
|
||||||
pipeline_args.update(cache_dir=config.cache_dir)
|
pipeline_args.update(cache_dir=self.globals.cache_dir)
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
pipeline_args.update(torch_dtype=torch.float16)
|
pipeline_args.update(torch_dtype=torch.float16)
|
||||||
fp_args_list = [{"revision": "fp16"}, {}]
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
@ -559,10 +559,9 @@ class ModelManager(object):
|
|||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
if not os.path.isabs(config):
|
root_dir = self.globals.root_dir
|
||||||
config = os.path.join(config.root, config)
|
config = str(root_dir / config)
|
||||||
if not os.path.isabs(weights):
|
weights = str(root_dir / weights)
|
||||||
weights = os.path.normpath(os.path.join(config.root, weights))
|
|
||||||
|
|
||||||
# Convert to diffusers and return a diffusers pipeline
|
# Convert to diffusers and return a diffusers pipeline
|
||||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||||
@ -577,11 +576,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
vae_path = None
|
vae_path = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_path = (
|
vae_path = str(root_dir / vae)
|
||||||
vae
|
|
||||||
if os.path.isabs(vae)
|
|
||||||
else os.path.normpath(os.path.join(config.root, vae))
|
|
||||||
)
|
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
@ -613,9 +608,7 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "path" in mconfig and mconfig["path"] is not None:
|
if "path" in mconfig and mconfig["path"] is not None:
|
||||||
path = Path(mconfig["path"])
|
path = self.globals.root_dir / Path(mconfig["path"])
|
||||||
if not path.is_absolute():
|
|
||||||
path = Path(config.root, path).resolve()
|
|
||||||
return path
|
return path
|
||||||
elif "repo_id" in mconfig:
|
elif "repo_id" in mconfig:
|
||||||
return mconfig["repo_id"]
|
return mconfig["repo_id"]
|
||||||
@ -863,16 +856,16 @@ class ModelManager(object):
|
|||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
self.logger.debug("SD-v1 model detected")
|
self.logger.debug("SD-v1 model detected")
|
||||||
model_config_file = config.legacy_conf_path / "v1-inference.yaml"
|
model_config_file = self.globals.legacy_conf_path / "v1-inference.yaml"
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
self.logger.debug("SD-v1 inpainting model detected")
|
self.logger.debug("SD-v1 inpainting model detected")
|
||||||
model_config_file = config.legacy_conf_path / "v1-inpainting-inference.yaml",
|
model_config_file = self.globals.legacy_conf_path / "v1-inpainting-inference.yaml",
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
self.logger.debug("SD-v2-v model detected")
|
self.logger.debug("SD-v2-v model detected")
|
||||||
model_config_file = config.legacy_conf_path / "v2-inference-v.yaml"
|
model_config_file = self.globals.legacy_conf_path / "v2-inference-v.yaml"
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
self.logger.debug("SD-v2-e model detected")
|
self.logger.debug("SD-v2-e model detected")
|
||||||
model_config_file = config.legacy_conf_path / "v2-inference.yaml"
|
model_config_file = self.globals.legacy_conf_path / "v2-inference.yaml"
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||||
@ -899,7 +892,7 @@ class ModelManager(object):
|
|||||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
self.logger.debug(f"Using VAE file {vae_path.name}")
|
||||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||||
|
|
||||||
diffuser_path = config.root / "models/converted_ckpts" / model_path.stem
|
diffuser_path = self.globals.root_dir / "models/converted_ckpts" / model_path.stem
|
||||||
model_name = self.convert_and_import(
|
model_name = self.convert_and_import(
|
||||||
model_path,
|
model_path,
|
||||||
diffusers_path=diffuser_path,
|
diffusers_path=diffuser_path,
|
||||||
@ -1032,7 +1025,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
yaml_str = OmegaConf.to_yaml(self.config)
|
yaml_str = OmegaConf.to_yaml(self.config)
|
||||||
if not os.path.isabs(config_file_path):
|
if not os.path.isabs(config_file_path):
|
||||||
config_file_path = config.model_conf_path
|
config_file_path = self.globals.model_conf_path
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||||
outfile.write(self.preamble())
|
outfile.write(self.preamble())
|
||||||
@ -1064,7 +1057,8 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
# Three transformer models to check: bert, clip and safety checker, and
|
# Three transformer models to check: bert, clip and safety checker, and
|
||||||
# the diffusers as well
|
# the diffusers as well
|
||||||
models_dir = config.root / "models"
|
config = get_invokeai_config()
|
||||||
|
models_dir = config.root_dir / "models"
|
||||||
legacy_locations = [
|
legacy_locations = [
|
||||||
Path(
|
Path(
|
||||||
models_dir,
|
models_dir,
|
||||||
@ -1138,13 +1132,12 @@ class ModelManager(object):
|
|||||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||||
dest_directory = Path(dest_directory)
|
dest_directory = Path(dest_directory)
|
||||||
if not dest_directory.is_absolute():
|
if not dest_directory.is_absolute():
|
||||||
dest_directory = config.root / dest_directory
|
dest_directory = self.globals.root_dir / dest_directory
|
||||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||||
resolved_path = download_with_resume(str(source), dest_directory)
|
resolved_path = download_with_resume(str(source), dest_directory)
|
||||||
else:
|
else:
|
||||||
if not os.path.isabs(source):
|
source = self.globals.root_dir / source
|
||||||
source = config.root / source
|
resolved_path = source
|
||||||
resolved_path = Path(source)
|
|
||||||
return resolved_path
|
return resolved_path
|
||||||
|
|
||||||
def _invalidate_cached_model(self, model_name: str) -> None:
|
def _invalidate_cached_model(self, model_name: str) -> None:
|
||||||
@ -1194,7 +1187,7 @@ class ModelManager(object):
|
|||||||
path = name_or_path
|
path = name_or_path
|
||||||
else:
|
else:
|
||||||
owner, repo = name_or_path.split("/")
|
owner, repo = name_or_path.split("/")
|
||||||
path = Path(config.cache_dir / f"models--{owner}--{repo}")
|
path = self.globals.cache_dir / f"models--{owner}--{repo}"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return None
|
return None
|
||||||
hashpath = path / "checksum.sha256"
|
hashpath = path / "checksum.sha256"
|
||||||
@ -1255,8 +1248,8 @@ class ModelManager(object):
|
|||||||
using_fp16 = self.precision == "float16"
|
using_fp16 = self.precision == "float16"
|
||||||
|
|
||||||
vae_args.update(
|
vae_args.update(
|
||||||
cache_dir=config.cache_dir,
|
cache_dir=self.globals.cache_dir,
|
||||||
local_files_only=not config.internet_available,
|
local_files_only=not self.globals.internet_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||||
@ -1294,7 +1287,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
def _delete_model_from_cache(cls,repo_id):
|
||||||
cache_info = scan_cache_dir(config.cache_dir)
|
cache_info = scan_cache_dir(get_invokeai_config().cache_dir)
|
||||||
|
|
||||||
# I'm sure there is a way to do this with comprehensions
|
# I'm sure there is a way to do this with comprehensions
|
||||||
# but the code quickly became incomprehensible!
|
# but the code quickly became incomprehensible!
|
||||||
@ -1311,9 +1304,10 @@ class ModelManager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _abs_path(path: str | Path) -> Path:
|
def _abs_path(path: str | Path) -> Path:
|
||||||
|
globals = get_invokeai_config()
|
||||||
if path is None or Path(path).is_absolute():
|
if path is None or Path(path).is_absolute():
|
||||||
return path
|
return path
|
||||||
return Path(config.root, path).resolve()
|
return Path(globals.root_dir, path).resolve()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_huggingface_hub_directory_present() -> bool:
|
def _is_huggingface_hub_directory_present() -> bool:
|
||||||
|
@ -25,8 +25,6 @@ from invokeai.app.services.config import get_invokeai_config
|
|||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..util import torch_dtype
|
from ..util import torch_dtype
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string,
|
def get_uc_and_c_and_ec(prompt_string,
|
||||||
model: InvokeAIDiffuserComponent,
|
model: InvokeAIDiffuserComponent,
|
||||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
@ -41,6 +39,8 @@ def get_uc_and_c_and_ec(prompt_string,
|
|||||||
truncate_long_prompts=False,
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
# get rid of any newline characters
|
# get rid of any newline characters
|
||||||
prompt_string = prompt_string.replace("\n", " ")
|
prompt_string = prompt_string.replace("\n", " ")
|
||||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||||
|
@ -7,7 +7,6 @@ import torch
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
|
||||||
@ -18,11 +17,11 @@ class CodeFormerRestoration:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
|
||||||
) -> None:
|
) -> None:
|
||||||
if not os.path.isabs(codeformer_dir):
|
|
||||||
codeformer_dir = os.path.join(config.root, codeformer_dir)
|
|
||||||
|
|
||||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
self.globals = get_invokeai_config()
|
||||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
codeformer_dir = self.globals.root_dir / codeformer_dir
|
||||||
|
self.model_path = codeformer_dir / codeformer_model_path
|
||||||
|
self.codeformer_model_exists = self.model_path.exists()
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
if not self.codeformer_model_exists:
|
||||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||||
@ -72,9 +71,7 @@ class CodeFormerRestoration:
|
|||||||
upscale_factor=1,
|
upscale_factor=1,
|
||||||
use_parse=True,
|
use_parse=True,
|
||||||
device=device,
|
device=device,
|
||||||
model_rootpath=os.path.join(
|
model_rootpath = self.globals.root_dir / "gfpgan" / "weights"
|
||||||
config.root, "models", "gfpgan", "weights"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
face_helper.clean_all()
|
face_helper.clean_all()
|
||||||
face_helper.read_image(bgr_image_array)
|
face_helper.read_image(bgr_image_array)
|
||||||
|
@ -8,14 +8,12 @@ from PIL import Image
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||||
|
self.globals = get_invokeai_config()
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
gfpgan_model_path = os.path.abspath(
|
gfpgan_model_path = self.globals.root_dir / gfpgan_model_path
|
||||||
os.path.join(config.root, gfpgan_model_path)
|
|
||||||
)
|
|
||||||
self.model_path = gfpgan_model_path
|
self.model_path = gfpgan_model_path
|
||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
@ -34,7 +32,7 @@ class GFPGAN:
|
|||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
os.chdir(os.path.join(config.root, "models"))
|
os.chdir(self.globals.root_dir / 'models')
|
||||||
try:
|
try:
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,7 +6,8 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
@ -30,12 +30,8 @@ class ESRGAN:
|
|||||||
upscale=4,
|
upscale=4,
|
||||||
act_type="prelu",
|
act_type="prelu",
|
||||||
)
|
)
|
||||||
model_path = os.path.join(
|
model_path = config.root_dir / "models/realesrgan/realesr-general-x4v3.pth"
|
||||||
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
wdn_model_path = config.root_dir / "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
||||||
)
|
|
||||||
wdn_model_path = os.path.join(
|
|
||||||
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
|
||||||
)
|
|
||||||
scale = 4
|
scale = 4
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
bg_upsampler = RealESRGANer(
|
||||||
|
@ -18,8 +18,6 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
class SafetyChecker(object):
|
class SafetyChecker(object):
|
||||||
CAUTION_IMG = "caution.png"
|
CAUTION_IMG = "caution.png"
|
||||||
|
|
||||||
@ -28,6 +26,7 @@ class SafetyChecker(object):
|
|||||||
caution = Image.open(path)
|
caution = Image.open(path)
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||||
self.device = device
|
self.device = device
|
||||||
|
config = get_invokeai_config()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
@ -19,14 +19,14 @@ from huggingface_hub import (
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
class HuggingFaceConceptsLibrary(object):
|
class HuggingFaceConceptsLibrary(object):
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
"""
|
"""
|
||||||
Initialize the Concepts object. May optionally pass a root directory.
|
Initialize the Concepts object. May optionally pass a root directory.
|
||||||
"""
|
"""
|
||||||
self.root = root or config.root
|
self.config = get_invokeai_config()
|
||||||
|
self.root = root or self.config.root
|
||||||
self.hf_api = HfApi()
|
self.hf_api = HfApi()
|
||||||
self.local_concepts = dict()
|
self.local_concepts = dict()
|
||||||
self.concept_list = None
|
self.concept_list = None
|
||||||
@ -58,7 +58,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
elif config.internet_available is True:
|
elif self.config.internet_available is True:
|
||||||
try:
|
try:
|
||||||
models = self.hf_api.list_models(
|
models = self.hf_api.list_models(
|
||||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||||
|
@ -43,8 +43,6 @@ from .diffusion import (
|
|||||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||||
from .textual_inversion_manager import TextualInversionManager
|
from .textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineIntermediateState:
|
class PipelineIntermediateState:
|
||||||
run_id: str
|
run_id: str
|
||||||
@ -348,6 +346,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
if xformers is available, use it, otherwise use sliced attention.
|
if xformers is available, use it, otherwise use sliced attention.
|
||||||
"""
|
"""
|
||||||
|
config = get_invokeai_config()
|
||||||
if (
|
if (
|
||||||
torch.cuda.is_available()
|
torch.cuda.is_available()
|
||||||
and is_xformers_available()
|
and is_xformers_available()
|
||||||
|
@ -32,8 +32,6 @@ ModelForwardCallback: TypeAlias = Union[
|
|||||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
]
|
]
|
||||||
|
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PostprocessingSettings:
|
class PostprocessingSettings:
|
||||||
threshold: float
|
threshold: float
|
||||||
@ -74,6 +72,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model: the unet model to pass through to cross attention control
|
:param model: the unet model to pass through to cross attention control
|
||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
|
config = get_invokeai_config()
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.is_running_diffusers = is_running_diffusers
|
self.is_running_diffusers = is_running_diffusers
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
This is the backend to "textual_inversion.py"
|
This is the backend to "textual_inversion.py"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -47,7 +46,7 @@ from tqdm.auto import tqdm
|
|||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
# invokeai stuff
|
# invokeai stuff
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser
|
||||||
|
|
||||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||||
PIL_INTERPOLATION = {
|
PIL_INTERPOLATION = {
|
||||||
@ -89,10 +88,9 @@ def save_progress(
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
config = InvokeAIAppConfig()
|
config = InvokeAIAppConfig(argv=[])
|
||||||
|
|
||||||
parser = PagingArgumentParser(
|
parser = PagingArgumentParser(
|
||||||
description="Textual inversion training", formatter_class=ArgFormatter
|
description="Textual inversion training"
|
||||||
)
|
)
|
||||||
general_group = parser.add_argument_group("General")
|
general_group = parser.add_argument_group("General")
|
||||||
model_group = parser.add_argument_group("Models and Paths")
|
model_group = parser.add_argument_group("Models and Paths")
|
||||||
@ -529,6 +527,7 @@ def get_full_repo_name(
|
|||||||
|
|
||||||
|
|
||||||
def do_textual_inversion_training(
|
def do_textual_inversion_training(
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
model: str,
|
model: str,
|
||||||
train_data_dir: Path,
|
train_data_dir: Path,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
@ -629,7 +628,7 @@ def do_textual_inversion_training(
|
|||||||
elif output_dir is not None:
|
elif output_dir is not None:
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
models_conf = OmegaConf.load(os.path.join(config.root, "configs/models.yaml"))
|
models_conf = OmegaConf.load(config.model_conf_path)
|
||||||
model_conf = models_conf.get(model, None)
|
model_conf = models_conf.get(model, None)
|
||||||
assert model_conf is not None, f"Unknown model: {model}"
|
assert model_conf is not None, f"Unknown model: {model}"
|
||||||
assert (
|
assert (
|
||||||
@ -641,7 +640,7 @@ def do_textual_inversion_training(
|
|||||||
assert (
|
assert (
|
||||||
pretrained_model_name_or_path
|
pretrained_model_name_or_path
|
||||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||||
pipeline_args = dict(cache_dir=config.cache_dir())
|
pipeline_args = dict(cache_dir=config.cache_dir)
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
if tokenizer_name:
|
if tokenizer_name:
|
||||||
|
@ -9,10 +9,10 @@ from invokeai.app.services.config import get_invokeai_config
|
|||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
config = get_invokeai_config()
|
|
||||||
|
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Convenience routine for guessing which GPU device to run model on"""
|
||||||
|
config = get_invokeai_config()
|
||||||
if config.always_use_cpu:
|
if config.always_use_cpu:
|
||||||
return CPU_DEVICE
|
return CPU_DEVICE
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -32,6 +32,7 @@ def choose_precision(device: torch.device) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
|
config = get_invokeai_config()
|
||||||
if config.full_precision:
|
if config.full_precision:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
if choose_precision(device) == "float16":
|
if choose_precision(device) == "float16":
|
||||||
|
@ -21,14 +21,17 @@ from npyscreen import widget
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.backend.globals import Globals, global_set_root
|
|
||||||
|
|
||||||
from ...backend.training import do_textual_inversion_training, parse_args
|
from invokeai.app.services.config import get_invokeai_config
|
||||||
|
from ...backend.training import (
|
||||||
|
do_textual_inversion_training,
|
||||||
|
parse_args
|
||||||
|
)
|
||||||
|
|
||||||
TRAINING_DATA = "text-inversion-training-data"
|
TRAINING_DATA = "text-inversion-training-data"
|
||||||
TRAINING_DIR = "text-inversion-output"
|
TRAINING_DIR = "text-inversion-output"
|
||||||
CONF_FILE = "preferences.conf"
|
CONF_FILE = "preferences.conf"
|
||||||
|
config = None
|
||||||
|
|
||||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||||
resolutions = [512, 768, 1024]
|
resolutions = [512, 768, 1024]
|
||||||
@ -122,7 +125,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
value=str(
|
value=str(
|
||||||
saved_args.get(
|
saved_args.get(
|
||||||
"train_data_dir",
|
"train_data_dir",
|
||||||
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
|
config.root_dir / TRAINING_DATA / default_placeholder_token,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
@ -135,7 +138,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
value=str(
|
value=str(
|
||||||
saved_args.get(
|
saved_args.get(
|
||||||
"output_dir",
|
"output_dir",
|
||||||
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
|
config.root_dir / TRAINING_DIR / default_placeholder_token,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
@ -241,9 +244,9 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
placeholder = self.placeholder_token.value
|
placeholder = self.placeholder_token.value
|
||||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||||
self.train_data_dir.value = str(
|
self.train_data_dir.value = str(
|
||||||
Path(Globals.root) / TRAINING_DATA / placeholder
|
config.root_dir / TRAINING_DATA / placeholder
|
||||||
)
|
)
|
||||||
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder)
|
||||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||||
|
|
||||||
def on_ok(self):
|
def on_ok(self):
|
||||||
@ -284,7 +287,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def get_model_names(self) -> Tuple[List[str], int]:
|
def get_model_names(self) -> Tuple[List[str], int]:
|
||||||
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
conf = OmegaConf.load(config.root_dir / "configs/models.yaml")
|
||||||
model_names = [
|
model_names = [
|
||||||
idx
|
idx
|
||||||
for idx in sorted(list(conf.keys()))
|
for idx in sorted(list(conf.keys()))
|
||||||
@ -367,7 +370,7 @@ def copy_to_embeddings_folder(args: dict):
|
|||||||
"""
|
"""
|
||||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
destination = config.root_dir / "embeddings" / dest_dir_name
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||||
shutil.copy(source, destination)
|
shutil.copy(source, destination)
|
||||||
@ -383,7 +386,7 @@ def save_args(args: dict):
|
|||||||
"""
|
"""
|
||||||
Save the current argument values to an omegaconf file
|
Save the current argument values to an omegaconf file
|
||||||
"""
|
"""
|
||||||
dest_dir = Path(Globals.root) / TRAINING_DIR
|
dest_dir = config.root_dir / TRAINING_DIR
|
||||||
os.makedirs(dest_dir, exist_ok=True)
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
conf_file = dest_dir / CONF_FILE
|
conf_file = dest_dir / CONF_FILE
|
||||||
conf = OmegaConf.create(args)
|
conf = OmegaConf.create(args)
|
||||||
@ -394,7 +397,7 @@ def previous_args() -> dict:
|
|||||||
"""
|
"""
|
||||||
Get the previous arguments used.
|
Get the previous arguments used.
|
||||||
"""
|
"""
|
||||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
conf_file = config.root_dir / TRAINING_DIR / CONF_FILE
|
||||||
try:
|
try:
|
||||||
conf = OmegaConf.load(conf_file)
|
conf = OmegaConf.load(conf_file)
|
||||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||||
@ -420,7 +423,7 @@ def do_front_end(args: Namespace):
|
|||||||
save_args(args)
|
save_args(args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
do_textual_inversion_training(**args)
|
do_textual_inversion_training(get_invokeai_config(),**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("An exception occurred during training. The exception was:")
|
logger.error("An exception occurred during training. The exception was:")
|
||||||
@ -430,13 +433,20 @@ def do_front_end(args: Namespace):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
global config
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
global_set_root(args.root_dir or Globals.root)
|
config = get_invokeai_config(argv=[])
|
||||||
|
|
||||||
|
# change root if needed
|
||||||
|
if args.root_dir:
|
||||||
|
config.root = args.root_dir
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if args.front_end:
|
if args.front_end:
|
||||||
do_front_end(args)
|
do_front_end(args)
|
||||||
else:
|
else:
|
||||||
do_textual_inversion_training(**vars(args))
|
do_textual_inversion_training(config,**vars(args))
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
banana sushi -Ak_lms -W640 -H480 -S42 -s20
|
|
||||||
banana sushi -Ak_lms -S42 -G1 -U 2 0.5 -s20
|
|
||||||
banana sushi -Ak_lms -S42 -v0.2 -n3 -s20
|
|
||||||
banana sushi -Ak_lms -S42 -V1349749425:0.1,4145759947:0.1 -s20
|
|
@ -58,6 +58,11 @@ def test_env_override():
|
|||||||
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
conf = InvokeAIAppConfig(conf=init1,argv=['--max_loaded=10'])
|
||||||
assert conf.nsfw_checker==True
|
assert conf.nsfw_checker==True
|
||||||
|
|
||||||
|
# environment variables should be case insensitive
|
||||||
|
os.environ['InvokeAI_Max_Loaded_Models'] = '15'
|
||||||
|
conf = InvokeAIAppConfig(conf=init1)
|
||||||
|
assert conf.max_loaded_models == 15
|
||||||
|
|
||||||
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
|
conf = InvokeAIAppConfig(conf=init1,argv=['--no-nsfw_checker','--max_loaded=10'])
|
||||||
assert conf.nsfw_checker==False
|
assert conf.nsfw_checker==False
|
||||||
assert conf.max_loaded_models==10
|
assert conf.max_loaded_models==10
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
t2i --positive_prompt 'banana sushi' --seed 42
|
t2i --positive_prompt 'banana sushi' --seed 42
|
||||||
compel --prompt 'strawberry sushi' | compel | noise | t2l --scheduler heun --steps 3 --scheduler ddim --link -3 conditioning positive_conditioning --link -2 conditioning negative_conditioning | l2i
|
compel --prompt 'strawberry sushi' | compel | noise | t2l --scheduler heun --steps 3 --scheduler ddim --link -3 conditioning positive_conditioning --link -2 conditioning negative_conditioning | l2i
|
||||||
compel --prompt 'banana sushi' | compel | noise | t2i --scheduler heun --steps 3 --scheduler euler_a --link -3 conditioning positive_conditioning --link -2 conditioning negative_conditioning
|
compel --prompt 'banana sushi' | compel | noise | t2l --scheduler heun --steps 3 --scheduler euler_a --link -3 conditioning positive_conditioning --link -2 conditioning negative_conditioning
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user