mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add new console frontend to initial model selection, and other model mgmt improvements (#2644)
## Major Changes The invokeai-configure script has now been refactored. The work of selecting and downloading initial models at install time is now done by a script named `invokeai-model-install` (module name is `ldm.invoke.config.model_install`) Screen 1 - adjust startup options:  Screen 2 - select SD models:  The calling arguments for `invokeai-configure` have not changed, so nothing should break. After initializing the root directory, the script calls `invokeai-model-install` to let the user select the starting models to install. `invokeai-model-install puts up a console GUI with checkboxes to indicate which models to install. It respects the `--default_only` and `--yes` arguments so that CI will continue to work. Here are the various effects you can achieve: `invokeai-configure` This will use console-based UI to initialize invokeai.init, download support models, and choose and download SD models `invokeai-configure --yes` Without activating the GUI, populate invokeai.init with default values, download support models and download the "recommended" SD models `invokeai-configure --default_only` Activate the GUI for changing init options, but don't show the SD download form, and automatically download the default SD model (currently SD-1.5) `invokeai-model-install` Select and install models. This can be used to download arbitrary models from the Internet, install HuggingFace models using their repo_id, or watch a directory for models to load at startup time `invokeai-model-install --yes` Import the recommended SD models without a GUI `invokeai-model-install --default_only` As above, but only import the default model ## Flexible Model Imports The console GUI allows the user to import arbitrary models into InvokeAI using: 1. A HuggingFace Repo_id 2. A URL (http/https/ftp) that points to a checkpoint or safetensors file 3. A local path on disk pointing to a checkpoint/safetensors file or diffusers directory 4. A directory to be scanned for all checkpoint/safetensors files to be imported The UI allows the user to specify multiple models to bulk import. The user can specify whether to import the ckpt/safetensors as-is, or convert to `diffusers`. The user can also designate a directory to be scanned at startup time for checkpoint/safetensors files. ## Backend Changes To support the model selection GUI PR introduces a new method in `ldm.invoke.model_manager` called `heuristic_import(). This accepts a string-like object which can be a repo_id, URL, local path or directory. It will figure out what the object is and import it. It interrogates the contents of checkpoint and safetensors files to determine what type of SD model they are -- v1.x, v2.x or v1.x inpainting. ## Installer I am attaching a zip file of the installer if you would like to try the process from end to end. [InvokeAI-installer-v2.3.0.zip](https://github.com/invoke-ai/InvokeAI/files/10785474/InvokeAI-installer-v2.3.0.zip)
This commit is contained in:
commit
cf2eca7c60
@ -67,6 +67,8 @@ del /q .tmp1 .tmp2
|
||||
@rem -------------- Install and Configure ---------------
|
||||
|
||||
call python .\lib\main.py
|
||||
pause
|
||||
exit /b
|
||||
|
||||
@rem ------------------------ Subroutines ---------------
|
||||
@rem routine to do comparison of semantic version numbers
|
||||
|
@ -9,13 +9,16 @@ cd $scriptdir
|
||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||
|
||||
MINIMUM_PYTHON_VERSION=3.9.0
|
||||
MAXIMUM_PYTHON_VERSION=3.11.0
|
||||
PYTHON=""
|
||||
for candidate in python3.10 python3.9 python3 python python3.11 ; do
|
||||
for candidate in python3.10 python3.9 python3 python ; do
|
||||
if ppath=`which $candidate`; then
|
||||
python_version=$($ppath -V | awk '{ print $2 }')
|
||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||
PYTHON=$ppath
|
||||
break
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
@ -28,3 +31,4 @@ if [ -z "$PYTHON" ]; then
|
||||
fi
|
||||
|
||||
exec $PYTHON ./lib/main.py ${@}
|
||||
read -p "Press any key to exit"
|
||||
|
@ -11,11 +11,13 @@ echo 1. command-line
|
||||
echo 2. browser-based UI
|
||||
echo 3. run textual inversion training
|
||||
echo 4. merge models (diffusers type only)
|
||||
echo 5. re-run the configure script to download new models
|
||||
echo 6. update InvokeAI
|
||||
echo 7. open the developer console
|
||||
echo 8. command-line help
|
||||
set /P restore="Please enter 1, 2, 3, 4, 5, 6, 7 or 8: [2] "
|
||||
echo 5. download and install models
|
||||
echo 6. change InvokeAI startup options
|
||||
echo 7. re-run the configure script to fix a broken install
|
||||
echo 8. open the developer console
|
||||
echo 9. update InvokeAI
|
||||
echo 10. command-line help
|
||||
set /P restore="Please enter 1-10: [2] "
|
||||
if not defined restore set restore=2
|
||||
IF /I "%restore%" == "1" (
|
||||
echo Starting the InvokeAI command-line..
|
||||
@ -25,17 +27,20 @@ IF /I "%restore%" == "1" (
|
||||
python .venv\Scripts\invokeai.exe --web %*
|
||||
) ELSE IF /I "%restore%" == "3" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\invokeai-ti.exe --gui %*
|
||||
python .venv\Scripts\invokeai-ti.exe --gui
|
||||
) ELSE IF /I "%restore%" == "4" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\invokeai-merge.exe --gui %*
|
||||
python .venv\Scripts\invokeai-merge.exe --gui
|
||||
) ELSE IF /I "%restore%" == "5" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe %*
|
||||
echo Running invokeai-model-install...
|
||||
python .venv\Scripts\invokeai-model-install.exe
|
||||
) ELSE IF /I "%restore%" == "6" (
|
||||
echo Running invokeai-update...
|
||||
python .venv\Scripts\invokeai-update.exe %*
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||
) ELSE IF /I "%restore%" == "7" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --default_only
|
||||
) ELSE IF /I "%restore%" == "8" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@ -47,7 +52,10 @@ IF /I "%restore%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%restore%" == "8" (
|
||||
) ELSE IF /I "%restore%" == "9" (
|
||||
echo Running invokeai-update...
|
||||
python .venv\Scripts\invokeai-update.exe %*
|
||||
) ELSE IF /I "%restore%" == "10" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
pause
|
||||
|
@ -30,12 +30,14 @@ if [ "$0" != "bash" ]; then
|
||||
echo "2. browser-based UI"
|
||||
echo "3. run textual inversion training"
|
||||
echo "4. merge models (diffusers type only)"
|
||||
echo "5. re-run the configure script to download new models"
|
||||
echo "6. update InvokeAI"
|
||||
echo "7. open the developer console"
|
||||
echo "8. command-line help"
|
||||
echo "5. download and install models"
|
||||
echo "6. change InvokeAI startup options"
|
||||
echo "7. re-run the configure script to fix a broken install"
|
||||
echo "8. open the developer console"
|
||||
echo "9. update InvokeAI"
|
||||
echo "10. command-line help "
|
||||
echo ""
|
||||
read -p "Please enter 1, 2, 3, 4, 5, 6, 7 or 8: [2] " yn
|
||||
read -p "Please enter 1-10: [2] " yn
|
||||
choice=${yn:='2'}
|
||||
case $choice in
|
||||
1)
|
||||
@ -55,19 +57,24 @@ if [ "$0" != "bash" ]; then
|
||||
exec invokeai-merge --gui $@
|
||||
;;
|
||||
5)
|
||||
echo "Configuration:"
|
||||
exec invokeai-configure --root ${INVOKEAI_ROOT}
|
||||
exec invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
echo "Update:"
|
||||
exec invokeai-update
|
||||
exec invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
7)
|
||||
echo "Developer Console:"
|
||||
exec invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||
;;
|
||||
8)
|
||||
echo "Developer Console:"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
8)
|
||||
9)
|
||||
echo "Update:"
|
||||
exec invokeai-update
|
||||
;;
|
||||
10)
|
||||
exec invokeai --help
|
||||
;;
|
||||
*)
|
||||
|
@ -56,33 +56,3 @@ trinart-2.0:
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: False
|
||||
trinart-characters-2_0:
|
||||
description: An SD model finetuned with 19.2M anime/manga style images (ckpt version) (4.27 GB)
|
||||
repo_id: naclbit/trinart_derrida_characters_v2_stable_diffusion
|
||||
config: v1-inference.yaml
|
||||
file: derrida_final.ckpt
|
||||
format: ckpt
|
||||
vae:
|
||||
repo_id: naclbit/trinart_derrida_characters_v2_stable_diffusion
|
||||
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
recommended: False
|
||||
ft-mse-improved-autoencoder-840000:
|
||||
description: StabilityAI improved autoencoder fine-tuned for human faces. Improves legacy .ckpt models (335 MB)
|
||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||
format: ckpt
|
||||
config: VAE/default
|
||||
file: vae-ft-mse-840000-ema-pruned.ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
recommended: True
|
||||
trinart_vae:
|
||||
description: Custom autoencoder for trinart_characters for legacy .ckpt models only (335 MB)
|
||||
repo_id: naclbit/trinart_characters_19.2m_stable_diffusion_v1
|
||||
config: VAE/trinart
|
||||
format: ckpt
|
||||
file: autoencoder_fix_kl-f8-trinart_characters.ckpt
|
||||
width: 512
|
||||
height: 512
|
||||
recommended: False
|
||||
|
@ -211,7 +211,7 @@ class Generate:
|
||||
Globals.full_precision = self.precision == "float32"
|
||||
|
||||
if is_xformers_available():
|
||||
if not Globals.disable_xformers:
|
||||
if torch.cuda.is_available() and not Globals.disable_xformers:
|
||||
print(">> xformers memory-efficient attention is available and enabled")
|
||||
else:
|
||||
print(
|
||||
@ -221,9 +221,13 @@ class Generate:
|
||||
print(">> xformers not installed")
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_manager = ModelManager(mconfig, self.device, self.precision,
|
||||
max_loaded_models=max_loaded_models,
|
||||
sequential_offload=self.free_gpu_mem)
|
||||
self.model_manager = ModelManager(
|
||||
mconfig,
|
||||
self.device,
|
||||
self.precision,
|
||||
max_loaded_models=max_loaded_models,
|
||||
sequential_offload=self.free_gpu_mem,
|
||||
)
|
||||
# don't accept invalid models
|
||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||
model = model or fallback
|
||||
@ -246,7 +250,7 @@ class Generate:
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
try:
|
||||
print(">> Initializing safety checker")
|
||||
print(">> Initializing NSFW checker")
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
@ -270,6 +274,8 @@ class Generate:
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
print(">> NSFW checker is disabled")
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
|
@ -5,7 +5,7 @@ import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import click
|
||||
|
||||
@ -17,6 +17,7 @@ if sys.platform == "darwin":
|
||||
import pyparsing # type: ignore
|
||||
|
||||
import ldm.invoke
|
||||
|
||||
from ..generate import Generate
|
||||
from .args import (Args, dream_cmd_from_png, metadata_dumps,
|
||||
metadata_from_png)
|
||||
@ -83,6 +84,7 @@ def main():
|
||||
import transformers # type: ignore
|
||||
|
||||
from ldm.generate import Generate
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
import diffusers
|
||||
|
||||
@ -155,11 +157,14 @@ def main():
|
||||
report_model_error(opt, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := opt.autoimport:
|
||||
gen.model_manager.heuristic_import(
|
||||
str(path), convert=False, commit_to_conf=opt.conf
|
||||
)
|
||||
|
||||
if path := opt.autoconvert:
|
||||
gen.model_manager.autoconvert_weights(
|
||||
conf_path=opt.conf,
|
||||
weights_directory=path,
|
||||
gen.model_manager.heuristic_import(
|
||||
str(path), convert=True, commit_to_conf=opt.conf
|
||||
)
|
||||
|
||||
# web server loops forever
|
||||
@ -529,32 +534,25 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
"** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||
)
|
||||
else:
|
||||
import_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
try:
|
||||
import_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
except KeyboardInterrupt:
|
||||
print('\n')
|
||||
operation = None
|
||||
|
||||
elif command.startswith("!convert"):
|
||||
elif command.startswith(("!convert","!optimize")):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide the path to a .ckpt or .safetensors model")
|
||||
elif not os.path.exists(path[1]):
|
||||
print(f"** {path[1]}: model not found")
|
||||
else:
|
||||
optimize_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
try:
|
||||
convert_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
except KeyboardInterrupt:
|
||||
print('\n')
|
||||
operation = None
|
||||
|
||||
elif command.startswith("!optimize"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
print("** please provide an installed model name")
|
||||
elif not path[1] in gen.model_manager.list_models():
|
||||
print(f"** {path[1]}: model not found")
|
||||
else:
|
||||
optimize_model(path[1], gen, opt, completer)
|
||||
completer.add_history(command)
|
||||
operation = None
|
||||
|
||||
|
||||
elif command.startswith("!edit"):
|
||||
path = shlex.split(command)
|
||||
if len(path) < 2:
|
||||
@ -626,190 +624,69 @@ def set_default_output_dir(opt: Args, completer: Completer):
|
||||
completer.set_default_dir(opt.outdir)
|
||||
|
||||
|
||||
def import_model(model_path: str, gen, opt, completer):
|
||||
def import_model(model_path: str, gen, opt, completer, convert=False) -> str:
|
||||
"""
|
||||
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
|
||||
(3) a huggingface repository id; or (4) a local directory containing a
|
||||
diffusers model.
|
||||
"""
|
||||
model_path = model_path.replace('\\','/') # windows
|
||||
model_path = model_path.replace("\\", "/") # windows
|
||||
default_name = Path(model_path).stem
|
||||
model_name = None
|
||||
model_desc = None
|
||||
|
||||
if model_path.startswith(("http:", "https:", "ftp:")):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
|
||||
elif (
|
||||
os.path.exists(model_path)
|
||||
and model_path.endswith((".ckpt", ".safetensors"))
|
||||
and os.path.isfile(model_path)
|
||||
if (
|
||||
Path(model_path).is_dir()
|
||||
and not (Path(model_path) / "model_index.json").exists()
|
||||
):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
|
||||
elif os.path.isdir(model_path):
|
||||
# Allow for a directory containing multiple models.
|
||||
models = list(Path(model_path).rglob("*.ckpt")) + list(
|
||||
Path(model_path).rglob("*.safetensors")
|
||||
)
|
||||
|
||||
if models:
|
||||
# Only the last model name will be used below.
|
||||
for model in sorted(models):
|
||||
if click.confirm(f"Import {model.stem} ?", default=True):
|
||||
model_name = import_ckpt_model(model, gen, opt, completer)
|
||||
print()
|
||||
else:
|
||||
model_name = import_diffuser_model(Path(model_path), gen, opt, completer)
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"** {model_path} is neither the path to a .ckpt file nor a diffusers repository id. Can't import."
|
||||
)
|
||||
if model_path.startswith(('http:','https:')):
|
||||
try:
|
||||
default_name = url_attachment_name(model_path)
|
||||
default_name = Path(default_name).stem
|
||||
except Exception as e:
|
||||
print(f'** URL: {str(e)}')
|
||||
model_name, model_desc = _get_model_name_and_desc(
|
||||
gen.model_manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
)
|
||||
imported_name = gen.model_manager.heuristic_import(
|
||||
model_path,
|
||||
model_name=model_name,
|
||||
description=model_desc,
|
||||
convert=convert,
|
||||
)
|
||||
|
||||
if not model_name:
|
||||
if not imported_name:
|
||||
print("** Import failed or was skipped")
|
||||
return
|
||||
|
||||
if not _verify_load(model_name, gen):
|
||||
if not _verify_load(imported_name, gen):
|
||||
print("** model failed to load. Discarding configuration entry")
|
||||
gen.model_manager.del_model(model_name)
|
||||
gen.model_manager.del_model(imported_name)
|
||||
return
|
||||
if click.confirm('Make this the default model?', default=False):
|
||||
gen.model_manager.set_default_model(model_name)
|
||||
if click.confirm("Make this the default model?", default=False):
|
||||
gen.model_manager.set_default_model(imported_name)
|
||||
|
||||
gen.model_manager.commit(opt.conf)
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
print(f">> {model_name} successfully installed")
|
||||
|
||||
|
||||
def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
|
||||
'''
|
||||
Does a mass import of all the checkpoint/safetensors on a path list
|
||||
'''
|
||||
model_names = list()
|
||||
choice = input('** Directory of checkpoint/safetensors models detected. Install <a>ll or <s>elected models? [a] ') or 'a'
|
||||
do_all = choice.startswith('a')
|
||||
if do_all:
|
||||
config_file = _ask_for_config_file(models[0], completer, plural=True)
|
||||
manager = gen.model_manager
|
||||
for model in sorted(models):
|
||||
model_name = f'{model.stem}'
|
||||
model_description = f'Imported model {model_name}'
|
||||
if model_name in manager.model_names():
|
||||
print(f'** {model_name} is already imported. Skipping.')
|
||||
elif manager.import_ckpt_model(
|
||||
model,
|
||||
config = config_file,
|
||||
model_name = model_name,
|
||||
model_description = model_description,
|
||||
commit_to_conf = opt.conf):
|
||||
model_names.append(model_name)
|
||||
print(f'>> Model {model_name} imported successfully')
|
||||
else:
|
||||
print(f'** Model {model} failed to import')
|
||||
else:
|
||||
for model in sorted(models):
|
||||
if click.confirm(f'Import {model.stem} ?', default=True):
|
||||
if model_name := import_ckpt_model(model, gen, opt, completer):
|
||||
print(f'>> Model {model.stem} imported successfully')
|
||||
model_names.append(model_name)
|
||||
else:
|
||||
print(f'** Model {model} failed to import')
|
||||
print()
|
||||
return model_names
|
||||
|
||||
def import_diffuser_model(
|
||||
path_or_repo: Union[Path, str], gen, _, completer
|
||||
) -> Optional[str]:
|
||||
path_or_repo = path_or_repo.replace('\\','/') # windows
|
||||
manager = gen.model_manager
|
||||
default_name = Path(path_or_repo).stem
|
||||
default_description = f"Imported model {default_name}"
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
model_description=default_description,
|
||||
)
|
||||
vae = None
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
if not manager.import_diffuser_model(
|
||||
path_or_repo, model_name=model_name, vae=vae, description=model_description
|
||||
):
|
||||
print("** model failed to import")
|
||||
return None
|
||||
return model_name
|
||||
|
||||
def import_ckpt_model(
|
||||
path_or_url: Union[Path, str], gen, opt, completer
|
||||
) -> Optional[str]:
|
||||
path_or_url = path_or_url.replace('\\','/')
|
||||
manager = gen.model_manager
|
||||
is_a_url = str(path_or_url).startswith(('http:','https:'))
|
||||
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
|
||||
default_name = Path(base_name).stem
|
||||
default_description = f"Imported model {default_name}"
|
||||
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager,
|
||||
completer,
|
||||
model_name=default_name,
|
||||
model_description=default_description,
|
||||
)
|
||||
config_file = None
|
||||
default = (
|
||||
Path(Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml")
|
||||
if re.search("inpaint", default_name, flags=re.IGNORECASE)
|
||||
else Path(Globals.root, "configs/stable-diffusion/v1-inference.yaml")
|
||||
)
|
||||
|
||||
completer.complete_extensions((".yaml", ".yml"))
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
config_file = input("Configuration file for this model: ").strip()
|
||||
done = os.path.exists(config_file)
|
||||
|
||||
completer.complete_extensions((".ckpt", ".safetensors"))
|
||||
vae = None
|
||||
default = Path(
|
||||
Globals.root, "models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt"
|
||||
)
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
vae = input("VAE file for this model (leave blank for none): ").strip() or None
|
||||
done = (not vae) or os.path.exists(vae)
|
||||
completer.complete_extensions(None)
|
||||
|
||||
if not manager.import_ckpt_model(
|
||||
path_or_url,
|
||||
config=config_file,
|
||||
vae=vae,
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
commit_to_conf=opt.conf,
|
||||
):
|
||||
print("** model failed to import")
|
||||
return None
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def _verify_load(model_name: str, gen) -> bool:
|
||||
print(">> Verifying that new model loads...")
|
||||
current_model = gen.model_name
|
||||
try:
|
||||
if not gen.model_manager.get_model(model_name):
|
||||
if not gen.set_model(model_name):
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'** model failed to load: {str(e)}')
|
||||
print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
|
||||
print(f"** model failed to load: {str(e)}")
|
||||
print(
|
||||
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
||||
)
|
||||
return False
|
||||
if click.confirm('Keep model loaded?', default=True):
|
||||
if click.confirm("Keep model loaded?", default=True):
|
||||
gen.set_model(model_name)
|
||||
else:
|
||||
print(">> Restoring previous model")
|
||||
@ -821,6 +698,7 @@ def _get_model_name_and_desc(
|
||||
model_manager, completer, model_name: str = "", model_description: str = ""
|
||||
):
|
||||
model_name = _get_model_name(model_manager.list_models(), completer, model_name)
|
||||
model_description = model_description or f"Imported model {model_name}"
|
||||
completer.set_line(model_description)
|
||||
model_description = (
|
||||
input(f"Description for this model [{model_description}]: ").strip()
|
||||
@ -828,46 +706,11 @@ def _get_model_name_and_desc(
|
||||
)
|
||||
return model_name, model_description
|
||||
|
||||
def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path:
|
||||
default = '1'
|
||||
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
|
||||
default = '3'
|
||||
choices={
|
||||
'1': 'v1-inference.yaml',
|
||||
'2': 'v2-inference-v.yaml',
|
||||
'3': 'v1-inpainting-inference.yaml',
|
||||
}
|
||||
|
||||
prompt = '''What type of models are these?:
|
||||
[1] Models based on Stable Diffusion 1.X
|
||||
[2] Models based on Stable Diffusion 2.X
|
||||
[3] Inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else''' if plural else '''What type of model is this?:
|
||||
[1] A model based on Stable Diffusion 1.X
|
||||
[2] A model based on Stable Diffusion 2.X
|
||||
[3] An inpainting models based on Stable Diffusion 1.X
|
||||
[4] Something else'''
|
||||
print(prompt)
|
||||
choice = input(f'Your choice: [{default}] ')
|
||||
choice = choice.strip() or default
|
||||
if config_file := choices.get(choice,None):
|
||||
return Path('configs','stable-diffusion',config_file)
|
||||
|
||||
# otherwise ask user to select
|
||||
done = False
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
|
||||
while not done:
|
||||
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
|
||||
done = not config_path or os.path.exists(config_path)
|
||||
return config_path
|
||||
|
||||
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
||||
model_name_or_path = model_name_or_path.replace('\\','/') # windows
|
||||
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer) -> str:
|
||||
model_name_or_path = model_name_or_path.replace("\\", "/") # windows
|
||||
manager = gen.model_manager
|
||||
ckpt_path = None
|
||||
original_config_file=None
|
||||
|
||||
original_config_file = None
|
||||
if model_name_or_path == gen.model_name:
|
||||
print("** Can't convert the active model. !switch to another model first. **")
|
||||
return
|
||||
@ -877,61 +720,39 @@ def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
|
||||
original_config_file = Path(model_info["config"])
|
||||
model_name = model_name_or_path
|
||||
model_description = model_info["description"]
|
||||
vae = model_info["vae"]
|
||||
else:
|
||||
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||
return
|
||||
elif os.path.exists(model_name_or_path):
|
||||
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
|
||||
if not original_config_file:
|
||||
return
|
||||
ckpt_path = Path(model_name_or_path)
|
||||
model_name, model_description = _get_model_name_and_desc(
|
||||
manager, completer, ckpt_path.stem, f"Converted model {ckpt_path.stem}"
|
||||
if vae_repo := ldm.invoke.model_manager.VAE_TO_REPO_ID.get(Path(vae).stem):
|
||||
vae_repo = dict(repo_id=vae_repo)
|
||||
else:
|
||||
vae_repo = None
|
||||
model_name = manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffusers_path=Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_name_or_path
|
||||
),
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
original_config_file=original_config_file,
|
||||
vae=vae_repo,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"** {model_name_or_path} is neither an existing model nor the path to a .ckpt file"
|
||||
)
|
||||
try:
|
||||
model_name = import_model(model_name_or_path, gen, opt, completer, convert=True)
|
||||
except KeyboardInterrupt:
|
||||
return
|
||||
|
||||
if not model_name:
|
||||
print("** Conversion failed. Aborting.")
|
||||
return
|
||||
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
if original_config_file and not original_config_file.is_absolute():
|
||||
original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_name
|
||||
)
|
||||
if diffuser_path.exists():
|
||||
print(
|
||||
f"** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again."
|
||||
)
|
||||
return
|
||||
|
||||
vae = None
|
||||
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
|
||||
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
|
||||
|
||||
new_config = gen.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffuser_path,
|
||||
model_name=model_name,
|
||||
model_description=model_description,
|
||||
vae=vae,
|
||||
original_config_file=original_config_file,
|
||||
commit_to_conf=opt.conf,
|
||||
)
|
||||
if not new_config:
|
||||
return
|
||||
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
if click.confirm(f'Load optimized model {model_name}?', default=True):
|
||||
gen.set_model(model_name)
|
||||
|
||||
if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
|
||||
manager.commit(opt.conf)
|
||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||
ckpt_path.unlink(missing_ok=True)
|
||||
print(f"{ckpt_path} deleted")
|
||||
return model_name
|
||||
|
||||
|
||||
def del_config(model_name: str, gen, opt, completer):
|
||||
@ -943,11 +764,15 @@ def del_config(model_name: str, gen, opt, completer):
|
||||
print(f"** Unknown model {model_name}")
|
||||
return
|
||||
|
||||
if not click.confirm(f'Remove {model_name} from the list of models known to InvokeAI?',default=True):
|
||||
if not click.confirm(
|
||||
f"Remove {model_name} from the list of models known to InvokeAI?", default=True
|
||||
):
|
||||
return
|
||||
|
||||
delete_completely = click.confirm('Completely remove the model file or directory from disk?',default=False)
|
||||
gen.model_manager.del_model(model_name,delete_files=delete_completely)
|
||||
delete_completely = click.confirm(
|
||||
"Completely remove the model file or directory from disk?", default=False
|
||||
)
|
||||
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
||||
gen.model_manager.commit(opt.conf)
|
||||
print(f"** {model_name} deleted")
|
||||
completer.update_models(gen.model_manager.list_models())
|
||||
@ -970,13 +795,30 @@ def edit_model(model_name: str, gen, opt, completer):
|
||||
completer.set_line(info[attribute])
|
||||
info[attribute] = input(f"{attribute}: ") or info[attribute]
|
||||
|
||||
if info["format"] == "diffusers":
|
||||
vae = info.get("vae", dict(repo_id=None, path=None, subfolder=None))
|
||||
completer.set_line(vae.get("repo_id") or "stabilityai/sd-vae-ft-mse")
|
||||
vae["repo_id"] = input("External VAE repo_id: ").strip() or None
|
||||
if not vae["repo_id"]:
|
||||
completer.set_line(vae.get("path") or "")
|
||||
vae["path"] = (
|
||||
input("Path to a local diffusers VAE model (usually none): ").strip()
|
||||
or None
|
||||
)
|
||||
completer.set_line(vae.get("subfolder") or "")
|
||||
vae["subfolder"] = (
|
||||
input("Name of subfolder containing the VAE model (usually none): ").strip()
|
||||
or None
|
||||
)
|
||||
info["vae"] = vae
|
||||
|
||||
if new_name != model_name:
|
||||
manager.del_model(model_name)
|
||||
|
||||
# this does the update
|
||||
manager.add_model(new_name, info, True)
|
||||
|
||||
if click.confirm('Make this the default model?',default=False):
|
||||
if click.confirm("Make this the default model?", default=False):
|
||||
manager.set_default_model(new_name)
|
||||
manager.commit(opt.conf)
|
||||
completer.update_models(manager.list_models())
|
||||
@ -1354,7 +1196,10 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||
)
|
||||
else:
|
||||
if click.confirm('Do you want to run invokeai-configure script to select and/or reinstall models?', default=True):
|
||||
if not click.confirm(
|
||||
'Do you want to run invokeai-configure script to select and/or reinstall models?',
|
||||
default=False
|
||||
):
|
||||
return
|
||||
|
||||
print("invokeai-configure is launching....\n")
|
||||
|
@ -93,6 +93,7 @@ import shlex
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import ldm.invoke
|
||||
import ldm.invoke.pngwriter
|
||||
@ -173,10 +174,10 @@ class Args(object):
|
||||
self._arg_switches = self.parse_cmd('') # fill in defaults
|
||||
self._cmd_switches = self.parse_cmd('') # fill in defaults
|
||||
|
||||
def parse_args(self):
|
||||
def parse_args(self, args: List[str]=None):
|
||||
'''Parse the shell switches and store.'''
|
||||
sysargs = args if args is not None else sys.argv[1:]
|
||||
try:
|
||||
sysargs = sys.argv[1:]
|
||||
# pre-parse before we do any initialization to get root directory
|
||||
# and intercept --version request
|
||||
switches = self._arg_parser.parse_args(sysargs)
|
||||
@ -539,11 +540,17 @@ class Args(object):
|
||||
default=False,
|
||||
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoimport',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--autoconvert',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Check the indicated directory for .ckpt weights files at startup and import as optimized diffuser models',
|
||||
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--patchmatch',
|
||||
@ -561,8 +568,8 @@ class Args(object):
|
||||
'--outdir',
|
||||
'-o',
|
||||
type=str,
|
||||
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
|
||||
default='outputs/img-samples',
|
||||
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
||||
default='outputs',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--prompt_as_dir',
|
||||
|
@ -803,6 +803,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
extract_ema:bool=True,
|
||||
upcast_attn:bool=False,
|
||||
vae:AutoencoderKL=None,
|
||||
precision:torch.dtype=torch.float32,
|
||||
return_generator_pipeline:bool=False,
|
||||
)->Union[StableDiffusionPipeline,StableDiffusionGeneratorPipeline]:
|
||||
'''
|
||||
@ -828,6 +829,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
|
||||
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
||||
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||
running stable diffusion 2.1.
|
||||
'''
|
||||
@ -988,12 +990,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker',cache_dir=global_cache_dir("hub"))
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker",cache_dir=cache_dir)
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
safety_checker=safety_checker.to(precision),
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
|
File diff suppressed because it is too large
Load Diff
495
ldm/invoke/config/model_install.py
Normal file
495
ldm/invoke/config/model_install.py
Normal file
@ -0,0 +1,495 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
The work is actually done in backend code in model_install_backend.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ..devices import choose_precision, choose_torch_device
|
||||
from ..globals import Globals, global_config_dir
|
||||
from .model_install_backend import (Dataset_path, default_config_file,
|
||||
default_dataset, get_root,
|
||||
install_requested_models,
|
||||
recommended_datasets)
|
||||
from .widgets import (MultiSelectColumns, TextBox,
|
||||
OffsetButtonPress, CenteredTitleText)
|
||||
|
||||
class addModelsForm(npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
#FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def __init__(self, parentApp, name, multipage=False, *args, **keywords):
|
||||
self.multipage = multipage
|
||||
self.initial_models = OmegaConf.load(Dataset_path)
|
||||
try:
|
||||
self.existing_models = OmegaConf.load(default_config_file())
|
||||
except:
|
||||
self.existing_models = dict()
|
||||
self.starter_model_list = [
|
||||
x for x in list(self.initial_models.keys()) if x not in self.existing_models
|
||||
]
|
||||
self.installed_models = dict()
|
||||
super().__init__(parentApp=parentApp, name=name, *args, **keywords)
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
starter_model_labels = self._get_starter_model_labels()
|
||||
recommended_models = [
|
||||
x
|
||||
for x in self.starter_model_list
|
||||
if self.initial_models[x].get("recommended", False)
|
||||
]
|
||||
self.installed_models = sorted(
|
||||
[x for x in list(self.initial_models.keys()) if x in self.existing_models]
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields,",
|
||||
editable=False,
|
||||
color='CAUTION',
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
color='CAUTION'
|
||||
)
|
||||
self.nextrely += 1
|
||||
if len(self.installed_models) > 0:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="== INSTALLED STARTER MODELS ==",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Currently installed starter models. Uncheck to delete:",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
columns = self._get_columns()
|
||||
self.previously_installed_models = self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=columns,
|
||||
values=self.installed_models,
|
||||
value=[x for x in range(0, len(self.installed_models))],
|
||||
max_height=1 + len(self.installed_models) // columns,
|
||||
relx=4,
|
||||
slow_scroll=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.purge_deleted = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Purge deleted models from disk",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
relx=4,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="== STARTER MODELS (recommended ones selected) ==",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace:",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
|
||||
self.nextrely -= 1
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = not self.existing_models
|
||||
self.models_selected = self.add_widget_intelligent(
|
||||
npyscreen.MultiSelect,
|
||||
name="Install Starter Models",
|
||||
values=starter_model_labels,
|
||||
value=[
|
||||
self.starter_model_list.index(x)
|
||||
for x in self.starter_model_list
|
||||
if show_recommended and x in recommended_models
|
||||
],
|
||||
max_height=len(starter_model_labels) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name='== IMPORT LOCAL AND REMOTE MODELS ==',
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely -= 1
|
||||
|
||||
for line in [
|
||||
"In the box below, enter URLs, file paths, or HuggingFace repository IDs.",
|
||||
"Separate model names by lines or whitespace (Use shift-control-V to paste):",
|
||||
]:
|
||||
self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name=line,
|
||||
editable=False,
|
||||
labelColor="CONTROL",
|
||||
relx = 4,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.import_model_paths = self.add_widget_intelligent(
|
||||
TextBox, max_height=5, scroll_exit=True, editable=True, relx=4
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.show_directory_fields = self.add_widget_intelligent(
|
||||
npyscreen.FormControlCheckbox,
|
||||
name="Select a directory for models to import",
|
||||
value=False,
|
||||
)
|
||||
self.autoload_directory = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Directory (<tab> autocompletes):",
|
||||
select_dir=True,
|
||||
must_exist=True,
|
||||
use_two_lines=False,
|
||||
labelColor="DANGER",
|
||||
begin_entry_at=34,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoscan_on_startup = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scan this directory each time InvokeAI starts for new models to import",
|
||||
value=False,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||
values=["Keep original format", "Convert to diffusers"],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
max_height=4,
|
||||
hidden=True, # will appear when imported models box is edited
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_cancel,
|
||||
)
|
||||
done_label = "DONE"
|
||||
back_label = "BACK"
|
||||
button_length = len(done_label)
|
||||
button_offset = 0
|
||||
if self.multipage:
|
||||
button_length += len(back_label) + 1
|
||||
button_offset += len(back_label) + 1
|
||||
self.back_button = self.add_widget_intelligent(
|
||||
OffsetButtonPress,
|
||||
name=back_label,
|
||||
relx=(window_width - button_length) // 2,
|
||||
offset=-3,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_back,
|
||||
)
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
OffsetButtonPress,
|
||||
name=done_label,
|
||||
offset=+3,
|
||||
relx=button_offset + 1 + (window_width - button_length) // 2,
|
||||
rely=-3,
|
||||
when_pressed_function=self.on_ok,
|
||||
)
|
||||
|
||||
for i in [self.autoload_directory, self.autoscan_on_startup]:
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
self.models_selected.values = self._get_starter_model_labels()
|
||||
|
||||
def _clear_scan_directory(self):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
|
||||
def _show_hide_convert(self):
|
||||
model_paths = self.import_model_paths.value or ""
|
||||
autoload_directory = self.autoload_directory.value or ""
|
||||
self.convert_models.hidden = (
|
||||
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||
)
|
||||
|
||||
def _get_starter_model_labels(self) -> List[str]:
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
label_width = 25
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
im = self.initial_models
|
||||
names = self.starter_model_list
|
||||
descriptions = [
|
||||
im[x].description[0 : description_width - 3] + "..."
|
||||
if len(im[x].description) > description_width
|
||||
else im[x].description
|
||||
for x in names
|
||||
]
|
||||
return [
|
||||
f"%-{label_width}s %s" % (names[x], descriptions[x])
|
||||
for x in range(0, len(names))
|
||||
]
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
cols = (
|
||||
4
|
||||
if window_width > 240
|
||||
else 3
|
||||
if window_width > 160
|
||||
else 2
|
||||
if window_width > 80
|
||||
else 1
|
||||
)
|
||||
return min(cols, len(self.installed_models))
|
||||
|
||||
def on_ok(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.user_cancelled = False
|
||||
self.marshall_arguments()
|
||||
|
||||
def on_back(self):
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
|
||||
def on_cancel(self):
|
||||
if npyscreen.notify_yes_no(
|
||||
"Are you sure you want to cancel?\nYou may re-run this script later using the invoke.sh or invoke.bat command.\n"
|
||||
):
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def marshall_arguments(self):
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
True => Install
|
||||
False => Remove
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
"""
|
||||
# we're using a global here rather than storing the result in the parentapp
|
||||
# due to some bug in npyscreen that is causing attributes to be lost
|
||||
selections = self.parentApp.user_selections
|
||||
|
||||
# starter models to install/remove
|
||||
starter_models = dict(
|
||||
map(
|
||||
lambda x: (self.starter_model_list[x], True), self.models_selected.value
|
||||
)
|
||||
)
|
||||
selections.purge_deleted_models = False
|
||||
if hasattr(self, "previously_installed_models"):
|
||||
unchecked = [
|
||||
self.previously_installed_models.values[x]
|
||||
for x in range(0, len(self.previously_installed_models.values))
|
||||
if x not in self.previously_installed_models.value
|
||||
]
|
||||
starter_models.update(map(lambda x: (x, False), unchecked))
|
||||
selections.purge_deleted_models = self.purge_deleted.value
|
||||
selections.starter_models = starter_models
|
||||
|
||||
# load directory and whether to scan on startup
|
||||
if self.show_directory_fields.value:
|
||||
selections.scan_directory = self.autoload_directory.value
|
||||
selections.autoscan_on_startup = self.autoscan_on_startup.value
|
||||
else:
|
||||
selections.scan_directory = None
|
||||
selections.autoscan_on_startup = False
|
||||
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.user_cancelled = False
|
||||
self.user_selections = Namespace(
|
||||
starter_models=None,
|
||||
purge_deleted_models=False,
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models"
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
models_to_remove = [
|
||||
x for x in selections.starter_models if not selections.starter_models[x]
|
||||
]
|
||||
models_to_install = [
|
||||
x for x in selections.starter_models if selections.starter_models[x]
|
||||
]
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
convert_to_diffusers = selections.convert_to_diffusers
|
||||
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
remove_models=models_to_remove,
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
convert_to_diffusers=convert_to_diffusers,
|
||||
precision="float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
purge_deleted=selections.purge_deleted_models,
|
||||
config_file_path=Path(opt.config_file) if opt.config_file else None,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace):
|
||||
precision = (
|
||||
"float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device()))
|
||||
)
|
||||
if opt.default_only:
|
||||
install_requested_models(
|
||||
install_initial_models=default_dataset(),
|
||||
precision=precision,
|
||||
)
|
||||
elif opt.yes_to_all:
|
||||
install_requested_models(
|
||||
install_initial_models=recommended_datasets(),
|
||||
precision=precision,
|
||||
)
|
||||
else:
|
||||
installApp = AddModelApplication()
|
||||
installApp.run()
|
||||
|
||||
if not installApp.user_cancelled:
|
||||
process_and_execute(opt, installApp.user_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
"-c",
|
||||
dest="config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to configuration file to create",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
|
||||
if not global_config_dir().exists():
|
||||
print(
|
||||
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||
)
|
||||
import ldm.invoke.config.invokeai_configure
|
||||
|
||||
ldm.invoke.config.invokeai_configure.main()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
select_and_download_models(opt)
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
print(
|
||||
"** Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||
)
|
||||
elif str(e).startswith("addwstr"):
|
||||
print(
|
||||
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
else:
|
||||
print(f"** An error has occurred: {str(e)}")
|
||||
traceback.print_exc()
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
452
ldm/invoke/config/model_install_backend.py
Normal file
452
ldm/invoke/config/model_install_backend.py
Normal file
@ -0,0 +1,452 @@
|
||||
"""
|
||||
Utility (backend) functions used by model_install.py
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from huggingface_hub import hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
import invokeai.configs as configs
|
||||
from ..generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..globals import Globals, global_cache_dir, global_config_dir
|
||||
from ..model_manager import ModelManager
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
Model_dir = "models"
|
||||
Weights_dir = "ldm/stable-diffusion-v1/"
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
# initial models omegaconf
|
||||
Datasets = None
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
def default_config_file():
|
||||
return Path(global_config_dir()) / "models.yaml"
|
||||
|
||||
def sd_configs():
|
||||
return Path(global_config_dir()) / "stable-diffusion"
|
||||
|
||||
def initial_models():
|
||||
global Datasets
|
||||
if Datasets:
|
||||
return Datasets
|
||||
return (Datasets := OmegaConf.load(Dataset_path))
|
||||
|
||||
def install_requested_models(
|
||||
install_initial_models: List[str] = None,
|
||||
remove_models: List[str] = None,
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
):
|
||||
config_file_path=config_file_path or default_config_file()
|
||||
if not config_file_path.exists():
|
||||
open(config_file_path,'w')
|
||||
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
|
||||
if remove_models and len(remove_models) > 0:
|
||||
print("== DELETING UNCHECKED STARTER MODELS ==")
|
||||
for model in remove_models:
|
||||
print(f'{model}...')
|
||||
model_manager.del_model(model, delete_files=purge_deleted)
|
||||
model_manager.commit(config_file_path)
|
||||
|
||||
if install_initial_models and len(install_initial_models) > 0:
|
||||
print("== INSTALLING SELECTED STARTER MODELS ==")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models=install_initial_models,
|
||||
access_token=None,
|
||||
precision=precision,
|
||||
) # FIX: for historical reasons, we don't use model manager here
|
||||
update_config_file(successfully_downloaded, config_file_path)
|
||||
if len(successfully_downloaded) < len(install_initial_models):
|
||||
print("** Some of the model downloads were not successful")
|
||||
|
||||
# due to above, we have to reload the model manager because conf file
|
||||
# was changed behind its back
|
||||
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
||||
|
||||
external_models = external_models or list()
|
||||
if scan_directory:
|
||||
external_models.append(str(scan_directory))
|
||||
|
||||
if len(external_models)>0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
commit_to_conf=config_file_path
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||
with open(initfile,'r') as input:
|
||||
with open(replacement,'w') as output:
|
||||
while line := input.readline():
|
||||
if not line.startswith(argument):
|
||||
output.writelines([line])
|
||||
output.writelines([f'{argument} {str(scan_directory)}'])
|
||||
os.replace(replacement,initfile)
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
else:
|
||||
return Globals.root
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
if initial_models()[ds].get("default", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def all_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in initial_models().keys():
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
||||
print('The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.')
|
||||
print(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_weight_datasets(
|
||||
models: List[str], access_token: str, precision: str = "float32"
|
||||
):
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models:
|
||||
print(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
initial_models()[mod], access_token, precision=precision
|
||||
)
|
||||
return successful
|
||||
|
||||
|
||||
def _download_repo_or_file(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
) -> Path:
|
||||
path = None
|
||||
if mconfig["format"] == "ckpt":
|
||||
path = _download_ckpt_weights(mconfig, access_token)
|
||||
else:
|
||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
||||
_download_diffusion_weights(
|
||||
mconfig["vae"], access_token, precision=precision
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
model_name=filename,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
):
|
||||
print("", file=sys.stderr) # to prevent tqdm from overwriting
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model_name = "--".join(("models", *model_name.split("/")))
|
||||
return path / model_name if model else None
|
||||
|
||||
|
||||
def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
):
|
||||
repo_id = mconfig["repo_id"]
|
||||
model_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if mconfig.get("format", None) == "diffusers"
|
||||
else AutoencoderKL
|
||||
)
|
||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||
path = None
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
model_class,
|
||||
repo_id,
|
||||
cache_subdir="diffusers",
|
||||
safety_checker=None,
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def update_config_file(successfully_downloaded: dict, config_file: Path):
|
||||
config_file = (
|
||||
Path(config_file) if config_file is not None else default_config_file()
|
||||
)
|
||||
|
||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
||||
# Create it if it doesn't exist.
|
||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
||||
# are doing if they are passing a custom config file from elsewhere.
|
||||
if config_file is default_config_file() and not config_file.parent.exists():
|
||||
configs_src = Dataset_path.parent
|
||||
configs_dest = default_config_file().parent
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
yaml = new_config_file_contents(successfully_downloaded, config_file)
|
||||
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
if sys.platform == "win32" and backup.is_file():
|
||||
backup.unlink()
|
||||
config_file.rename(backup)
|
||||
|
||||
with TemporaryFile() as tmp:
|
||||
tmp.write(Config_preamble.encode())
|
||||
tmp.write(yaml.encode())
|
||||
|
||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
||||
tmp.seek(0)
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def new_config_file_contents(
|
||||
successfully_downloaded: dict, config_file: Path,
|
||||
) -> str:
|
||||
if config_file.exists():
|
||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
||||
else:
|
||||
conf = OmegaConf.create()
|
||||
|
||||
default_selected = None
|
||||
for model in successfully_downloaded:
|
||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
||||
# version of the model was previously defined, and whether the current
|
||||
# model is a diffusers (indicated with a path)
|
||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
||||
delete_weights(model, conf[model])
|
||||
|
||||
stanza = {}
|
||||
mod = initial_models()[model]
|
||||
stanza["description"] = mod["description"]
|
||||
stanza["repo_id"] = mod["repo_id"]
|
||||
stanza["format"] = mod["format"]
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if "width" in mod:
|
||||
stanza["width"] = mod["width"]
|
||||
if "height" in mod:
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(os.path.join(sd_configs(), mod["config"]))
|
||||
if "vae" in mod:
|
||||
if "file" in mod["vae"]:
|
||||
stanza["vae"] = os.path.normpath(
|
||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
||||
)
|
||||
else:
|
||||
stanza["vae"] = mod["vae"]
|
||||
if mod.get("default", False):
|
||||
stanza["default"] = True
|
||||
default_selected = True
|
||||
|
||||
conf[model] = stanza
|
||||
|
||||
# if no default model was chosen, then we select the first
|
||||
# one in the list
|
||||
if not default_selected:
|
||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def delete_weights(model_name: str, conf_stanza: dict):
|
||||
if not (weights := conf_stanza.get("weights")):
|
||||
return
|
||||
if re.match("/VAE/", conf_stanza.get("config")):
|
||||
return
|
||||
|
||||
print(
|
||||
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
||||
)
|
||||
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
139
ldm/invoke/config/widgets.py
Normal file
139
ldm/invoke/config/widgets.py
Normal file
@ -0,0 +1,139 @@
|
||||
'''
|
||||
Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
|
||||
'''
|
||||
import math
|
||||
import npyscreen
|
||||
import curses
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredTitleText(npyscreen.TitleText):
|
||||
def __init__(self,*args,**keywords):
|
||||
super().__init__(*args,**keywords)
|
||||
self.resize()
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
begin_entry_at = -self.relx + 2
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredButtonPress(npyscreen.ButtonPress):
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
# -------------------------------------
|
||||
class OffsetButtonPress(npyscreen.ButtonPress):
|
||||
def __init__(self, screen, offset=0, *args, **keywords):
|
||||
super().__init__(screen, *args, **keywords)
|
||||
self.offset = offset
|
||||
|
||||
def resize(self):
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
width = len(self.name)
|
||||
self.relx = self.offset + (maxx - width) // 2
|
||||
|
||||
class IntTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = IntSlider
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
class MultiSelectColumns(npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int=1, values: list=[], **keywords):
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen,values=values, **keywords)
|
||||
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
for h in range(self.value_cnt):
|
||||
self._my_widgets.append(
|
||||
self._contained_widgets(self.parent,
|
||||
rely=self.rely + (h % self.rows) * self._contained_widget_height,
|
||||
relx=self.relx + (h // self.rows) * column_width,
|
||||
max_width=column_width,
|
||||
max_height=self.__class__._contained_widget_height,
|
||||
)
|
||||
)
|
||||
|
||||
def set_up_handlers(self):
|
||||
super().set_up_handlers()
|
||||
self.handlers.update({
|
||||
curses.KEY_UP: self.h_cursor_line_left,
|
||||
curses.KEY_DOWN: self.h_cursor_line_right,
|
||||
}
|
||||
)
|
||||
def h_cursor_line_down(self, ch):
|
||||
self.cursor_line += self.rows
|
||||
if self.cursor_line >= len(self.values):
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = len(self.values)-self.rows
|
||||
self.h_exit_down(ch)
|
||||
return True
|
||||
else:
|
||||
self.cursor_line -= self.rows
|
||||
return True
|
||||
|
||||
def h_cursor_line_up(self, ch):
|
||||
self.cursor_line -= self.rows
|
||||
if self.cursor_line < 0:
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = 0
|
||||
self.h_exit_up(ch)
|
||||
else:
|
||||
self.cursor_line = 0
|
||||
|
||||
def h_cursor_line_left(self,ch):
|
||||
super().h_cursor_line_up(ch)
|
||||
|
||||
def h_cursor_line_right(self,ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
class TextBox(npyscreen.MultiLineEdit):
|
||||
def update(self, clear=True):
|
||||
if clear: self.clear()
|
||||
|
||||
HEIGHT = self.height
|
||||
WIDTH = self.width
|
||||
# draw box.
|
||||
self.parent.curses_pad.hline(self.rely, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.hline(self.rely + HEIGHT, self.relx, curses.ACS_HLINE, WIDTH)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx, curses.ACS_VLINE, self.height)
|
||||
self.parent.curses_pad.vline(self.rely, self.relx+WIDTH, curses.ACS_VLINE, HEIGHT)
|
||||
|
||||
# draw corners
|
||||
self.parent.curses_pad.addch(self.rely, self.relx, curses.ACS_ULCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely, self.relx+WIDTH, curses.ACS_URCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx, curses.ACS_LLCORNER, )
|
||||
self.parent.curses_pad.addch(self.rely+HEIGHT, self.relx+WIDTH, curses.ACS_LRCORNER, )
|
||||
|
||||
# fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work
|
||||
(relx,rely,height,width) = (self.relx, self.rely, self.height, self.width)
|
||||
self.relx += 1
|
||||
self.rely += 1
|
||||
self.height -= 1
|
||||
self.width -= 1
|
||||
super().update(clear=False)
|
||||
(self.relx,self.rely,self.height,self.width) = (relx, rely, height, width)
|
@ -339,7 +339,6 @@ class Generator:
|
||||
if self.caution_img:
|
||||
return self.caution_img
|
||||
path = Path(web_assets.__path__[0]) / CAUTION_IMG
|
||||
print(f'DEBUG: path to caution = {path}')
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height //2))
|
||||
return self.caution_img
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import psutil
|
||||
import secrets
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
@ -308,7 +309,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
if is_xformers_available() and not Globals.disable_xformers:
|
||||
if torch.cuda.is_available() and is_xformers_available() and not Globals.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
if torch.backends.mps.is_available():
|
||||
|
@ -40,8 +40,6 @@ class Omnibus(Img2Img,Txt2Img):
|
||||
self.perlin = perlin
|
||||
num_samples = 1
|
||||
|
||||
print('DEBUG: IN OMNIBUS')
|
||||
|
||||
sampler.make_schedule(
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ from diffusers import logging as dlogging
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.invoke.config.widgets import FloatTitleSlider
|
||||
from ldm.invoke.globals import (Globals, global_cache_dir, global_config_file,
|
||||
global_models_dir, global_set_root)
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
@ -172,18 +173,6 @@ def _parse_args() -> Namespace:
|
||||
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||
|
||||
|
@ -11,10 +11,12 @@ import gc
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union
|
||||
@ -31,12 +33,22 @@ from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||
global_models_dir)
|
||||
from ldm.util import (ask_user, download_with_resume,
|
||||
url_attachment_name, instantiate_from_config)
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.util import (
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
V1_INPAINT = 2
|
||||
V2 = 3
|
||||
UNKNOWN = 99
|
||||
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
@ -51,7 +63,7 @@ class ModelManager(object):
|
||||
device_type: torch.device = CPU_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload = False
|
||||
sequential_offload=False,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file,
|
||||
@ -129,6 +141,7 @@ class ModelManager(object):
|
||||
for model_name in self.config:
|
||||
if self.config[model_name].get("default"):
|
||||
return model_name
|
||||
return list(self.config.keys())[0] # first one
|
||||
|
||||
def set_default_model(self, model_name: str) -> None:
|
||||
"""
|
||||
@ -375,21 +388,31 @@ class ModelManager(object):
|
||||
print(
|
||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import \
|
||||
load_pipeline_from_original_stable_diffusion_ckpt
|
||||
from ldm.invoke.ckpt_to_diffuser import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
self.offload_model(self.current_model)
|
||||
if vae_config := self._choose_diffusers_vae(model_name):
|
||||
vae = self._load_vae(vae_config)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=weights,
|
||||
original_config_file=config,
|
||||
vae=vae,
|
||||
return_generator_pipeline=True,
|
||||
precision=torch.float16
|
||||
if self.precision == "float16"
|
||||
else torch.float32,
|
||||
)
|
||||
if self.sequential_offload:
|
||||
pipeline.enable_offload_submodels(self.device)
|
||||
else:
|
||||
pipeline.to(self.device)
|
||||
|
||||
return (
|
||||
pipeline.to(self.device).to(
|
||||
torch.float16 if self.precision == "float16" else torch.float32
|
||||
),
|
||||
pipeline,
|
||||
width,
|
||||
height,
|
||||
"NOHASH",
|
||||
@ -466,19 +489,6 @@ class ModelManager(object):
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
|
||||
module._orig_padding_mode = module.padding_mode
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
|
||||
if self._has_cuda():
|
||||
print(
|
||||
">> Max VRAM used to load the model:",
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
||||
"\n>> Current VRAM usage:"
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
@ -496,8 +506,8 @@ class ModelManager(object):
|
||||
safety_checker=None, local_files_only=not Globals.internet_available
|
||||
)
|
||||
if "vae" in mconfig and mconfig["vae"] is not None:
|
||||
vae = self._load_vae(mconfig["vae"])
|
||||
pipeline_args.update(vae=vae)
|
||||
if vae := self._load_vae(mconfig["vae"]):
|
||||
pipeline_args.update(vae=vae)
|
||||
if not isinstance(name_or_path, Path):
|
||||
pipeline_args.update(cache_dir=global_cache_dir("diffusers"))
|
||||
if using_fp16:
|
||||
@ -555,7 +565,7 @@ class ModelManager(object):
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
|
||||
if "path" in mconfig:
|
||||
if "path" in mconfig and mconfig["path"] is not None:
|
||||
path = Path(mconfig["path"])
|
||||
if not path.is_absolute():
|
||||
path = Path(Globals.root, path).resolve()
|
||||
@ -610,13 +620,13 @@ class ModelManager(object):
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print(">> Model scanned ok!")
|
||||
print(">> Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_description: str = None,
|
||||
vae: dict = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> bool:
|
||||
@ -632,21 +642,24 @@ class ModelManager(object):
|
||||
models.yaml file.
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
description = description or f"imported diffusers model {model_name}"
|
||||
description = model_description or f"imported diffusers model {model_name}"
|
||||
new_config = dict(
|
||||
description=description,
|
||||
description=model_description,
|
||||
vae=vae,
|
||||
format="diffusers",
|
||||
)
|
||||
print(f"DEBUG: here i am 1")
|
||||
if isinstance(repo_or_path, Path) and repo_or_path.exists():
|
||||
new_config.update(path=str(repo_or_path))
|
||||
else:
|
||||
new_config.update(repo_id=repo_or_path)
|
||||
print(f"DEBUG: here i am 2")
|
||||
|
||||
self.add_model(model_name, new_config, True)
|
||||
print(f"DEBUG: config = {self.config}")
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return True
|
||||
return model_name
|
||||
|
||||
def import_ckpt_model(
|
||||
self,
|
||||
@ -656,7 +669,7 @@ class ModelManager(object):
|
||||
model_name: str = None,
|
||||
model_description: str = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> bool:
|
||||
) -> str:
|
||||
"""
|
||||
Attempts to install the indicated ckpt file and returns True if successful.
|
||||
|
||||
@ -673,19 +686,23 @@ class ModelManager(object):
|
||||
then these will be derived from the weight file name. If you provide a commit_to_conf
|
||||
path to the configuration file, then the new entry will be committed to the
|
||||
models.yaml file.
|
||||
|
||||
Return value is the name of the imported file, or None if an error occurred.
|
||||
"""
|
||||
if str(weights).startswith(("http:", "https:")):
|
||||
model_name = model_name or url_attachment_name(weights)
|
||||
|
||||
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
config_path = self._resolve_path(config, "configs/stable-diffusion")
|
||||
|
||||
if weights_path is None or not weights_path.exists():
|
||||
return False
|
||||
return
|
||||
if config_path is None or not config_path.exists():
|
||||
return False
|
||||
return
|
||||
|
||||
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||
model_name = (
|
||||
model_name or Path(weights).stem
|
||||
) # note this gives ugly pathnames if used on a URL without a Content-Disposition header
|
||||
model_description = (
|
||||
model_description or f"imported stable diffusion weights file {model_name}"
|
||||
)
|
||||
@ -702,43 +719,205 @@ class ModelManager(object):
|
||||
self.add_model(model_name, new_config, True)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return True
|
||||
return model_name
|
||||
|
||||
def autoconvert_weights(
|
||||
@classmethod
|
||||
def probe_model_type(self, checkpoint: dict) -> SDLegacyType:
|
||||
"""
|
||||
Given a pickle or safetensors model object, probes contents
|
||||
of the object and returns an SDLegacyType indicating its
|
||||
format. Valid return values include:
|
||||
SDLegacyType.V1
|
||||
SDLegacyType.V1_INPAINT
|
||||
SDLegacyType.V2
|
||||
UNKNOWN
|
||||
"""
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
return SDLegacyType.V2
|
||||
|
||||
try:
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
in_channels = state_dict[
|
||||
"model.diffusion_model.input_blocks.0.0.weight"
|
||||
].shape[1]
|
||||
if in_channels == 9:
|
||||
return SDLegacyType.V1_INPAINT
|
||||
elif in_channels == 4:
|
||||
return SDLegacyType.V1
|
||||
else:
|
||||
return SDLegacyType.UNKNOWN
|
||||
except KeyError:
|
||||
return SDLegacyType.UNKNOWN
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
conf_path: Path,
|
||||
weights_directory: Path = None,
|
||||
dest_directory: Path = None,
|
||||
):
|
||||
path_url_or_repo: str,
|
||||
convert: bool = False,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
commit_to_conf: Path = None,
|
||||
) -> str:
|
||||
"""
|
||||
Scan the indicated directory for .ckpt files, convert into diffuser models,
|
||||
and import.
|
||||
Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
- a URL pointing to a legacy .ckpt or .safetensors file
|
||||
- a local path pointing to a legacy .ckpt or .safetensors file
|
||||
- a local directory containing .ckpt and .safetensors files
|
||||
- a local directory containing a diffusers model
|
||||
|
||||
After determining the nature of the model and downloading it
|
||||
(if necessary), the file is probed to determine the correct
|
||||
configuration file (if needed) and it is imported.
|
||||
|
||||
The model_name and/or description can be provided. If not, they will
|
||||
be generated automatically.
|
||||
|
||||
If convert is true, legacy models will be converted to diffusers
|
||||
before importing.
|
||||
|
||||
If commit_to_conf is provided, the newly loaded model will be written
|
||||
to the `models.yaml` file at the indicated path. Otherwise, the changes
|
||||
will only remain in memory.
|
||||
|
||||
The (potentially derived) name of the model is returned on success, or None
|
||||
on failure. When multiple models are added from a directory, only the last
|
||||
imported one is returned.
|
||||
"""
|
||||
weights_directory = weights_directory or global_autoscan_dir()
|
||||
dest_directory = dest_directory or Path(
|
||||
global_models_dir(), Globals.converted_ckpts_dir
|
||||
)
|
||||
model_path: Path = None
|
||||
thing = path_url_or_repo # to save typing
|
||||
|
||||
print(">> Checking for unconverted .ckpt files in {weights_directory}")
|
||||
ckpt_files = dict()
|
||||
for root, dirs, files in os.walk(weights_directory):
|
||||
for f in files:
|
||||
if not f.endswith(".ckpt"):
|
||||
continue
|
||||
basename = Path(f).stem
|
||||
dest = Path(dest_directory, basename)
|
||||
if not dest.exists():
|
||||
ckpt_files[Path(root, f)] = dest
|
||||
print(f">> Probing {thing} for import")
|
||||
|
||||
if len(ckpt_files) == 0:
|
||||
if thing.startswith(("http:", "https:", "ftp:")):
|
||||
print(f" | {thing} appears to be a URL")
|
||||
model_path = self._resolve_path(
|
||||
thing, "models/ldm/stable-diffusion-v1"
|
||||
) # _resolve_path does a download if needed
|
||||
|
||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||
print(
|
||||
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||
)
|
||||
return
|
||||
else:
|
||||
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||
|
||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||
print(f" | {thing} appears to be a diffusers file on disk")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
model_name=model_name,
|
||||
description=description,
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
|
||||
elif Path(thing).is_dir():
|
||||
if (Path(thing) / "model_index.json").exists():
|
||||
print(f">> {thing} appears to be a diffusers model.")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f">> {thing} appears to be a directory. Will scan for models to import"
|
||||
)
|
||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||
Path(thing).rglob("*.safetensors")
|
||||
):
|
||||
if model_name := self.heuristic_import(
|
||||
str(m), convert, commit_to_conf=commit_to_conf
|
||||
):
|
||||
print(f" >> {model_name} successfully imported")
|
||||
return model_name
|
||||
|
||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||
model_name = self.import_diffuser_model(
|
||||
thing, commit_to_conf=commit_to_conf
|
||||
)
|
||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||
|
||||
else:
|
||||
print(
|
||||
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
|
||||
)
|
||||
|
||||
# Model_path is set in the event of a legacy checkpoint file.
|
||||
# If not set, we're all done
|
||||
if not model_path:
|
||||
return
|
||||
|
||||
print(
|
||||
f">> New .ckpt file(s) found in {weights_directory}. Optimizing and importing..."
|
||||
if model_path.stem in self.config: # already imported
|
||||
print(" | Already imported. Skipping")
|
||||
return
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = (
|
||||
safetensors.torch.load_file(model_path)
|
||||
if model_path.suffix == ".safetensors"
|
||||
else torch.load(model_path)
|
||||
)
|
||||
for ckpt in ckpt_files:
|
||||
self.convert_and_import(ckpt, ckpt_files[ckpt])
|
||||
self.commit(conf_path)
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
|
||||
model_config_file = None
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(
|
||||
" | SD-v2 model detected; model will be converted to diffusers format"
|
||||
)
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
convert = True
|
||||
else:
|
||||
print(
|
||||
f"** {thing} is a legacy checkpoint file but not in a known Stable Diffusion model. Skipping import"
|
||||
)
|
||||
return
|
||||
|
||||
if convert:
|
||||
diffuser_path = Path(
|
||||
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
|
||||
)
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
else:
|
||||
model_name = self.import_ckpt_model(
|
||||
model_path,
|
||||
config=model_config_file,
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
vae=str(
|
||||
Path(
|
||||
Globals.root,
|
||||
"models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
|
||||
)
|
||||
),
|
||||
commit_to_conf=commit_to_conf,
|
||||
)
|
||||
if commit_to_conf:
|
||||
self.commit(commit_to_conf)
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
@ -754,6 +933,12 @@ class ModelManager(object):
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
into models.yaml.
|
||||
"""
|
||||
ckpt_path = self._resolve_path(ckpt_path, "models/ldm/stable-diffusion-v1")
|
||||
if original_config_file:
|
||||
original_config_file = self._resolve_path(
|
||||
original_config_file, "configs/stable-diffusion"
|
||||
)
|
||||
|
||||
new_config = None
|
||||
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
@ -768,7 +953,7 @@ class ModelManager(object):
|
||||
model_description = model_description or f"Optimized version of {model_name}"
|
||||
print(f">> Optimizing {model_name} (30-60s)")
|
||||
try:
|
||||
# By passing the specified VAE too the conversion function, the autoencoder
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = self._load_vae(vae) if vae else None
|
||||
convert_ckpt_to_diffuser(
|
||||
@ -795,9 +980,11 @@ class ModelManager(object):
|
||||
print(">> Conversion succeeded")
|
||||
except Exception as e:
|
||||
print(f"** Conversion failed: {str(e)}")
|
||||
print("** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)")
|
||||
print(
|
||||
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
)
|
||||
|
||||
return new_config
|
||||
return model_name
|
||||
|
||||
def search_models(self, search_folder):
|
||||
print(f">> Finding Models In: {search_folder}")
|
||||
@ -812,10 +999,11 @@ class ModelManager(object):
|
||||
found_models = []
|
||||
for file in files:
|
||||
location = str(file.resolve()).replace("\\", "/")
|
||||
if 'model.safetensors' not in location and 'diffusion_pytorch_model.safetensors' not in location:
|
||||
found_models.append(
|
||||
{"name": file.stem, "location": location}
|
||||
)
|
||||
if (
|
||||
"model.safetensors" not in location
|
||||
and "diffusion_pytorch_model.safetensors" not in location
|
||||
):
|
||||
found_models.append({"name": file.stem, "location": location})
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
@ -975,7 +1163,7 @@ class ModelManager(object):
|
||||
print("** Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
) -> Optional[Path]:
|
||||
resolved_path = None
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
@ -1113,7 +1301,12 @@ class ModelManager(object):
|
||||
|
||||
def _load_vae(self, vae_config) -> AutoencoderKL:
|
||||
vae_args = {}
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
try:
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
except Exception:
|
||||
return None
|
||||
if name_or_path is None:
|
||||
return None
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
vae_args.update(
|
||||
|
42
ldm/util.py
42
ldm/util.py
@ -306,8 +306,12 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
dest/filename
|
||||
:param access_token: Access token to access this resource
|
||||
'''
|
||||
resp = requests.get(url, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = requests.get(url, header, stream=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if dest.is_dir():
|
||||
try:
|
||||
@ -318,41 +322,41 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f'DEBUG: after many manipulations, dest={dest}')
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if dest.exists():
|
||||
exist_size = dest.stat().st_size
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
print('* corrupt existing file found. re-downloading')
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
resp.status_code == 416 or exist_size == content_length
|
||||
):
|
||||
print(f"* {dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
elif resp.status_code == 206 or exist_size > 0:
|
||||
print(f"* {dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {dest}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {dest}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
if content_length < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(dest, open_mode) as file, tqdm(
|
||||
desc=str(dest),
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
desc=str(dest),
|
||||
initial=exist_size,
|
||||
total=content_length,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
|
@ -109,6 +109,7 @@ dependencies = [
|
||||
"invokeai-configure" = "ldm.invoke.config.invokeai_configure:main"
|
||||
"invokeai-merge" = "ldm.invoke.merge_diffusers:main" # note name munging
|
||||
"invokeai-ti" = "ldm.invoke.training.textual_inversion:main"
|
||||
"invokeai-model-install" = "ldm.invoke.config.model_install:main"
|
||||
"invokeai-update" = "ldm.invoke.config.invokeai_update:main"
|
||||
|
||||
[project.urls]
|
||||
|
Loading…
Reference in New Issue
Block a user