mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge github.com:invoke-ai/InvokeAI into lstein/bugfix/model-install-thread-stop
This commit is contained in:
commit
12f9bda524
@ -9,10 +9,6 @@ set -e -o pipefail
|
||||
### Set INVOKEAI_ROOT pointing to a valid runtime directory
|
||||
# Otherwise configure the runtime dir first.
|
||||
|
||||
### Configure the InvokeAI runtime directory (done by default)):
|
||||
# docker run --rm -it <this image> --configure
|
||||
# or skip with --no-configure
|
||||
|
||||
### Set the CONTAINER_UID envvar to match your user.
|
||||
# Ensures files created in the container are owned by you:
|
||||
# docker run --rm -it -v /some/path:/invokeai -e CONTAINER_UID=$(id -u) <this image>
|
||||
@ -22,27 +18,6 @@ USER_ID=${CONTAINER_UID:-1000}
|
||||
USER=ubuntu
|
||||
usermod -u ${USER_ID} ${USER} 1>/dev/null
|
||||
|
||||
configure() {
|
||||
# Configure the runtime directory
|
||||
if [[ -f ${INVOKEAI_ROOT}/invokeai.yaml ]]; then
|
||||
echo "${INVOKEAI_ROOT}/invokeai.yaml exists. InvokeAI is already configured."
|
||||
echo "To reconfigure InvokeAI, delete the above file."
|
||||
echo "======================================================================"
|
||||
else
|
||||
mkdir -p "${INVOKEAI_ROOT}"
|
||||
chown --recursive ${USER} "${INVOKEAI_ROOT}"
|
||||
gosu ${USER} invokeai-configure --yes --default_only
|
||||
fi
|
||||
}
|
||||
|
||||
## Skip attempting to configure.
|
||||
## Must be passed first, before any other args.
|
||||
if [[ $1 != "--no-configure" ]]; then
|
||||
configure
|
||||
else
|
||||
shift
|
||||
fi
|
||||
|
||||
### Set the $PUBLIC_KEY env var to enable SSH access.
|
||||
# We do not install openssh-server in the image by default to avoid bloat.
|
||||
# but it is useful to have the full SSH server e.g. on Runpod.
|
||||
|
@ -18,9 +18,6 @@ Settings sources are used in this order:
|
||||
- `invokeai.yaml` settings
|
||||
- Fallback: defaults
|
||||
|
||||
The most commonly changed settings are also accessible
|
||||
graphically via the `invokeai-configure` script.
|
||||
|
||||
### InvokeAI Root Directory
|
||||
|
||||
On startup, InvokeAI searches for its "root" directory. This is the directory
|
||||
@ -42,10 +39,9 @@ It has two sections - one for internal use and one for user settings:
|
||||
|
||||
```yaml
|
||||
# Internal metadata - do not edit:
|
||||
meta:
|
||||
schema_version: 4
|
||||
schema_version: 4
|
||||
|
||||
# Put user settings here:
|
||||
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:
|
||||
host: 0.0.0.0 # serve the app on your local network
|
||||
models_dir: D:\invokeai\models # store models on an external drive
|
||||
precision: float16 # always use fp16 precision
|
||||
@ -62,6 +58,12 @@ You can fix a broken `invokeai.yaml` by deleting it and running the
|
||||
configuration script again -- option [6] in the launcher, "Re-run the
|
||||
configure script".
|
||||
|
||||
#### Custom Config File Location
|
||||
|
||||
You can use any config file with the `--config` CLI arg. Pass in the path to the `invokeai.yaml` file you want to use.
|
||||
|
||||
Note that environment variables will trump any settings in the config file.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
All settings may be set via environment variables by prefixing `INVOKEAI_`
|
||||
@ -81,13 +83,10 @@ We suggest using `invokeai.yaml`, as it is more user-friendly.
|
||||
A subset of settings may be specified using CLI args:
|
||||
|
||||
- `--root`: specify the root directory
|
||||
- `--ignore_missing_core-models`: if set, do not check for models needed
|
||||
to convert checkpoint/safetensor models to diffusers
|
||||
- `--config`: override the default `invokeai.yaml` file location
|
||||
|
||||
### All Settings
|
||||
|
||||
The config is managed by the `InvokeAIAppConfig` class. The below docs are autogenerated from the class.
|
||||
|
||||
Following the table are additional explanations for certain settings.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
|
@ -122,9 +122,9 @@ experimental versions later.
|
||||
[latest release](https://github.com/invoke-ai/InvokeAI/releases/latest),
|
||||
and look for a file named:
|
||||
|
||||
- InvokeAI-installer-v3.X.X.zip
|
||||
- InvokeAI-installer-v4.X.X.zip
|
||||
|
||||
where "3.X.X" is the latest released version. The file is located
|
||||
where "4.X.X" is the latest released version. The file is located
|
||||
at the very bottom of the release page, under **Assets**.
|
||||
|
||||
4. **Unpack the installer**: Unpack the zip file into a convenient directory. This will create a new
|
||||
@ -199,136 +199,7 @@ experimental versions later.
|
||||
![initial-settings-screenshot](../assets/installer-walkthrough/settings-form.png)
|
||||
</figure>
|
||||
|
||||
10. **Post-install Configuration**: After installation completes, the
|
||||
installer will launch the configuration form, which will guide you
|
||||
through the first-time process of adjusting some of InvokeAI's
|
||||
startup settings. To move around this form use ctrl-N for
|
||||
<N>ext and ctrl-P for <P>revious, or use <tab>
|
||||
and shift-<tab> to move forward and back. Once you are in a
|
||||
multi-checkbox field use the up and down cursor keys to select the
|
||||
item you want, and <space> to toggle it on and off. Within
|
||||
a directory field, pressing <tab> will provide autocomplete
|
||||
options.
|
||||
|
||||
Generally the defaults are fine, and you can come back to this screen at
|
||||
any time to tweak your system. Here are the options you can adjust:
|
||||
|
||||
- ***HuggingFace Access Token***
|
||||
InvokeAI has the ability to download embedded styles and subjects
|
||||
from the HuggingFace Concept Library on-demand. However, some of
|
||||
the concept library files are password protected. To make download
|
||||
smoother, you can set up an account at huggingface.co, obtain an
|
||||
access token, and paste it into this field. Note that you paste
|
||||
to this screen using ctrl-shift-V
|
||||
|
||||
- ***Free GPU memory after each generation***
|
||||
This is useful for low-memory machines and helps minimize the
|
||||
amount of GPU VRAM used by InvokeAI.
|
||||
|
||||
- ***Enable xformers support if available***
|
||||
If the xformers library was successfully installed, this will activate
|
||||
it to reduce memory consumption and increase rendering speed noticeably.
|
||||
Note that xformers has the side effect of generating slightly different
|
||||
images even when presented with the same seed and other settings.
|
||||
|
||||
- ***Force CPU to be used on GPU systems***
|
||||
This will use the (slow) CPU rather than the accelerated GPU. This
|
||||
can be used to generate images on systems that don't have a compatible
|
||||
GPU.
|
||||
|
||||
- ***Precision***
|
||||
This controls whether to use float32 or float16 arithmetic.
|
||||
float16 uses less memory but is also slightly less accurate.
|
||||
Ordinarily the right arithmetic is picked automatically ("auto"),
|
||||
but you may have to use float32 to get images on certain systems
|
||||
and graphics cards. The "autocast" option is deprecated and
|
||||
shouldn't be used unless you are asked to by a member of the team.
|
||||
|
||||
- **Size of the RAM cache used for fast model switching***
|
||||
This allows you to keep models in memory and switch rapidly among
|
||||
them rather than having them load from disk each time. This slider
|
||||
controls how many models to keep loaded at once. A typical SD-1 or SD-2 model
|
||||
uses 2-3 GB of memory. A typical SDXL model uses 6-7 GB. Providing more
|
||||
RAM will allow more models to be co-resident.
|
||||
|
||||
- ***Output directory for images***
|
||||
This is the path to a directory in which InvokeAI will store all its
|
||||
generated images.
|
||||
|
||||
- ***Autoimport Folder***
|
||||
This is the directory in which you can place models you have
|
||||
downloaded and wish to load into InvokeAI. You can place a variety
|
||||
of models in this directory, including diffusers folders, .ckpt files,
|
||||
.safetensors files, as well as LoRAs, ControlNet and Textual Inversion
|
||||
files (both folder and file versions). To help organize this folder,
|
||||
you can create several levels of subfolders and drop your models into
|
||||
whichever ones you want.
|
||||
|
||||
- ***LICENSE***
|
||||
|
||||
At the bottom of the screen you will see a checkbox for accepting
|
||||
the CreativeML Responsible AI Licenses. You need to accept the license
|
||||
in order to download Stable Diffusion models from the next screen.
|
||||
|
||||
_You can come back to the startup options form_ as many times as you like.
|
||||
From the `invoke.sh` or `invoke.bat` launcher, select option (6) to relaunch
|
||||
this script. On the command line, it is named `invokeai-configure`.
|
||||
|
||||
11. **Downloading Models**: After you press `[NEXT]` on the screen, you will be taken
|
||||
to another screen that prompts you to download a series of starter models. The ones
|
||||
we recommend are preselected for you, but you are encouraged to use the checkboxes to
|
||||
pick and choose.
|
||||
You will probably wish to download `autoencoder-840000` for use with models that
|
||||
were trained with an older version of the Stability VAE.
|
||||
|
||||
<figure markdown>
|
||||
![select-models-screenshot](../assets/installer-walkthrough/installing-models.png)
|
||||
</figure>
|
||||
|
||||
Below the preselected list of starter models is a large text field which you can use
|
||||
to specify a series of models to import. You can specify models in a variety of formats,
|
||||
each separated by a space or newline. The formats accepted are:
|
||||
|
||||
- The path to a .ckpt or .safetensors file. On most systems, you can drag a file from
|
||||
the file browser to the textfield to automatically paste the path. Be sure to remove
|
||||
extraneous quotation marks and other things that come along for the ride.
|
||||
|
||||
- The path to a directory containing a combination of `.ckpt` and `.safetensors` files.
|
||||
The directory will be scanned from top to bottom (including subfolders) and any
|
||||
file that can be imported will be.
|
||||
|
||||
- A URL pointing to a `.ckpt` or `.safetensors` file. You can cut
|
||||
and paste directly from a web page, or simply drag the link from the web page
|
||||
or navigation bar. (You can also use ctrl-shift-V to paste into this field)
|
||||
The file will be downloaded and installed.
|
||||
|
||||
- The HuggingFace repository ID (repo_id) for a `diffusers` model. These IDs have
|
||||
the format _author_name/model_name_, as in `andite/anything-v4.0`
|
||||
|
||||
- The path to a local directory containing a `diffusers`
|
||||
model. These directories always have the file `model_index.json`
|
||||
at their top level.
|
||||
|
||||
_Select a directory for models to import_ You may select a local
|
||||
directory for autoimporting at startup time. If you select this
|
||||
option, the directory you choose will be scanned for new
|
||||
.ckpt/.safetensors files each time InvokeAI starts up, and any new
|
||||
files will be automatically imported and made available for your
|
||||
use.
|
||||
|
||||
_Convert imported models into diffusers_ When legacy checkpoint
|
||||
files are imported, you may select to use them unmodified (the
|
||||
default) or to convert them into `diffusers` models. The latter
|
||||
load much faster and have slightly better rendering performance,
|
||||
but not all checkpoint files can be converted. Note that Stable Diffusion
|
||||
Version 2.X files are **only** supported in `diffusers` format and will
|
||||
be converted regardless.
|
||||
|
||||
_You can come back to the model install form_ as many times as you like.
|
||||
From the `invoke.sh` or `invoke.bat` launcher, select option (5) to relaunch
|
||||
this script. On the command line, it is named `invokeai-model-install`.
|
||||
|
||||
12. **Running InvokeAI for the first time**: The script will now exit and you'll be ready to generate some images. Look
|
||||
10. **Running InvokeAI for the first time**: The script will now exit and you'll be ready to generate some images. Look
|
||||
for the directory `invokeai` installed in the location you chose at the
|
||||
beginning of the install session. Look for a shell script named `invoke.sh`
|
||||
(Linux/Mac) or `invoke.bat` (Windows). Launch the script by double-clicking
|
||||
@ -349,14 +220,14 @@ experimental versions later.
|
||||
http://localhost:9090. Click on this link to open up a browser
|
||||
and start exploring InvokeAI's features.
|
||||
|
||||
12. **InvokeAI Options**: You can launch InvokeAI with several different command-line arguments that
|
||||
customize its behavior. For example, you can change the location of the
|
||||
12. **InvokeAI Options**: You can configure using the `invokeai.yaml` config file.
|
||||
For example, you can change the location of the
|
||||
image output directory or balance memory usage vs performance. See
|
||||
[Configuration](../features/CONFIGURATION.md) for a full list of the options.
|
||||
|
||||
- To set defaults that will take effect every time you launch InvokeAI,
|
||||
use a text editor (e.g. Notepad) to exit the file
|
||||
`invokeai\invokeai.init`. It contains a variety of examples that you can
|
||||
`invokeai\invokeai.yaml`. It contains a variety of examples that you can
|
||||
follow to add and modify launch options.
|
||||
|
||||
- The launcher script also offers you an option labeled "open the developer
|
||||
@ -394,7 +265,6 @@ rm .\.venv -r -force
|
||||
python -mvenv .venv
|
||||
.\.venv\Scripts\activate
|
||||
pip install invokeai
|
||||
invokeai-configure --yes --root .
|
||||
```
|
||||
|
||||
If you see anything marked as an error during this process please stop
|
||||
@ -426,16 +296,10 @@ error messages:
|
||||
This failure mode occurs when there is a network glitch during
|
||||
downloading the very large SDXL model.
|
||||
|
||||
To address this, first go to the Web Model Manager and delete the
|
||||
Stable-Diffusion-XL-base-1.X model. Then navigate to HuggingFace and
|
||||
manually download the .safetensors version of the model. The 1.0
|
||||
version is located at
|
||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main
|
||||
and the file is named `sd_xl_base_1.0.safetensors`.
|
||||
|
||||
Save this file to disk and then reenter the Model Manager. Navigate to
|
||||
Import Models->Add Model, then type (or drag-and-drop) the path to the
|
||||
.safetensors file. Press "Add Model".
|
||||
To address this, first go to the Model Manager and delete the
|
||||
Stable-Diffusion-XL-base-1.X model. Then, click the HuggingFace tab,
|
||||
paste the Repo ID stabilityai/stable-diffusion-xl-base-1.0 and install
|
||||
the model.
|
||||
|
||||
### _Package dependency conflicts_
|
||||
|
||||
@ -488,15 +352,7 @@ download models, etc), but this doesn't fix the problem.
|
||||
|
||||
This issue is often caused by a misconfigured configuration directive in the
|
||||
`invokeai\invokeai.init` initialization file that contains startup settings. The
|
||||
easiest way to fix the problem is to move the file out of the way and re-run
|
||||
`invokeai-configure`. Enter the developer's console (option 3 of the launcher
|
||||
script) and run this command:
|
||||
|
||||
```cmd
|
||||
invokeai-configure --root=.
|
||||
```
|
||||
|
||||
Note the dot (.) after `--root`. It is part of the command.
|
||||
easiest way to fix the problem is to move the file out of the way and restart the app.
|
||||
|
||||
_If none of these maneuvers fixes the problem_ then please report the problem to
|
||||
the [InvokeAI Issues](https://github.com/invoke-ai/InvokeAI/issues) section, or
|
||||
@ -565,16 +421,4 @@ This distribution is changing rapidly, and we add new features
|
||||
regularly. Releases are announced at
|
||||
http://github.com/invoke-ai/InvokeAI/releases, and at
|
||||
https://pypi.org/project/InvokeAI/ To update to the latest released
|
||||
version (recommended), follow these steps:
|
||||
|
||||
1. Start the `invoke.sh`/`invoke.bat` launch script from within the
|
||||
`invokeai` root directory.
|
||||
|
||||
2. Choose menu item (10) "Update InvokeAI".
|
||||
|
||||
3. This will launch a menu that gives you the option of:
|
||||
|
||||
1. Updating to the latest official release;
|
||||
2. Updating to the bleeding-edge development version; or
|
||||
3. Manually entering the tag or branch name of a version of
|
||||
InvokeAI you wish to try out.
|
||||
version (recommended), download the latest release and run the installer.
|
||||
|
@ -26,7 +26,7 @@ driver).
|
||||
|
||||
🖥️ **Download the latest installer .zip file here** : https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
|
||||
- *Look for the file labelled "InvokeAI-installer-v3.X.X.zip" at the bottom of the page*
|
||||
- *Look for the file labelled "InvokeAI-installer-v4.X.X.zip" at the bottom of the page*
|
||||
- If you experience issues, read through the full [installation instructions](010_INSTALL_AUTOMATED.md) to make sure you have met all of the installation requirements. If you need more help, join the [Discord](discord.gg/invoke-ai) or create an issue on [Github](https://github.com/invoke-ai/InvokeAI).
|
||||
|
||||
|
||||
|
@ -149,9 +149,6 @@ class Installer:
|
||||
# install the launch/update scripts into the runtime directory
|
||||
self.instance.install_user_scripts()
|
||||
|
||||
# run through the configuration flow
|
||||
self.instance.configure()
|
||||
|
||||
|
||||
class InvokeAiInstance:
|
||||
"""
|
||||
@ -242,53 +239,6 @@ class InvokeAiInstance:
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
def configure(self):
|
||||
"""
|
||||
Configure the InvokeAI runtime directory
|
||||
"""
|
||||
|
||||
auto_install = False
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1, len(sys.argv)):
|
||||
el = sys.argv[i]
|
||||
if el in ["-r", "--root"]:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
auto_install = True
|
||||
sys.argv = new_argv
|
||||
|
||||
import messages
|
||||
import requests # to catch download exceptions
|
||||
|
||||
auto_install = auto_install or messages.user_wants_auto_configuration()
|
||||
if auto_install:
|
||||
sys.argv.append("--yes")
|
||||
else:
|
||||
messages.introduction()
|
||||
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
# NOTE: currently the config script does its own arg parsing! this means the command-line switches
|
||||
# from the installer will also automatically propagate down to the config script.
|
||||
# this may change in the future with config refactoring!
|
||||
succeeded = False
|
||||
try:
|
||||
invokeai_configure()
|
||||
succeeded = True
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
||||
except OSError as e:
|
||||
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
||||
except Exception as e:
|
||||
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
||||
finally:
|
||||
if not succeeded:
|
||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
||||
print("Alternatively you can relaunch the installer.")
|
||||
|
||||
def install_user_scripts(self):
|
||||
"""
|
||||
Copy the launch and update scripts to the runtime dir
|
||||
|
@ -8,7 +8,7 @@ import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit import HTML, prompt
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
@ -98,39 +98,6 @@ def choose_version(available_releases: tuple | None = None) -> str:
|
||||
return "stable" if response == "" else response
|
||||
|
||||
|
||||
def user_wants_auto_configuration() -> bool:
|
||||
"""Prompt the user to choose between manual and auto configuration."""
|
||||
console.rule("InvokeAI Configuration Section")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
|
||||
"",
|
||||
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
|
||||
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
|
||||
"",
|
||||
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
|
||||
]
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = (
|
||||
prompt(
|
||||
HTML("Choose <b><a></b>utomatic or <b><m></b>anual configuration [a/m] (a): "),
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
|
||||
),
|
||||
)
|
||||
or "a"
|
||||
)
|
||||
return choice.lower().startswith("a")
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":stop_sign: Directory {dest} already exists!")
|
||||
@ -351,34 +318,6 @@ def windows_long_paths_registry() -> None:
|
||||
)
|
||||
|
||||
|
||||
def introduction() -> None:
|
||||
"""
|
||||
Display a banner when starting configuration of the InvokeAI application
|
||||
"""
|
||||
|
||||
console.rule()
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
title=":art: Configuring InvokeAI :art:",
|
||||
renderable=Group(
|
||||
"",
|
||||
"[b]This script will:",
|
||||
"",
|
||||
"1. Configure the InvokeAI application directory",
|
||||
"2. Help download the Stable Diffusion weight files",
|
||||
" and other large models that are needed for text to image generation",
|
||||
"3. Create initial configuration files.",
|
||||
"",
|
||||
"[i]At any point you may interrupt this program and resume later.",
|
||||
"",
|
||||
"[b]For the best user experience, please enlarge or maximize this window",
|
||||
),
|
||||
)
|
||||
)
|
||||
console.line(2)
|
||||
|
||||
|
||||
def _platform_specific_help() -> Text | None:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup(
|
||||
|
@ -9,15 +9,10 @@ set INVOKEAI_ROOT=.
|
||||
:start
|
||||
echo Desired action:
|
||||
echo 1. Generate images with the browser-based interface
|
||||
echo 2. Run textual inversion training
|
||||
echo 3. Merge models (diffusers type only)
|
||||
echo 4. Download and install models
|
||||
echo 5. Change InvokeAI startup options
|
||||
echo 6. Re-run the configure script to fix a broken install or to complete a major upgrade
|
||||
echo 7. Open the developer console
|
||||
echo 8. Update InvokeAI (DEPRECATED - please use the installer)
|
||||
echo 9. Run the InvokeAI image database maintenance script
|
||||
echo 10. Command-line help
|
||||
echo 2. Open the developer console
|
||||
echo 3. Update InvokeAI (DEPRECATED - please use the installer)
|
||||
echo 4. Run the InvokeAI image database maintenance script
|
||||
echo 5. Command-line help
|
||||
echo Q - Quit
|
||||
set /P choice="Please enter 1-10, Q: [1] "
|
||||
if not defined choice set choice=1
|
||||
@ -25,21 +20,6 @@ IF /I "%choice%" == "1" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invokeai-web.exe %*
|
||||
) ELSE IF /I "%choice%" == "2" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\invokeai-ti.exe --gui
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\invokeai-merge.exe --gui
|
||||
) ELSE IF /I "%choice%" == "4" (
|
||||
echo Running invokeai-model-install...
|
||||
python .venv\Scripts\invokeai-model-install.exe
|
||||
) ELSE IF /I "%choice%" == "5" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||
) ELSE IF /I "%choice%" == "6" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
|
||||
) ELSE IF /I "%choice%" == "7" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@ -51,15 +31,15 @@ IF /I "%choice%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%choice%" == "8" (
|
||||
) ELSE IF /I "%choice%" == "3" (
|
||||
echo UPDATING FROM WITHIN THE APP IS BEING DEPRECATED.
|
||||
echo Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.
|
||||
timeout 4
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
) ELSE IF /I "%choice%" == "9" (
|
||||
) ELSE IF /I "%choice%" == "4" (
|
||||
echo Running the db maintenance script...
|
||||
python .venv\Scripts\invokeai-db-maintenance.exe
|
||||
) ELSE IF /I "%choice%" == "10" (
|
||||
) ELSE IF /I "%choice%" == "5" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai-web.exe --help %*
|
||||
pause
|
||||
|
@ -58,49 +58,24 @@ do_choice() {
|
||||
invokeai-web $PARAMS
|
||||
;;
|
||||
2)
|
||||
clear
|
||||
printf "Textual inversion training\n"
|
||||
invokeai-ti --gui $PARAMS
|
||||
;;
|
||||
3)
|
||||
clear
|
||||
printf "Merge models (diffusers type only)\n"
|
||||
invokeai-merge --gui $PARAMS
|
||||
;;
|
||||
4)
|
||||
clear
|
||||
printf "Download and install models\n"
|
||||
invokeai-model-install --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
5)
|
||||
clear
|
||||
printf "Change InvokeAI startup options\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --skip-sd-weights --skip-support-models
|
||||
;;
|
||||
6)
|
||||
clear
|
||||
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
|
||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
|
||||
;;
|
||||
7)
|
||||
clear
|
||||
printf "Open the developer console\n"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
8)
|
||||
3)
|
||||
clear
|
||||
printf "UPDATING FROM WITHIN THE APP IS BEING DEPRECATED\n"
|
||||
printf "Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.\n"
|
||||
sleep 4
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
;;
|
||||
9)
|
||||
4)
|
||||
clear
|
||||
printf "Running the db maintenance script\n"
|
||||
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
10)
|
||||
5)
|
||||
clear
|
||||
printf "Command-line help\n"
|
||||
invokeai-web --help
|
||||
@ -118,15 +93,10 @@ do_choice() {
|
||||
do_dialog() {
|
||||
options=(
|
||||
1 "Generate images with a browser-based interface"
|
||||
2 "Textual inversion training"
|
||||
3 "Merge models (diffusers type only)"
|
||||
4 "Download and install models"
|
||||
5 "Change InvokeAI startup options"
|
||||
6 "Re-run the configure script to fix a broken install or to complete a major upgrade"
|
||||
7 "Open the developer console"
|
||||
8 "Update InvokeAI (DEPRECATED - please use the installer)"
|
||||
9 "Run the InvokeAI image database maintenance script"
|
||||
10 "Command-line help"
|
||||
2 "Open the developer console"
|
||||
3 "Update InvokeAI (DEPRECATED - please use the installer)"
|
||||
4 "Run the InvokeAI image database maintenance script"
|
||||
5 "Command-line help"
|
||||
)
|
||||
|
||||
choice=$(dialog --clear \
|
||||
@ -151,15 +121,10 @@ do_line_input() {
|
||||
printf " ** For a more attractive experience, please install the 'dialog' utility using your package manager. **\n\n"
|
||||
printf "What would you like to do?\n"
|
||||
printf "1: Generate images using the browser-based interface\n"
|
||||
printf "2: Run textual inversion training\n"
|
||||
printf "3: Merge models (diffusers type only)\n"
|
||||
printf "4: Download and install models\n"
|
||||
printf "5: Change InvokeAI startup options\n"
|
||||
printf "6: Re-run the configure script to fix a broken install\n"
|
||||
printf "7: Open the developer console\n"
|
||||
printf "8: Update InvokeAI\n"
|
||||
printf "9: Run the InvokeAI image database maintenance script\n"
|
||||
printf "10: Command-line help\n"
|
||||
printf "2: Open the developer console\n"
|
||||
printf "3: Update InvokeAI\n"
|
||||
printf "4: Run the InvokeAI image database maintenance script\n"
|
||||
printf "5: Command-line help\n"
|
||||
printf "Q: Quit\n\n"
|
||||
read -p "Please enter 1-10, Q: [1] " yn
|
||||
choice=${yn:='1'}
|
||||
|
@ -1,11 +0,0 @@
|
||||
Organization of the source tree:
|
||||
|
||||
app -- Home of nodes invocations and services
|
||||
assets -- Images and other data files used by InvokeAI
|
||||
backend -- Non-user facing libraries, including the rendering
|
||||
core.
|
||||
configs -- Configuration files used at install and run times
|
||||
frontend -- User-facing scripts, including the CLI and the WebUI
|
||||
version -- Current InvokeAI version string, stored
|
||||
in version/invokeai_version.py
|
||||
|
@ -1,12 +1,16 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import huggingface_hub
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
@ -21,6 +25,7 @@ from invokeai.app.services.model_records import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -32,6 +37,7 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -780,3 +786,69 @@ async def convert_model(
|
||||
# except ValueError as e:
|
||||
# raise HTTPException(status_code=400, detail=str(e))
|
||||
# return response
|
||||
|
||||
|
||||
@model_manager_router.get("/starter_models", operation_id="get_starter_models", response_model=list[StarterModel])
|
||||
async def get_starter_models() -> list[StarterModel]:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||
installed_model_sources = {m.source for m in installed_models}
|
||||
starter_models = deepcopy(STARTER_MODELS)
|
||||
for model in starter_models:
|
||||
if model.source in installed_model_sources:
|
||||
model.is_installed = True
|
||||
# Remove already-installed dependencies
|
||||
missing_deps: list[str] = []
|
||||
for dep in model.dependencies or []:
|
||||
if dep not in installed_model_sources:
|
||||
missing_deps.append(dep)
|
||||
model.dependencies = missing_deps
|
||||
|
||||
return starter_models
|
||||
|
||||
|
||||
class HFTokenStatus(str, Enum):
|
||||
VALID = "valid"
|
||||
INVALID = "invalid"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HFTokenHelper:
|
||||
@classmethod
|
||||
def get_status(cls) -> HFTokenStatus:
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||
# Valid token!
|
||||
return HFTokenStatus.VALID
|
||||
# No token set
|
||||
return HFTokenStatus.INVALID
|
||||
except Exception:
|
||||
return HFTokenStatus.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def set_token(cls, token: str) -> HFTokenStatus:
|
||||
with SuppressOutput(), contextlib.suppress(Exception):
|
||||
huggingface_hub.login(token=token, add_to_git_credential=False)
|
||||
return cls.get_status()
|
||||
|
||||
|
||||
@model_manager_router.get("/hf_login", operation_id="get_hf_login_status", response_model=HFTokenStatus)
|
||||
async def get_hf_login_status() -> HFTokenStatus:
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
||||
|
||||
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||
async def do_hf_login(
|
||||
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||
) -> HFTokenStatus:
|
||||
HFTokenHelper.set_token(token)
|
||||
token_status = HFTokenHelper.get_status()
|
||||
|
||||
if token_status is HFTokenStatus.UNKNOWN:
|
||||
ApiDependencies.invoker.services.logger.warning("Unable to verify HF token")
|
||||
|
||||
return token_status
|
||||
|
@ -233,10 +233,6 @@ def invoke_api() -> None:
|
||||
else:
|
||||
return port
|
||||
|
||||
from invokeai.backend.install.check_directories import check_directories
|
||||
|
||||
check_directories(app_config) # note, may exit with an exception if root not set up
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
|
@ -9,6 +9,7 @@ from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
@ -217,6 +218,13 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Downloads the LaMa model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
||||
)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
@ -11,6 +11,7 @@ from pydantic import ConfigDict
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
@ -27,6 +28,13 @@ ESRGAN_MODELS = Literal[
|
||||
"RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
ESRGAN_MODEL_URLS: dict[str, str] = {
|
||||
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
"RealESRGAN_x4plus_anime_6B.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"ESRGAN_SRx4_DF2KOST_official.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
}
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
@ -45,7 +53,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
models_path = context.config.get().models_path
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
@ -92,11 +99,16 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
# Downloads the ESRGAN model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
||||
)
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=models_path / esrgan_model_path,
|
||||
model_path=esrgan_model_path,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
|
@ -10,10 +10,12 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import psutil
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
@ -29,7 +31,25 @@ ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = 4
|
||||
CONFIG_SCHEMA_VERSION = "4.0.0"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
# On some machines, psutil.virtual_memory().total gives a value that is slightly less than the actual RAM, so the
|
||||
# limits are set slightly lower than than what we expect the actual RAM to be.
|
||||
|
||||
GB = 1024**3
|
||||
max_ram = psutil.virtual_memory().total / GB
|
||||
|
||||
if max_ram >= 60:
|
||||
return 15.0
|
||||
if max_ram >= 30:
|
||||
return 7.5
|
||||
if max_ram >= 14:
|
||||
return 4.0
|
||||
return 2.1 # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
class URLRegexTokenPair(BaseModel):
|
||||
@ -63,7 +83,6 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
ssl_keyfile: SSL key file for HTTPS. See https://www.uvicorn.org/settings/#https.
|
||||
log_tokenization: Enable logging of parsed prompt tokens.
|
||||
patchmatch: Enable patchmatch inpaint code.
|
||||
ignore_missing_core_models: Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.
|
||||
autoimport_dir: Path to a directory of models files to be imported on startup.
|
||||
models_dir: Path to the models directory.
|
||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||
@ -101,11 +120,13 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
_root: Optional[Path] = PrivateAttr(default=None)
|
||||
_config_file: Optional[Path] = PrivateAttr(default=None)
|
||||
|
||||
# fmt: off
|
||||
|
||||
# INTERNAL
|
||||
schema_version: int = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
|
||||
schema_version: str = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
|
||||
# This is only used during v3 models.yaml migration
|
||||
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
|
||||
|
||||
# WEB
|
||||
@ -121,13 +142,12 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
# MISC FEATURES
|
||||
log_tokenization: bool = Field(default=False, description="Enable logging of parsed prompt tokens.")
|
||||
patchmatch: bool = Field(default=True, description="Enable patchmatch inpaint code.")
|
||||
ignore_missing_core_models: bool = Field(default=False, description="Ignore missing core models on startup. If `True`, the app will attempt to download missing models on startup.")
|
||||
|
||||
# PATHS
|
||||
autoimport_dir: Path = Field(default=Path("autoimport"), description="Path to a directory of models files to be imported on startup.")
|
||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
legacy_conf_dir: Path = Field(default=Path("configs/stable-diffusion"), description="Path to directory of legacy checkpoint config files.")
|
||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||
@ -147,7 +167,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
|
||||
|
||||
# CACHE
|
||||
ram: float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
|
||||
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
|
||||
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
|
||||
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
|
||||
@ -202,7 +222,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
if new_value != current_value:
|
||||
setattr(self, field_name, new_value)
|
||||
|
||||
def write_file(self, dest_path: Path) -> None:
|
||||
def write_file(self, dest_path: Path, as_example: bool = False) -> None:
|
||||
"""Write the current configuration to file. This will overwrite the existing file.
|
||||
|
||||
A `meta` stanza is added to the top of the file, containing metadata about the config file. This is not stored in the config object.
|
||||
@ -210,46 +230,31 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
Args:
|
||||
dest_path: Path to write the config to.
|
||||
"""
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(dest_path, "w") as file:
|
||||
# Meta fields should be written in a separate stanza
|
||||
# Meta fields should be written in a separate stanza - skip legacy_models_yaml_path
|
||||
meta_dict = self.model_dump(mode="json", include={"schema_version"})
|
||||
# Only include the legacy_models_yaml_path if it's set
|
||||
if self.legacy_models_yaml_path:
|
||||
meta_dict.update(self.model_dump(mode="json", include={"legacy_models_yaml_path"}))
|
||||
|
||||
# User settings
|
||||
config_dict = self.model_dump(
|
||||
mode="json",
|
||||
exclude_unset=True,
|
||||
exclude_defaults=True,
|
||||
exclude_unset=False if as_example else True,
|
||||
exclude_defaults=False if as_example else True,
|
||||
exclude_none=True if as_example else False,
|
||||
exclude={"schema_version", "legacy_models_yaml_path"},
|
||||
)
|
||||
|
||||
if as_example:
|
||||
file.write(
|
||||
"# This is an example file with default and example settings. Use the values here as a baseline.\n\n"
|
||||
)
|
||||
file.write("# Internal metadata - do not edit:\n")
|
||||
file.write(yaml.dump(meta_dict, sort_keys=False))
|
||||
file.write("\n")
|
||||
file.write("# Put user settings here:\n")
|
||||
file.write("# Put user settings here - see https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/:\n")
|
||||
if len(config_dict) > 0:
|
||||
file.write(yaml.dump(config_dict, sort_keys=False))
|
||||
|
||||
def merge_from_file(self, source_path: Optional[Path] = None) -> None:
|
||||
"""Read the config from the `invokeai.yaml` file, migrating it if necessary and merging it into the singleton config.
|
||||
|
||||
This function will write to the `invokeai.yaml` file if the config is migrated.
|
||||
|
||||
Args:
|
||||
source_path: Path to the config file. If not provided, the default path is used.
|
||||
"""
|
||||
path = source_path or self.init_file_path
|
||||
config_from_file = load_and_migrate_config(path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
self.update_config(config_from_file, clobber=False)
|
||||
|
||||
def set_root(self, root: Path) -> None:
|
||||
"""Set the runtime root directory. This is typically set using a CLI arg."""
|
||||
assert isinstance(root, Path)
|
||||
self._root = root
|
||||
|
||||
def _resolve(self, partial_path: Path) -> Path:
|
||||
return (self.root_path / partial_path).resolve()
|
||||
|
||||
@ -264,9 +269,9 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
return root.resolve()
|
||||
|
||||
@property
|
||||
def init_file_path(self) -> Path:
|
||||
def config_file_path(self) -> Path:
|
||||
"""Path to invokeai.yaml, resolved to an absolute path.."""
|
||||
resolved_path = self._resolve(INIT_FILE)
|
||||
resolved_path = self._resolve(self._config_file or INIT_FILE)
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@ -351,6 +356,14 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
parsed_config_dict["vram"] = v
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
# Else we do not attempt to migrate this setting
|
||||
if v != "configs/stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
@ -378,7 +391,8 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
# This is a v3 config file, attempt to migrate it
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
config = migrate_v3_config_dict(loaded_config_dict)
|
||||
# This could be the wrong shape, but we will catch all exceptions below
|
||||
config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
@ -400,32 +414,50 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Return the global singleton app config.
|
||||
"""Get the global singleton app config.
|
||||
|
||||
When called, this function will parse the CLI args and merge in config from the `invokeai.yaml` config file.
|
||||
When first called, this function:
|
||||
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
||||
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
||||
- Sets the root dir, if provided via CLI args.
|
||||
- Logs in to HF if there is no valid token already.
|
||||
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
||||
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
||||
|
||||
On subsequent calls, the object is returned from the cache.
|
||||
"""
|
||||
config = InvokeAIAppConfig()
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
# CLI args trump environment variables
|
||||
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
||||
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
||||
if not InvokeAIArgs.did_parse:
|
||||
return config
|
||||
|
||||
# Set CLI args
|
||||
if root := getattr(args, "root", None):
|
||||
config.set_root(Path(root))
|
||||
if ignore_missing_core_models := getattr(args, "ignore_missing_core_models", None):
|
||||
config.ignore_missing_core_models = ignore_missing_core_models
|
||||
config._root = Path(root)
|
||||
if config_file := getattr(args, "config_file", None):
|
||||
config._config_file = Path(config_file)
|
||||
|
||||
# TODO(psyche): This shouldn't be wrapped in a try/catch. The configuration script imports a number of classes
|
||||
# from throughout the app, which in turn call get_config(). At that time, there may not be a config file to
|
||||
# read from, and this raises.
|
||||
#
|
||||
# Once we move all* model installation to the web app, the configure script will not be reaching into the main app
|
||||
# and we can make this an unhandled error, which feels correct.
|
||||
#
|
||||
# *all user-facing models. Core model installation will still be handled by the configure/install script.
|
||||
# Create the example file from a deep copy, with some extra values provided
|
||||
example_config = config.model_copy(deep=True)
|
||||
example_config.remote_api_tokens = [
|
||||
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
||||
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
try:
|
||||
config.merge_from_file()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
incoming_config = load_and_migrate_config(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(incoming_config, clobber=False)
|
||||
else:
|
||||
config.write_file(config.config_file_path)
|
||||
|
||||
return config
|
||||
|
@ -114,8 +114,10 @@ class HFModelSource(StringLikeSource):
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
base += f":{self.subfolder}" if self.subfolder else ""
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
|
@ -309,6 +309,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
||||
)
|
||||
|
||||
# The old path may be relative to the root path
|
||||
if not legacy_models_yaml_path.exists():
|
||||
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
|
||||
|
||||
if legacy_models_yaml_path.exists():
|
||||
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
||||
|
||||
@ -348,7 +352,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
# Remove `legacy_models_yaml_path` from the config file - we are done with it either way
|
||||
self._app_config.legacy_models_yaml_path = None
|
||||
self._app_config.write_file(self._app_config.init_file_path)
|
||||
self._app_config.write_file(self._app_config.config_file_path)
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||
@ -615,7 +619,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
# Checkpoints have a config file needed for conversion - resolve this to an absolute path
|
||||
if isinstance(info, CheckpointConfigBase):
|
||||
legacy_conf = (self.app_config.legacy_conf_path / info.config_path).resolve()
|
||||
info.config_path = legacy_conf.as_posix()
|
||||
|
51
invokeai/app/util/download_with_progress.py
Normal file
51
invokeai/app/util/download_with_progress.py
Normal file
@ -0,0 +1,51 @@
|
||||
from pathlib import Path
|
||||
from urllib import request
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
|
||||
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num: int, block_size: int, total_size: int):
|
||||
if not self.pbar:
|
||||
self.pbar = tqdm(
|
||||
desc=self.name,
|
||||
initial=0,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size,
|
||||
)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
|
||||
"""Download a file from a URL to a destination path, with a progress bar.
|
||||
If the file already exists, it will not be downloaded again.
|
||||
|
||||
Exceptions are not caught.
|
||||
|
||||
Args:
|
||||
name (str): Name of the file being downloaded.
|
||||
url (str): URL to download the file from.
|
||||
dest_path (Path): Destination path to save the file to.
|
||||
|
||||
Returns:
|
||||
bool: True if the file was downloaded, False if it already existed.
|
||||
"""
|
||||
if dest_path.exists():
|
||||
return False # already downloaded
|
||||
|
||||
InvokeAILogger.get_logger().info(f"Downloading {name}...")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
request.urlretrieve(url, dest_path, ProgressBar(name))
|
||||
|
||||
return True
|
24
invokeai/app/util/suppress_output.py
Normal file
24
invokeai/app/util/suppress_output.py
Normal file
@ -0,0 +1,24 @@
|
||||
import io
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SuppressOutput:
|
||||
"""Context manager to suppress stdout.
|
||||
|
||||
Example:
|
||||
```
|
||||
with SuppressOutput():
|
||||
print("This will not be printed")
|
||||
```
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
# Save the original stdout
|
||||
self._original_stdout = sys.stdout
|
||||
# Redirect stdout to a dummy StringIO object
|
||||
sys.stdout = io.StringIO()
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any):
|
||||
# Restore stdout
|
||||
sys.stdout = self._original_stdout
|
@ -1,4 +0,0 @@
|
||||
"""Initialization file for invokeai.backend.embeddings modules."""
|
||||
|
||||
# from .model_patcher import ModelPatcher
|
||||
# __all__ = ["ModelPatcher"]
|
@ -1,12 +0,0 @@
|
||||
"""Base class for LoRA and Textual Inversion models.
|
||||
|
||||
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher.
|
||||
|
||||
The use of "Raw" here is a historical artifact, and carried forward in
|
||||
order to avoid confusion.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingModelRaw:
|
||||
"""Base class for LoRA and Textual Inversion models."""
|
@ -5,21 +5,4 @@ Initialization file for invokeai.backend.image_util methods.
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
from .txt2mask import Txt2Mask # noqa: F401
|
||||
from .util import InitImageResizer, make_grid # noqa: F401
|
||||
|
||||
|
||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
||||
from PIL import ImageDraw
|
||||
|
||||
if not debug_status:
|
||||
return
|
||||
|
||||
image_copy = debug_image.copy().convert("RGBA")
|
||||
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
|
||||
|
||||
if debug_show:
|
||||
image_copy.show()
|
||||
|
||||
if debug_result:
|
||||
return image_copy
|
||||
|
@ -10,11 +10,11 @@ from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import download_with_progress_bar
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
@ -59,9 +59,12 @@ class DepthAnythingDetector:
|
||||
self.device = choose_torch_device()
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
|
||||
if not DEPTH_ANYTHING_MODEL_PATH.exists():
|
||||
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
|
||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||
download_with_progress_bar(
|
||||
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
|
||||
DEPTH_ANYTHING_MODELS[model_size]["url"],
|
||||
DEPTH_ANYTHING_MODEL_PATH,
|
||||
)
|
||||
|
||||
if not self.model or model_size != self.model_size:
|
||||
del self.model
|
||||
|
@ -1,14 +1,13 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.util import download_with_progress_bar
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
@ -24,7 +23,7 @@ DWPOSE_MODELS = {
|
||||
},
|
||||
}
|
||||
|
||||
config = get_config
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
@ -33,13 +32,13 @@ class Wholebody:
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = pathlib.Path(config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"])
|
||||
if not DET_MODEL_PATH.exists():
|
||||
download_with_progress_bar(DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
||||
POSE_MODEL_PATH = pathlib.Path(config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"])
|
||||
if not POSE_MODEL_PATH.exists():
|
||||
download_with_progress_bar(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH)
|
||||
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
|
||||
download_with_progress_bar(
|
||||
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
|
||||
)
|
||||
|
||||
onnx_det = DET_MODEL_PATH
|
||||
onnx_pose = POSE_MODEL_PATH
|
||||
|
@ -1,46 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""Very simple functions to fetch and print metadata from InvokeAI-generated images."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_invokeai_metadata(image_path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieve "invokeai_metadata" field from png image.
|
||||
|
||||
:param image_path: Path to the image to read metadata from.
|
||||
May raise:
|
||||
OSError -- image path not found
|
||||
KeyError -- image doesn't contain the metadata field
|
||||
"""
|
||||
image: Image = Image.open(image_path)
|
||||
return json.loads(image.text["invokeai_metadata"])
|
||||
|
||||
|
||||
def print_invokeai_metadata(image_path: Path):
|
||||
"""Pretty-print the metadata."""
|
||||
try:
|
||||
metadata = get_invokeai_metadata(image_path)
|
||||
print(f"{image_path}:\n{json.dumps(metadata, sort_keys=True, indent=4)}")
|
||||
except OSError:
|
||||
print(f"{image_path}:\nNo file found.")
|
||||
except KeyError:
|
||||
print(f"{image_path}:\nNo metadata found.")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the command-line utility."""
|
||||
image_paths = sys.argv[1:]
|
||||
if not image_paths:
|
||||
print(f"Usage: {Path(sys.argv[0]).name} image1 image2 image3 ...")
|
||||
print("\nPretty-print InvokeAI image metadata from the listed png files.")
|
||||
sys.exit(-1)
|
||||
for img in image_paths:
|
||||
print_invokeai_metadata(img)
|
@ -4,6 +4,8 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
|
||||
configuration variable, that allows the checker to be supressed.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from PIL import Image
|
||||
@ -34,22 +36,21 @@ class SafetyChecker:
|
||||
try:
|
||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
|
||||
logger.info("NSFW checker initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||
cls.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def safety_checker_available(cls) -> bool:
|
||||
cls._load_safety_checker()
|
||||
return cls.safety_checker is not None
|
||||
return Path(get_config().models_path, CHECKER_PATH).exists()
|
||||
|
||||
@classmethod
|
||||
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
||||
if not cls.safety_checker_available():
|
||||
if not cls.safety_checker_available() and cls.tried_load:
|
||||
return False
|
||||
cls._load_safety_checker()
|
||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||
return False
|
||||
assert cls.safety_checker is not None
|
||||
assert cls.feature_extractor is not None
|
||||
device = choose_torch_device()
|
||||
features = cls.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
|
@ -1,114 +0,0 @@
|
||||
"""Makes available the Txt2Mask class, which assists in the automatic
|
||||
assignment of masks via text prompt using clipseg.
|
||||
|
||||
Here is typical usage:
|
||||
|
||||
from invokeai.backend.image_util.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
from PIL import Image
|
||||
|
||||
txt2mask = Txt2Mask(self.device)
|
||||
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
|
||||
|
||||
# this will return a grayscale Image of the segmented data
|
||||
grayscale = segmented.to_grayscale()
|
||||
|
||||
# this will return a semi-transparent image in which the
|
||||
# selected object(s) are opaque and the rest is at various
|
||||
# levels of transparency
|
||||
transparent = segmented.to_transparent()
|
||||
|
||||
# this will return a masked image suitable for use in inpainting:
|
||||
mask = segmented.to_mask(threshold=0.5)
|
||||
|
||||
The threshold used in the call to to_mask() selects pixels for use in
|
||||
the mask that exceed the indicated confidence threshold. Values range
|
||||
from 0.0 to 1.0. The higher the threshold, the more confident the
|
||||
algorithm is. In limited testing, I have found that values around 0.5
|
||||
work fine.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
def __init__(self, image: Image.Image, heatmap: torch.Tensor):
|
||||
self.heatmap = heatmap
|
||||
self.image = image
|
||||
|
||||
def to_grayscale(self, invert: bool = False) -> Image.Image:
|
||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
||||
|
||||
def to_mask(self, threshold: float = 0.5) -> Image.Image:
|
||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
||||
|
||||
def to_transparent(self, invert: bool = False) -> Image.Image:
|
||||
transparent_image = self.image.copy()
|
||||
# For img2img, we want the selected regions to be transparent,
|
||||
# but to_grayscale() returns the opposite. Thus invert.
|
||||
gs = self.to_grayscale(not invert)
|
||||
transparent_image.putalpha(gs)
|
||||
return transparent_image
|
||||
|
||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||
def _rescale(self, heatmap: Image.Image) -> Image.Image:
|
||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||
|
||||
|
||||
class Txt2Mask(object):
|
||||
"""
|
||||
Create new Txt2Mask object. The optional device argument can be one of
|
||||
'cuda', 'mps' or 'cpu'.
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", refined=False):
|
||||
logger.info("Initializing clipseg model for text to mask inference")
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=get_config().cache_dir)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image: Image.Image, prompt: str) -> SegmentedGrayscale:
|
||||
"""
|
||||
Given a prompt string such as "a bagel", tries to identify the object in the
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
image = ImageOps.exif_transpose(image)
|
||||
img = self._scale_and_crop(image)
|
||||
|
||||
inputs = self.processor(text=[prompt], images=[img], padding=True, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
heatmap = torch.sigmoid(outputs.logits)
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
||||
def _scale_and_crop(self, image: Image.Image) -> Image.Image:
|
||||
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
scale = CLIPSEG_SIZE / image.width
|
||||
else:
|
||||
scale = CLIPSEG_SIZE / image.height
|
||||
scaled_image.paste(
|
||||
image.resize(
|
||||
(int(scale * image.width), int(scale * image.height)),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
),
|
||||
box=(0, 0),
|
||||
)
|
||||
return scaled_image
|
@ -1,30 +0,0 @@
|
||||
"""
|
||||
Check that the invokeai_root is correctly configured and exit if not.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
# TODO(psyche): Should this also check for things like ESRGAN models, database, etc?
|
||||
def validate_directories(config: InvokeAIAppConfig) -> None:
|
||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||
|
||||
|
||||
def check_directories(config: InvokeAIAppConfig):
|
||||
try:
|
||||
validate_directories(config)
|
||||
except Exception as e:
|
||||
print()
|
||||
print(f"An exception has occurred: {str(e)}")
|
||||
print("== STARTUP ABORTED ==")
|
||||
print("** One or more necessary files is missing from your InvokeAI directories **")
|
||||
print("** Please rerun the configuration script to fix this problem. **")
|
||||
print("** From the launcher, selection option [6]. **")
|
||||
print(
|
||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||
)
|
||||
input("Press any key to continue...")
|
||||
sys.exit(0)
|
@ -1,267 +0,0 @@
|
||||
"""Utility (backend) functions used by model_install.py"""
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import omegaconf
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import (
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
|
||||
|
||||
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
"""Return an initialized ModelConfigRecordServiceBase object."""
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.outputs_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return obj
|
||||
|
||||
|
||||
def initialize_installer(
|
||||
app_config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None
|
||||
) -> ModelInstallServiceBase:
|
||||
"""Return an initialized ModelInstallService object."""
|
||||
record_store = initialize_record_store(app_config)
|
||||
download_queue = DownloadQueueService()
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
download_queue=download_queue,
|
||||
event_bus=event_bus,
|
||||
)
|
||||
download_queue.start()
|
||||
installer.start()
|
||||
return installer
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
"""Catchall class for information in INITIAL_MODELS2.yaml."""
|
||||
|
||||
name: Optional[str] = None
|
||||
base: Optional[BaseModelType] = None
|
||||
type: Optional[ModelType] = None
|
||||
source: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
recommended: bool = False
|
||||
installed: bool = False
|
||||
default: bool = False
|
||||
requires: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
"""Lists of models to install and remove."""
|
||||
|
||||
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
|
||||
remove_models: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TqdmEventService(EventServiceBase):
|
||||
"""An event service to track downloads."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a new TqdmEventService object."""
|
||||
super().__init__()
|
||||
self._bars: Dict[str, tqdm] = {}
|
||||
self._last: Dict[str, int] = {}
|
||||
self._logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
data = payload["data"]
|
||||
source = data["source"]
|
||||
if payload["event"] == "model_install_downloading":
|
||||
dest = data["local_path"]
|
||||
total_bytes = data["total_bytes"]
|
||||
bytes = data["bytes"]
|
||||
if dest not in self._bars:
|
||||
self._bars[dest] = tqdm(desc=Path(dest).name, initial=0, total=total_bytes, unit="iB", unit_scale=True)
|
||||
self._last[dest] = 0
|
||||
self._bars[dest].update(bytes - self._last[dest])
|
||||
self._last[dest] = bytes
|
||||
elif payload["event"] == "model_install_completed":
|
||||
self._logger.info(f"{source}: installed successfully.")
|
||||
elif payload["event"] == "model_install_error":
|
||||
self._logger.warning(f"{source}: installation failed with error {data['error']}")
|
||||
elif payload["event"] == "model_install_cancelled":
|
||||
self._logger.warning(f"{source}: installation cancelled")
|
||||
|
||||
|
||||
class InstallHelper(object):
|
||||
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger):
|
||||
"""Create new InstallHelper object."""
|
||||
self._app_config = app_config
|
||||
self.all_models: Dict[str, UnifiedModelInfo] = {}
|
||||
|
||||
omega = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
||||
assert isinstance(omega, omegaconf.dictconfig.DictConfig)
|
||||
|
||||
self._installer = initialize_installer(app_config, TqdmEventService())
|
||||
self._initial_models = omega
|
||||
self._installed_models: List[str] = []
|
||||
self._starter_models: List[str] = []
|
||||
self._default_model: Optional[str] = None
|
||||
self._logger = logger
|
||||
self._initialize_model_lists()
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallServiceBase:
|
||||
"""Return the installer object used internally."""
|
||||
return self._installer
|
||||
|
||||
def _initialize_model_lists(self) -> None:
|
||||
"""
|
||||
Initialize our model slots.
|
||||
|
||||
Set up the following:
|
||||
installed_models -- list of installed model keys
|
||||
starter_models -- list of starter model keys from INITIAL_MODELS
|
||||
all_models -- dict of key => UnifiedModelInfo
|
||||
default_model -- key to default model
|
||||
"""
|
||||
# previously-installed models
|
||||
for model in self._installer.record_store.all_models():
|
||||
info = UnifiedModelInfo.model_validate(model.model_dump())
|
||||
info.installed = True
|
||||
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||
self.all_models[model_key] = info
|
||||
self._installed_models.append(model_key)
|
||||
|
||||
for key in self._initial_models.keys():
|
||||
assert isinstance(key, str)
|
||||
if key in self.all_models:
|
||||
# we want to preserve the description
|
||||
description = self.all_models[key].description or self._initial_models[key].get("description")
|
||||
self.all_models[key].description = description
|
||||
else:
|
||||
base_model, model_type, model_name = key.split("/")
|
||||
info = UnifiedModelInfo(
|
||||
name=model_name,
|
||||
type=ModelType(model_type),
|
||||
base=BaseModelType(base_model),
|
||||
source=self._initial_models[key].source,
|
||||
description=self._initial_models[key].get("description"),
|
||||
recommended=self._initial_models[key].get("recommended", False),
|
||||
default=self._initial_models[key].get("default", False),
|
||||
subfolder=self._initial_models[key].get("subfolder"),
|
||||
requires=list(self._initial_models[key].get("requires", [])),
|
||||
)
|
||||
self.all_models[key] = info
|
||||
if not self.default_model():
|
||||
self._default_model = key
|
||||
elif self._initial_models[key].get("default", False):
|
||||
self._default_model = key
|
||||
self._starter_models.append(key)
|
||||
|
||||
# previously-installed models
|
||||
for model in self._installer.record_store.all_models():
|
||||
info = UnifiedModelInfo.model_validate(model.model_dump())
|
||||
info.installed = True
|
||||
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||
self.all_models[model_key] = info
|
||||
self._installed_models.append(model_key)
|
||||
|
||||
def recommended_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of the models recommended in INITIAL_MODELS.yaml."""
|
||||
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
|
||||
|
||||
def installed_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of models already installed."""
|
||||
return [self._to_model(x) for x in self._installed_models]
|
||||
|
||||
def starter_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of starter models."""
|
||||
return [self._to_model(x) for x in self._starter_models]
|
||||
|
||||
def default_model(self) -> Optional[UnifiedModelInfo]:
|
||||
"""Return the default model."""
|
||||
return self._to_model(self._default_model) if self._default_model else None
|
||||
|
||||
def _to_model(self, key: str) -> UnifiedModelInfo:
|
||||
return self.all_models[key]
|
||||
|
||||
def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None:
|
||||
installed = {x.source for x in self.installed_models()}
|
||||
reverse_source = {x.source: x for x in self.all_models.values()}
|
||||
additional_models: List[UnifiedModelInfo] = []
|
||||
for model_info in model_list:
|
||||
for requirement in model_info.requires:
|
||||
if requirement not in installed and reverse_source.get(requirement):
|
||||
additional_models.append(reverse_source[requirement])
|
||||
model_list.extend(additional_models)
|
||||
|
||||
def add_or_delete(self, selections: InstallSelections) -> None:
|
||||
"""Add or delete selected models."""
|
||||
installer = self._installer
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
assert model.source
|
||||
model_path_id_or_url = model.source.strip("\"' ")
|
||||
config = (
|
||||
{
|
||||
"description": model.description,
|
||||
"name": model.name,
|
||||
}
|
||||
if model.name
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
installer.heuristic_import(
|
||||
source=model_path_id_or_url,
|
||||
config=config,
|
||||
)
|
||||
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
||||
self._logger.warning(f"{model.source}: {e}")
|
||||
|
||||
for model_to_remove in selections.remove_models:
|
||||
parts = model_to_remove.split("/")
|
||||
if len(parts) == 1:
|
||||
base_model, model_type, model_name = (None, None, model_to_remove)
|
||||
else:
|
||||
base_model, model_type, model_name = parts
|
||||
matches = installer.record_store.search_by_attr(
|
||||
base_model=BaseModelType(base_model) if base_model else None,
|
||||
model_type=ModelType(model_type) if model_type else None,
|
||||
model_name=model_name,
|
||||
)
|
||||
if len(matches) > 1:
|
||||
self._logger.error(
|
||||
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
|
||||
)
|
||||
elif not matches:
|
||||
self._logger.error(f"{model_to_remove}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
self._logger.info(f"Deleting {m.type}:{m.name}")
|
||||
installer.delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
@ -1,837 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
#
|
||||
# Coauthor: Kevin Turner http://github.com/keturn
|
||||
#
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import copy, get_terminal_size, move
|
||||
from typing import Any, Optional, Tuple, Type, get_args, get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import psutil
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub import login as hf_hub_login
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
CenteredButtonPress,
|
||||
CyclingForm,
|
||||
FileBox,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumnsSimple,
|
||||
WindowTooSmallException,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
def get_literal_fields(field: str) -> Tuple[Any]:
|
||||
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
|
||||
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = None
|
||||
|
||||
PRECISION_CHOICES = get_literal_fields("precision")
|
||||
DEVICE_CHOICES = get_literal_fields("device")
|
||||
ATTENTION_CHOICES = get_literal_fields("attention_type")
|
||||
ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
|
||||
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
|
||||
GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0)
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
|
||||
FORCE_FULL_PRECISION = False
|
||||
|
||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||
# Feel free to edit. If anything goes wrong, you can re-initialize this file by deleting
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
"""
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
"""Dummy widget values."""
|
||||
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: set[str]) -> None:
|
||||
if not any(errors):
|
||||
message = f"""
|
||||
** INVOKEAI INSTALLATION SUCCESSFUL **
|
||||
If you installed manually from source or with 'pip install': activate the virtual environment
|
||||
then run one of the following commands to start InvokeAI.
|
||||
|
||||
Web UI:
|
||||
invokeai-web
|
||||
|
||||
If you installed using an installation script, run:
|
||||
{config.root_path}/invoke.{"bat" if sys.platform == "win32" else "sh"}
|
||||
|
||||
Add the '--help' argument to see all of the command-line switches available for use.
|
||||
"""
|
||||
|
||||
else:
|
||||
message = (
|
||||
"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||
)
|
||||
for err in errors:
|
||||
message += f"\t - {err}\n"
|
||||
message += "Please check the logs above and correct any issues."
|
||||
|
||||
print(message)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def HfLogin(access_token) -> None:
|
||||
"""
|
||||
Helper for logging in to Huggingface
|
||||
The stdout capture is needed to hide the irrelevant "git credential helper" warning
|
||||
"""
|
||||
|
||||
capture = io.StringIO()
|
||||
sys.stdout = capture
|
||||
try:
|
||||
hf_hub_login(token=access_token, add_to_git_credential=False)
|
||||
sys.stdout = sys.__stdout__
|
||||
except Exception as exc:
|
||||
sys.stdout = sys.__stdout__
|
||||
print(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class ProgressBar:
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num, block_size, total_size):
|
||||
if not self.pbar:
|
||||
self.pbar = tqdm(
|
||||
desc=self.name,
|
||||
initial=0,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size,
|
||||
)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any):
|
||||
filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731
|
||||
logger.addFilter(filter)
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
finally:
|
||||
logger.removeFilter(filter)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str | Path, label: str = "the"):
|
||||
try:
|
||||
logger.info(f"Installing {label} model file {model_url}...")
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest)))
|
||||
logger.info("...downloaded successfully")
|
||||
else:
|
||||
logger.info("...exists")
|
||||
except Exception:
|
||||
logger.info("...download failed")
|
||||
logger.info(f"Error downloading {label} model")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
def download_safety_checker():
|
||||
target_dir = config.models_path / "core/convert"
|
||||
kwargs = {} # for future use
|
||||
try:
|
||||
# safety checking
|
||||
logger.info("Downloading safety checker")
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# TO DO: use the download queue here.
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
"description": "RealESRGAN_x4plus.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"description": "RealESRGAN_x4plus_anime_6B.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||
},
|
||||
{
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
"description": "RealESRGAN_x2plus.pth",
|
||||
},
|
||||
]
|
||||
for model in URLs:
|
||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_lama():
|
||||
logger.info("Installing lama infill model")
|
||||
download_with_progress_bar(
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
config.models_path / "core/misc/lama/lama.pt",
|
||||
"lama infill model",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_support_models() -> None:
|
||||
download_realesrgan()
|
||||
download_lama()
|
||||
download_safety_checker()
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: Optional[str] = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif root := os.environ.get("INVOKEAI_ROOT"):
|
||||
assert root is not None
|
||||
return root
|
||||
else:
|
||||
return str(config.root_path)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# for responsive resizing - disabled
|
||||
# FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
|
||||
def create(self):
|
||||
program_opts = self.parentApp.program_opts
|
||||
old_opts: InvokeAIAppConfig = self.parentApp.invokeai_opts
|
||||
first_time = not (config.root_path / "invokeai.yaml").exists()
|
||||
access_token = HfFolder.get_token()
|
||||
window_width, window_height = get_terminal_size()
|
||||
label = """Configure startup settings. You can come back and change these later.
|
||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
"""
|
||||
self.nextrely -= 1
|
||||
for i in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
||||
for line in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=line,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
|
||||
self.hf_token = self.add_widget_intelligent(
|
||||
npyscreen.TitlePassword,
|
||||
name="Access Token (ctrl-shift-V pastes):",
|
||||
value=access_token,
|
||||
begin_entry_at=42,
|
||||
use_two_lines=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
|
||||
# old settings for defaults
|
||||
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
||||
device = old_opts.device
|
||||
attention_type = old_opts.attention_type
|
||||
attention_slice_size = old_opts.attention_slice_size
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Image Generation Options:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.generation_options = self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=3,
|
||||
values=GENERATION_OPT_CHOICES,
|
||||
value=[GENERATION_OPT_CHOICES.index(x) for x in GENERATION_OPT_CHOICES if getattr(old_opts, x)],
|
||||
relx=30,
|
||||
max_height=2,
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Floating Point Precision:",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.precision = self.add_widget_intelligent(
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(PRECISION_CHOICES),
|
||||
name="Precision",
|
||||
values=PRECISION_CHOICES,
|
||||
value=PRECISION_CHOICES.index(precision),
|
||||
begin_entry_at=3,
|
||||
max_height=2,
|
||||
relx=30,
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Generation Device:",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.device = self.add_widget_intelligent(
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(DEVICE_CHOICES),
|
||||
values=DEVICE_CHOICES,
|
||||
value=[DEVICE_CHOICES.index(device)],
|
||||
begin_entry_at=3,
|
||||
relx=30,
|
||||
max_height=2,
|
||||
max_width=60,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Attention Type:",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.attention_type = self.add_widget_intelligent(
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(ATTENTION_CHOICES),
|
||||
values=ATTENTION_CHOICES,
|
||||
value=[ATTENTION_CHOICES.index(attention_type)],
|
||||
begin_entry_at=3,
|
||||
max_height=2,
|
||||
relx=30,
|
||||
max_width=80,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.attention_type.on_changed = self.show_hide_slice_sizes
|
||||
self.attention_slice_label = self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Attention Slice Size:",
|
||||
relx=5,
|
||||
editable=False,
|
||||
hidden=attention_type != "sliced",
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 2
|
||||
self.attention_slice_size = self.add_widget_intelligent(
|
||||
SingleSelectColumnsSimple,
|
||||
columns=len(ATTENTION_SLICE_CHOICES),
|
||||
values=ATTENTION_SLICE_CHOICES,
|
||||
value=[ATTENTION_SLICE_CHOICES.index(attention_slice_size)],
|
||||
relx=30,
|
||||
hidden=attention_type != "sliced",
|
||||
max_height=2,
|
||||
max_width=110,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.disk = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.convert_cache, range=(0, 100), step=0.5),
|
||||
out_of=100,
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.ram = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.ram, range=(3.0, MAX_RAM), step=0.5),
|
||||
out_of=round(MAX_RAM),
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
if HAS_CUDA:
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.vram = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.vram, range=(0, MAX_VRAM), step=0.25),
|
||||
out_of=round(MAX_VRAM * 2) / 2,
|
||||
lowest=0.0,
|
||||
relx=8,
|
||||
step=0.25,
|
||||
scroll_exit=True,
|
||||
)
|
||||
else:
|
||||
self.vram = DummyWidgetValue.zero
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Location of the database used to store model path and configuration information:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Output directory for images (<tab> autocompletes, ctrl-N advances):",
|
||||
value=str(default_output_dir()),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.autoimport_path),
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
|
||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
|
||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
||||
"""
|
||||
for i in textwrap.wrap(label, width=window_width - 6):
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=i,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.license_acceptance = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="I accept the CreativeML Responsible AI Licenses",
|
||||
value=not first_time,
|
||||
relx=2,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
|
||||
self.ok_button = self.add_widget_intelligent(
|
||||
CenteredButtonPress,
|
||||
name=label,
|
||||
relx=(window_width - len(label)) // 2,
|
||||
when_pressed_function=self.on_ok,
|
||||
)
|
||||
|
||||
def show_hide_slice_sizes(self, value):
|
||||
show = ATTENTION_CHOICES[value[0]] == "sliced"
|
||||
self.attention_slice_label.hidden = not show
|
||||
self.attention_slice_size.hidden = not show
|
||||
|
||||
def show_hide_model_conf_override(self, value):
|
||||
self.model_conf_override.hidden = value
|
||||
self.model_conf_override.display()
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
self.parentApp.new_opts = options
|
||||
if hasattr(self.parentApp, "model_select"):
|
||||
self.parentApp.setNextForm("MODELS")
|
||||
else:
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def validate_field_values(self, opt: Namespace) -> bool:
|
||||
bad_fields = []
|
||||
if not opt.license_acceptance:
|
||||
bad_fields.append("Please accept the license terms before proceeding to model downloads")
|
||||
if not Path(opt.outdir).parent.exists():
|
||||
bad_fields.append(
|
||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||
)
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:\n"
|
||||
for problem in bad_fields:
|
||||
message += f"* {problem}\n"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def marshall_arguments(self) -> Namespace:
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"ram",
|
||||
"vram",
|
||||
"convert_cache",
|
||||
"outdir",
|
||||
]:
|
||||
if hasattr(self, attr):
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
for attr in self.autoimport_dirs:
|
||||
if not self.autoimport_dirs[attr].value:
|
||||
continue
|
||||
directory = Path(self.autoimport_dirs[attr].value)
|
||||
if directory.is_relative_to(config.root_path):
|
||||
directory = directory.relative_to(config.root_path)
|
||||
setattr(new_opts, attr, directory)
|
||||
|
||||
new_opts.hf_token = self.hf_token.value
|
||||
new_opts.license_acceptance = self.license_acceptance.value
|
||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||
new_opts.device = DEVICE_CHOICES[self.device.value[0]]
|
||||
new_opts.attention_type = ATTENTION_CHOICES[self.attention_type.value[0]]
|
||||
new_opts.attention_slice_size = ATTENTION_SLICE_CHOICES[self.attention_slice_size.value[0]]
|
||||
generation_options = [GENERATION_OPT_CHOICES[x] for x in self.generation_options.value]
|
||||
for v in GENERATION_OPT_CHOICES:
|
||||
setattr(new_opts, v, v in generation_options)
|
||||
return new_opts
|
||||
|
||||
|
||||
class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = program_opts
|
||||
self.invokeai_opts = invokeai_opts
|
||||
self.user_cancelled = False
|
||||
self.autoload_pending = True
|
||||
self.install_helper = install_helper
|
||||
self.install_selections = default_user_selections(program_opts, install_helper)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.options = self.addForm(
|
||||
"MAIN",
|
||||
editOptsForm,
|
||||
name="InvokeAI Startup Options",
|
||||
cycle_widgets=False,
|
||||
)
|
||||
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
|
||||
self.model_select = self.addForm(
|
||||
"MODELS",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
multipage=True,
|
||||
cycle_widgets=False,
|
||||
)
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
|
||||
# So we adjust everthing down a bit.
|
||||
return (
|
||||
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
|
||||
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
def get_default_config() -> InvokeAIAppConfig:
|
||||
"""Builds a new config object, setting the ram and precision using the appropriate heuristic."""
|
||||
config = InvokeAIAppConfig()
|
||||
config.ram = get_default_ram_cache_size()
|
||||
config.precision = "float32" if FORCE_FULL_PRECISION else choose_precision(torch.device(choose_torch_device()))
|
||||
return config
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
default_models = [default_model] if program_opts.default_only else install_helper.recommended_models()
|
||||
return InstallSelections(
|
||||
install_models=default_models if program_opts.yes_to_all else [],
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def clip(value: float, range: tuple[float, float], step: float) -> float:
|
||||
minimum, maximum = range
|
||||
if value < minimum:
|
||||
value = minimum
|
||||
if value > maximum:
|
||||
value = maximum
|
||||
return round(value / step) * step
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
logger.info("Initializing InvokeAI runtime directory")
|
||||
for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
|
||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||
for model_type in ModelType:
|
||||
Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
configs_src = Path(model_configs.__path__[0])
|
||||
configs_dest = root / "configs"
|
||||
if not os.path.samefile(configs_src, configs_dest):
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
dest = root / "models"
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, install_helper: InstallHelper
|
||||
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
|
||||
first_time = not config.init_file_path.exists()
|
||||
config_opts = get_default_config() if first_time else config
|
||||
if program_opts.root:
|
||||
config_opts.set_root(Path(program_opts.root))
|
||||
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
editApp = EditOptApplication(program_opts, config_opts, install_helper)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
return (None, None)
|
||||
else:
|
||||
return (editApp.new_opts, editApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def default_output_dir() -> Path:
|
||||
return config.root_path / "outputs"
|
||||
|
||||
|
||||
def is_v2_install(root: Path) -> bool:
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
old_init_file = root / "invokeai.init"
|
||||
new_init_file = root / "invokeai.yaml"
|
||||
old_hub = root / "models/hub"
|
||||
is_v2 = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||
return is_v2
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main(opt: Namespace) -> None:
|
||||
global FORCE_FULL_PRECISION # FIXME
|
||||
global config
|
||||
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
config = get_config()
|
||||
if opt.full_precision:
|
||||
updates["precision"] = "float32"
|
||||
|
||||
try:
|
||||
# Attempt to read the config file into the config object
|
||||
config.merge_from_file()
|
||||
except FileNotFoundError:
|
||||
# No config file, first time running the app
|
||||
pass
|
||||
|
||||
config.update_config(updates)
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
errors: set[str] = set()
|
||||
FORCE_FULL_PRECISION = opt.full_precision # FIXME global
|
||||
|
||||
# Before we write anything else, make a backup of the existing init file
|
||||
new_init_file = config.init_file_path
|
||||
backup_init_file = new_init_file.with_suffix(".bak")
|
||||
if new_init_file.exists():
|
||||
copy(new_init_file, backup_init_file)
|
||||
|
||||
try:
|
||||
# v2.3 -> v4.0.0 upgrade is no longer supported
|
||||
if is_v2_install(config.root_path):
|
||||
logger.error("Migration from v2.3 to v4.0.0 is no longer supported. Please install a fresh copy.")
|
||||
sys.exit(0)
|
||||
|
||||
# run this unconditionally in case new directories need to be added
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
# this will initialize and populate the models tables if not present
|
||||
install_helper = InstallHelper(config, logger)
|
||||
|
||||
models_to_download = default_user_selections(opt, install_helper)
|
||||
|
||||
if opt.yes_to_all:
|
||||
# We will not show the UI - just write the default config to the file and move on to installing models.
|
||||
get_default_config().write_file(new_init_file)
|
||||
else:
|
||||
# Run the UI to get the user's options & model choices
|
||||
user_opts, models_to_download = run_console_ui(opt, install_helper)
|
||||
if user_opts:
|
||||
# Create a dict of the user's opts, omitting any fields that are not config settings (like `hf_token`)
|
||||
user_opts_dict = {k: v for k, v in vars(user_opts).items() if k in config.model_fields}
|
||||
# Merge the user's opts back into the config object & write it
|
||||
config.update_config(user_opts_dict)
|
||||
config.write_file(config.init_file_path)
|
||||
|
||||
if hasattr(user_opts, "hf_token") and user_opts.hf_token:
|
||||
HfLogin(user_opts.hf_token)
|
||||
else:
|
||||
logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
|
||||
sys.exit(0)
|
||||
|
||||
if opt.skip_support_models:
|
||||
logger.info("Skipping support models at user's request")
|
||||
else:
|
||||
logger.info("Installing support models")
|
||||
download_support_models()
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
logger.warning("Skipping diffusion weights download per user request")
|
||||
elif models_to_download:
|
||||
install_helper.add_or_delete(models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
logger.error(str(e))
|
||||
if backup_init_file.exists():
|
||||
move(backup_init_file, new_init_file)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
if backup_init_file.exists():
|
||||
move(backup_init_file, new_init_file)
|
||||
except Exception:
|
||||
print("An error occurred during installation.")
|
||||
if backup_init_file.exists():
|
||||
move(backup_init_file, new_init_file)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,379 +0,0 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
from argparse import ArgumentParser
|
||||
|
||||
# note that this includes both old sampler names and new scheduler names
|
||||
# in order to be able to parse both 2.0 and 3.0-pre-nodes versions of invokeai.init
|
||||
SAMPLER_CHOICES = [
|
||||
"ddim",
|
||||
"ddpm",
|
||||
"deis",
|
||||
"lms",
|
||||
"lms_k",
|
||||
"pndm",
|
||||
"heun",
|
||||
"heun_k",
|
||||
"euler",
|
||||
"euler_k",
|
||||
"euler_a",
|
||||
"kdpm_2",
|
||||
"kdpm_2_a",
|
||||
"dpmpp_2s",
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
"k_dpm_2_a",
|
||||
"k_dpm_2",
|
||||
"k_dpmpp_2_a",
|
||||
"k_dpmpp_2",
|
||||
"k_euler_a",
|
||||
"k_euler",
|
||||
"k_heun",
|
||||
"k_lms",
|
||||
"plms",
|
||||
"lcm",
|
||||
]
|
||||
|
||||
PRECISION_CHOICES = [
|
||||
"auto",
|
||||
"float32",
|
||||
"autocast",
|
||||
"float16",
|
||||
]
|
||||
|
||||
|
||||
class FileArgumentParser(ArgumentParser):
|
||||
"""
|
||||
Supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def convert_arg_line_to_args(self, arg_line):
|
||||
return shlex.split(arg_line, comments=True)
|
||||
|
||||
|
||||
legacy_parser = FileArgumentParser(
|
||||
description="""
|
||||
Generate images using Stable Diffusion.
|
||||
Use --web to launch the web interface.
|
||||
Use --from_file to load prompts from a file path or standard input ("-").
|
||||
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
|
||||
Other command-line arguments are defaults that can usually be overridden
|
||||
prompt the command prompt.
|
||||
""",
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
general_group = legacy_parser.add_argument_group("General")
|
||||
model_group = legacy_parser.add_argument_group("Model selection")
|
||||
file_group = legacy_parser.add_argument_group("Input/output")
|
||||
web_server_group = legacy_parser.add_argument_group("Web server")
|
||||
render_group = legacy_parser.add_argument_group("Rendering")
|
||||
postprocessing_group = legacy_parser.add_argument_group("Postprocessing")
|
||||
deprecated_group = legacy_parser.add_argument_group("Deprecated options")
|
||||
|
||||
deprecated_group.add_argument("--laion400m")
|
||||
deprecated_group.add_argument("--weights") # deprecated
|
||||
general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number")
|
||||
model_group.add_argument(
|
||||
"--root_dir",
|
||||
default=None,
|
||||
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--config",
|
||||
"-c",
|
||||
"-config",
|
||||
dest="conf",
|
||||
default="./configs/models.yaml",
|
||||
help="Path to configuration file for alternate models.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--model",
|
||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--weight_dirs",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="List of one or more directories that will be auto-scanned for new model weights to import",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--png_compression",
|
||||
"-z",
|
||||
type=int,
|
||||
default=6,
|
||||
choices=range(0, 9),
|
||||
dest="png_compression",
|
||||
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"-F",
|
||||
"--full_precision",
|
||||
dest="full_precision",
|
||||
action="store_true",
|
||||
help="Deprecated way to set --precision=float32",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max_loaded_models",
|
||||
dest="max_loaded_models",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--free_gpu_mem",
|
||||
dest="free_gpu_mem",
|
||||
action="store_true",
|
||||
help="Force free gpu memory before final decoding",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--sequential_guidance",
|
||||
dest="sequential_guidance",
|
||||
action="store_true",
|
||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--xformers",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="Enable/disable xformers support (default enabled if installed)",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available"
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--precision",
|
||||
dest="precision",
|
||||
type=str,
|
||||
choices=PRECISION_CHOICES,
|
||||
metavar="PRECISION",
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default="auto",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--ckpt_convert",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest="ckpt_convert",
|
||||
default=True,
|
||||
help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--internet",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest="internet_available",
|
||||
default=True,
|
||||
help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--nsfw_checker",
|
||||
"--safety_checker",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
dest="safety_checker",
|
||||
default=False,
|
||||
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autoimport",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autoconvert",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--patchmatch",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.",
|
||||
)
|
||||
file_group.add_argument(
|
||||
"--from_file",
|
||||
dest="infile",
|
||||
type=str,
|
||||
help="If specified, load prompts from this file",
|
||||
)
|
||||
file_group.add_argument(
|
||||
"--outdir",
|
||||
"-o",
|
||||
type=str,
|
||||
help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs",
|
||||
default="outputs",
|
||||
)
|
||||
file_group.add_argument(
|
||||
"--prompt_as_dir",
|
||||
"-p",
|
||||
action="store_true",
|
||||
help="Place images in subdirectories named after the prompt.",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"--fnformat",
|
||||
default="{prefix}.{seed}.png",
|
||||
type=str,
|
||||
help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png",
|
||||
)
|
||||
render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps")
|
||||
render_group.add_argument(
|
||||
"-W",
|
||||
"--width",
|
||||
type=int,
|
||||
help="Image width, multiple of 64",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"-H",
|
||||
"--height",
|
||||
type=int,
|
||||
help="Image height, multiple of 64",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"-C",
|
||||
"--cfg_scale",
|
||||
default=7.5,
|
||||
type=float,
|
||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||
)
|
||||
render_group.add_argument(
|
||||
"--sampler",
|
||||
"-A",
|
||||
"-m",
|
||||
dest="sampler_name",
|
||||
type=str,
|
||||
choices=SAMPLER_CHOICES,
|
||||
metavar="SAMPLER_NAME",
|
||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||
default="k_lms",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens"
|
||||
)
|
||||
render_group.add_argument(
|
||||
"-f",
|
||||
"--strength",
|
||||
type=float,
|
||||
help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"-T",
|
||||
"-fit",
|
||||
"--fit",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)",
|
||||
)
|
||||
|
||||
render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid")
|
||||
render_group.add_argument(
|
||||
"--embedding_directory",
|
||||
"--embedding_path",
|
||||
dest="embedding_path",
|
||||
default="embeddings",
|
||||
type=str,
|
||||
help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"--lora_directory",
|
||||
dest="lora_path",
|
||||
default="loras",
|
||||
type=str,
|
||||
help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)",
|
||||
)
|
||||
render_group.add_argument(
|
||||
"--embeddings",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="Enable embedding directory (default). Use --no-embeddings to disable.",
|
||||
)
|
||||
render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display")
|
||||
render_group.add_argument(
|
||||
"--karras_max",
|
||||
type=int,
|
||||
default=None,
|
||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].",
|
||||
)
|
||||
# Restoration related args
|
||||
postprocessing_group.add_argument(
|
||||
"--no_restore",
|
||||
dest="restore",
|
||||
action="store_false",
|
||||
help="Disable face restoration with GFPGAN or codeformer",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
"--no_upscale",
|
||||
dest="esrgan",
|
||||
action="store_false",
|
||||
help="Disable upscaling with ESRGAN",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
"--esrgan_bg_tile",
|
||||
type=int,
|
||||
default=400,
|
||||
help="Tile size for background sampler, 0 for no tile during testing. Default: 400.",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
"--esrgan_denoise_str",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75",
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
"--gfpgan_model_path",
|
||||
type=str,
|
||||
default="./models/gfpgan/GFPGANv1.4.pth",
|
||||
help="Indicates the path to the GFPGAN model",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web",
|
||||
dest="web",
|
||||
action="store_true",
|
||||
help="Start in web server mode.",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web_develop",
|
||||
dest="web_develop",
|
||||
action="store_true",
|
||||
help="Start in web server development mode.",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--web_verbose",
|
||||
action="store_true",
|
||||
help="Enables verbose logging",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--cors",
|
||||
nargs="*",
|
||||
type=str,
|
||||
help="Additional allowed origins, comma-separated",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.",
|
||||
)
|
||||
web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on")
|
||||
web_server_group.add_argument(
|
||||
"--certfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Web server: Path to certificate file to use for SSL. Use together with --keyfile",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Web server: Path to private key file to use for SSL. Use together with --certfile",
|
||||
)
|
||||
web_server_group.add_argument(
|
||||
"--gui",
|
||||
dest="gui",
|
||||
action="store_true",
|
||||
help="Start InvokeAI GUI",
|
||||
)
|
@ -35,8 +35,6 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = config.config_path
|
||||
|
||||
image_size = (
|
||||
512
|
||||
if config.base == BaseModelType.StableDiffusion1
|
||||
@ -46,7 +44,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
)
|
||||
|
||||
self._logger.info(f"Converting {model_path} to diffusers format")
|
||||
with open(self._app_config.root_path / config_file, "r") as config_stream:
|
||||
with open(config.config_path, "r") as config_stream:
|
||||
convert_controlnet_to_diffusers(
|
||||
model_path,
|
||||
output_path,
|
||||
|
@ -76,7 +76,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
assert isinstance(config, MainCheckpointConfig)
|
||||
base = config.base
|
||||
|
||||
config_file = config.config_path
|
||||
prediction_type = config.prediction_type.value
|
||||
upcast_attention = config.upcast_attention
|
||||
image_size = (
|
||||
@ -92,7 +91,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
model_path,
|
||||
output_path,
|
||||
model_type=self.model_base_to_model_type[base],
|
||||
original_config_file=self._app_config.root_path / config_file,
|
||||
original_config_file=config.config_path,
|
||||
extract_ema=True,
|
||||
from_safetensors=model_path.suffix == ".safetensors",
|
||||
precision=self._torch_dtype,
|
||||
|
@ -178,13 +178,14 @@ class ModelProbe(object):
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||
ckpt_config_path = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
variant_type=fields["variant"],
|
||||
prediction_type=fields["prediction_type"],
|
||||
).as_posix()
|
||||
)
|
||||
fields["config_path"] = str(ckpt_config_path)
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
@ -298,23 +299,23 @@ class ModelProbe(object):
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml"
|
||||
"controlnet/cldm_v15.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../controlnet/cldm_v21.yaml"
|
||||
else "controlnet/cldm_v21.yaml"
|
||||
)
|
||||
elif model_type is ModelType.VAE:
|
||||
config_file = (
|
||||
"../stable-diffusion/v1-inference.yaml"
|
||||
"stable-diffusion/v1-inference.yaml"
|
||||
if base_type is BaseModelType.StableDiffusion1
|
||||
else "../stable-diffusion/v2-inference.yaml"
|
||||
else "stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
|
||||
)
|
||||
assert isinstance(config_file, str)
|
||||
return Path(config_file)
|
||||
|
||||
@classmethod
|
||||
|
392
invokeai/backend/model_manager/starter_models.py
Normal file
392
invokeai/backend/model_manager/starter_models.py
Normal file
@ -0,0 +1,392 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarterModel:
|
||||
description: str
|
||||
source: str
|
||||
name: str
|
||||
base: BaseModelType
|
||||
type: ModelType
|
||||
# Optional list of model source dependencies that need to be installed before this model can be used
|
||||
dependencies: Optional[list[str]] = None
|
||||
is_installed: bool = False
|
||||
|
||||
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
# region: Main
|
||||
StarterModel(
|
||||
name="SD 1.5 (base)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="runwayml/stable-diffusion-v1-5",
|
||||
description="Stable Diffusion version 1.5 diffusers model (4.27 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="SD 1.5 (inpainting)",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="runwayml/stable-diffusion-inpainting",
|
||||
description="RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="Analog Diffusion",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="wavymulder/Analog-Diffusion",
|
||||
description="An SD-1.5 model trained on diverse analog photographs (2.13 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="Deliberate v5",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors",
|
||||
description="Versatile model that produces detailed images up to 768px (4.27 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="Dungeons and Diffusion",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="0xJustin/Dungeons-and-Diffusion",
|
||||
description="Dungeons & Dragons characters (2.13 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="dreamlike photoreal v2",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="dreamlike-art/dreamlike-photoreal-2.0",
|
||||
description="A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="Inkpunk Diffusion",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="Envvi/Inkpunk-Diffusion",
|
||||
description='Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)',
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="OpenJourney",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="prompthero/openjourney",
|
||||
description='An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)',
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="seek.art MEGA",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="coreco/seek.art_MEGA",
|
||||
description='A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)',
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="TrinArt v2",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="naclbit/trinart_stable_diffusion_v2",
|
||||
description="An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="SD 2.1 (base)",
|
||||
base=BaseModelType.StableDiffusion2,
|
||||
source="stabilityai/stable-diffusion-2-1",
|
||||
description="Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="SD 2.0 (inpainting)",
|
||||
base=BaseModelType.StableDiffusion2,
|
||||
source="stabilityai/stable-diffusion-2-inpainting",
|
||||
description="Stable Diffusion version 2.0 inpainting model (5.21 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="SDXL (base)",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="stabilityai/stable-diffusion-xl-base-1.0",
|
||||
description="Stable Diffusion XL base model (12 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
StarterModel(
|
||||
name="SDXL Refiner",
|
||||
base=BaseModelType.StableDiffusionXLRefiner,
|
||||
source="stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
description="Stable Diffusion XL refiner model (12 GB)",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
# endregion
|
||||
# region VAE
|
||||
StarterModel(
|
||||
name="sdxl-vae-fp16-fix",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="madebyollin/sdxl-vae-fp16-fix",
|
||||
description="Version of the SDXL-1.0 VAE that works in half precision mode",
|
||||
type=ModelType.VAE,
|
||||
),
|
||||
# endregion
|
||||
# region LoRA
|
||||
StarterModel(
|
||||
name="FlatColor",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://civitai.com/models/6433/loraflatcolor",
|
||||
description="A LoRA that generates scenery using solid blocks of color",
|
||||
type=ModelType.LoRA,
|
||||
),
|
||||
StarterModel(
|
||||
name="Ink scenery",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://civitai.com/api/download/models/83390",
|
||||
description="Generate india ink-like landscapes",
|
||||
type=ModelType.LoRA,
|
||||
),
|
||||
# endregion
|
||||
# region IP Adapter
|
||||
StarterModel(
|
||||
name="IP Adapter",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_sd15",
|
||||
description="IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=["InvokeAI/ip_adapter_sd_image_encoder"],
|
||||
),
|
||||
StarterModel(
|
||||
name="IP Adapter Plus",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_plus_sd15",
|
||||
description="Refined IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=["InvokeAI/ip_adapter_sd_image_encoder"],
|
||||
),
|
||||
StarterModel(
|
||||
name="IP Adapter Plus Face",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_plus_face_sd15",
|
||||
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=["InvokeAI/ip_adapter_sd_image_encoder"],
|
||||
),
|
||||
StarterModel(
|
||||
name="IP Adapter SDXL",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="InvokeAI/ip_adapter_sdxl",
|
||||
description="IP-Adapter for SDXL models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=["InvokeAI/ip_adapter_sdxl_image_encoder"],
|
||||
),
|
||||
# endregion
|
||||
# region ControlNet
|
||||
StarterModel(
|
||||
name="QRCode Monster",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
||||
description="Controlnet model that generates scannable creative QR codes",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_canny",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="inpaint",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="mlsd",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="normal_bae",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="seg",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_seg",
|
||||
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_lineart",
|
||||
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart_anime",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="openpose",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_openpose",
|
||||
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="scribble",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_scribble",
|
||||
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11p_sd15_softedge",
|
||||
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="shuffle",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="tile",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="ip2p",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-canny-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with canny conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="softedge-dexined-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
||||
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-16bit-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-zoe-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
||||
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
StarterModel(
|
||||
name="canny-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_canny_sd15v2",
|
||||
description="T2I Adapter weights trained on sd-1.5 with canny conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
StarterModel(
|
||||
name="sketch-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_sketch_sd15v2",
|
||||
description="T2I Adapter weights trained on sd-1.5 with sketch conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
StarterModel(
|
||||
name="depth-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_depth_sd15v2",
|
||||
description="T2I Adapter weights trained on sd-1.5 with depth conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
StarterModel(
|
||||
name="zoedepth-sd15",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="TencentARC/t2iadapter_zoedepth_sd15v1",
|
||||
description="T2I Adapter weights trained on sd-1.5 with zoe depth conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
StarterModel(
|
||||
name="canny-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-canny-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with canny conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
StarterModel(
|
||||
name="zoedepth-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with zoe depth conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
StarterModel(
|
||||
name="lineart-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-lineart-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with lineart conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
StarterModel(
|
||||
name="sketch-sdxl",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
||||
description="T2I Adapter weights trained on sdxl-1.0 with sketch conditioning.",
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
# endregion
|
||||
# region TI
|
||||
StarterModel(
|
||||
name="EasyNegative",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors",
|
||||
description="A textual inversion to use in the negative prompt to reduce bad anatomy",
|
||||
type=ModelType.TextualInversion,
|
||||
),
|
||||
# endregion
|
||||
]
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
@ -25,8 +25,8 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import normalize_device
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -11,7 +11,7 @@ from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from ...util import torch_dtype
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
|
@ -1,5 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.training
|
||||
"""
|
||||
|
||||
from .textual_inversion_training import do_textual_inversion_training, parse_args # noqa: F401
|
@ -1,924 +0,0 @@
|
||||
# This code was copied from
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
|
||||
# on January 2, 2023
|
||||
# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
|
||||
|
||||
"""
|
||||
This is the backend to "textual_inversion.py"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.install.install_helper import initialize_record_store
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args() -> Namespace:
|
||||
config = get_config()
|
||||
parser = PagingArgumentParser(description="Textual inversion training")
|
||||
general_group = parser.add_argument_group("General")
|
||||
model_group = parser.add_argument_group("Models and Paths")
|
||||
image_group = parser.add_argument_group("Training Image Location and Options")
|
||||
trigger_group = parser.add_argument_group("Trigger Token")
|
||||
training_group = parser.add_argument_group("Training Parameters")
|
||||
checkpointing_group = parser.add_argument_group("Checkpointing and Resume")
|
||||
integration_group = parser.add_argument_group("Integration")
|
||||
general_group.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
general_group.add_argument(
|
||||
"--root_dir",
|
||||
"--root",
|
||||
type=Path,
|
||||
default=config.root_path,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
general_group.add_argument(
|
||||
"--logging_dir",
|
||||
type=Path,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
general_group.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f"{config.root_path}/text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="sd-1/main/stable-diffusion-v1-5",
|
||||
help="Name of the diffusers model to train against.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
|
||||
model_group.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--train_data_dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="A folder containing the training data.",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--center_crop",
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution",
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--placeholder_token",
|
||||
"--trigger_term",
|
||||
dest="placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help='A token to use as a placeholder for the concept. This token will trigger the concept when included in the prompt as "<trigger>".',
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--learnable_property",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="Choose between 'object' and 'style'",
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="*",
|
||||
help="A symbol to use as the initializer word.",
|
||||
)
|
||||
checkpointing_group.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
checkpointing_group.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
checkpointing_group.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--repeats",
|
||||
type=int,
|
||||
default=100,
|
||||
help="How many times to repeat the training data.",
|
||||
)
|
||||
training_group.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
training_group.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
training_group.add_argument("--num_train_epochs", type=int, default=100)
|
||||
training_group.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The beta1 parameter for the Adam optimizer.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_beta2",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="The beta2 parameter for the Adam optimizer.",
|
||||
)
|
||||
training_group.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
training_group.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-08,
|
||||
help="Epsilon value for the Adam optimizer",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="For distributed training: local_rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention",
|
||||
action="store_true",
|
||||
help="Whether or not to use xformers.",
|
||||
)
|
||||
|
||||
integration_group.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
integration_group.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
integration_group.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
integration_group.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the model to the Hub.",
|
||||
)
|
||||
integration_group.add_argument(
|
||||
"--hub_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The token to use to push to the Model Hub.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = Path(data_root)
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [
|
||||
self.data_root / file_path
|
||||
for file_path in self.data_root.iterdir()
|
||||
if file_path.is_file()
|
||||
and file_path.name.endswith((".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG", ".gif", ".GIF"))
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL_INTERPOLATION["linear"],
|
||||
"bilinear": PIL_INTERPOLATION["bilinear"],
|
||||
"bicubic": PIL_INTERPOLATION["bicubic"],
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def do_textual_inversion_training(
|
||||
config: InvokeAIAppConfig,
|
||||
model: str,
|
||||
train_data_dir: Path,
|
||||
output_dir: Path,
|
||||
placeholder_token: str,
|
||||
initializer_token: str,
|
||||
save_steps: int = 500,
|
||||
only_save_embeds: bool = False,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
learnable_property: str = "object",
|
||||
repeats: int = 100,
|
||||
seed: Optional[int] = None,
|
||||
resolution: int = 512,
|
||||
center_crop: bool = False,
|
||||
train_batch_size: int = 16,
|
||||
num_train_epochs: int = 100,
|
||||
max_train_steps: int = 5000,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
gradient_checkpointing: bool = False,
|
||||
learning_rate: float = 1e-4,
|
||||
scale_lr: bool = True,
|
||||
lr_scheduler: str = "constant",
|
||||
lr_warmup_steps: int = 500,
|
||||
adam_beta1: float = 0.9,
|
||||
adam_beta2: float = 0.999,
|
||||
adam_weight_decay: float = 1e-02,
|
||||
adam_epsilon: float = 1e-08,
|
||||
push_to_hub: bool = False,
|
||||
hub_token: Optional[str] = None,
|
||||
logging_dir: Path = Path("logs"),
|
||||
mixed_precision: str = "fp16",
|
||||
allow_tf32: bool = False,
|
||||
report_to: str = "tensorboard",
|
||||
local_rank: int = -1,
|
||||
checkpointing_steps: int = 500,
|
||||
resume_from_checkpoint: Optional[Path] = None,
|
||||
enable_xformers_memory_efficient_attention: bool = False,
|
||||
hub_model_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert model, "Please specify a base model with --model"
|
||||
assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir"
|
||||
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != local_rank:
|
||||
local_rank = env_local_rank
|
||||
|
||||
# setting up things the way invokeai expects them
|
||||
if not os.path.isabs(output_dir):
|
||||
output_dir = config.root_path / output_dir
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
accelerator_config = ProjectConfiguration()
|
||||
accelerator_config.logging_dir = logging_dir
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
mixed_precision=mixed_precision,
|
||||
log_with=report_to,
|
||||
project_config=accelerator_config,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if push_to_hub:
|
||||
if hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(output_dir).name, token=hub_token)
|
||||
else:
|
||||
repo_name = hub_model_id
|
||||
repo = Repository(output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
model_records = initialize_record_store(config)
|
||||
base, type, name = model.split("/") # note frontend still returns old-style keys
|
||||
try:
|
||||
model_config = model_records.search_by_attr(
|
||||
model_name=name, model_type=ModelType(type), base_model=BaseModelType(base)
|
||||
)[0]
|
||||
except IndexError:
|
||||
raise Exception(f"Unknown model {model}")
|
||||
model_path = config.models_path / model_config.path
|
||||
|
||||
pipeline_args = {"local_files_only": True}
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", **pipeline_args)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler", **pipeline_args)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="text_encoder",
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
model_path,
|
||||
subfolder="vae",
|
||||
**pipeline_args,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="unet",
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(
|
||||
f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}"
|
||||
)
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# Freeze vae and unet
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
||||
unet.train()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if scale_lr:
|
||||
learning_rate = learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
betas=(adam_beta1, adam_beta2),
|
||||
weight_decay=adam_weight_decay,
|
||||
eps=adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=train_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=resolution,
|
||||
placeholder_token=placeholder_token,
|
||||
repeats=repeats,
|
||||
learnable_property=learnable_property,
|
||||
center_crop=center_crop,
|
||||
set="train",
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
if max_train_steps is None:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
scheduler = get_scheduler(
|
||||
lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae and unet to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
params = locals()
|
||||
for k in params: # init_trackers() doesn't like objects
|
||||
params[k] = str(params[k]) if isinstance(params[k], object) else params[k]
|
||||
accelerator.init_trackers("textual_inversion", config=params)
|
||||
|
||||
# Train!
|
||||
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
resume_step = None
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if resume_from_checkpoint:
|
||||
if resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||
resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(global_step, max_train_steps),
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if resume_step and resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
||||
orig_embeds_params[index_no_updates]
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % save_steps == 0:
|
||||
save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
accelerator,
|
||||
placeholder_token,
|
||||
save_path,
|
||||
)
|
||||
|
||||
if global_step % checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if push_to_hub and only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
**pipeline_args,
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(output_dir, "learned_embeds.bin")
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
accelerator,
|
||||
placeholder_token,
|
||||
save_path,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
@ -2,32 +2,14 @@
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from .attention import auto_detect_slice_size # noqa: F401
|
||||
from .devices import ( # noqa: F401
|
||||
CPU_DEVICE,
|
||||
CUDA_DEVICE,
|
||||
MPS_DEVICE,
|
||||
choose_precision,
|
||||
choose_torch_device,
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
)
|
||||
from .devices import choose_precision, choose_torch_device
|
||||
from .logging import InvokeAILogger
|
||||
from .util import ( # TO DO: Clean this up; remove the unused symbols
|
||||
GIG,
|
||||
Chdir,
|
||||
ask_user, # noqa
|
||||
directory_size,
|
||||
download_with_resume,
|
||||
instantiate_from_config, # noqa
|
||||
url_attachment_name, # noqa
|
||||
)
|
||||
from .util import GIG, Chdir, directory_size
|
||||
|
||||
__all__ = [
|
||||
"GIG",
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"download_with_resume",
|
||||
"InvokeAILogger",
|
||||
"choose_precision",
|
||||
"choose_torch_device",
|
||||
|
@ -1,67 +0,0 @@
|
||||
"""
|
||||
Functions for better format logging
|
||||
write_log -- logs the name of the output image, prompt, and prompt args to the terminal and different types of file
|
||||
1 write_log_message -- Writes a message to the console
|
||||
2 write_log_files -- Writes a message to files
|
||||
2.1 write_log_default -- File in plain text
|
||||
2.2 write_log_txt -- File in txt format
|
||||
2.3 write_log_markdown -- File in markdown format
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def write_log(results, log_path, file_types, output_cntr):
|
||||
"""
|
||||
logs the name of the output image, prompt, and prompt args to the terminal and files
|
||||
"""
|
||||
output_cntr = write_log_message(results, output_cntr)
|
||||
write_log_files(results, log_path, file_types)
|
||||
return output_cntr
|
||||
|
||||
|
||||
def write_log_message(results, output_cntr):
|
||||
"""logs to the terminal"""
|
||||
if len(results) == 0:
|
||||
return output_cntr
|
||||
log_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
if len(log_lines) > 1:
|
||||
subcntr = 1
|
||||
for ll in log_lines:
|
||||
print(f"[{output_cntr}.{subcntr}] {ll}", end="")
|
||||
subcntr += 1
|
||||
else:
|
||||
print(f"[{output_cntr}] {log_lines[0]}", end="")
|
||||
return output_cntr + 1
|
||||
|
||||
|
||||
def write_log_files(results, log_path, file_types):
|
||||
for file_type in file_types:
|
||||
if file_type == "txt":
|
||||
write_log_txt(log_path, results)
|
||||
elif file_type == "md" or file_type == "markdown":
|
||||
write_log_markdown(log_path, results)
|
||||
else:
|
||||
print(f"'{file_type}' format is not supported, so write in plain text")
|
||||
write_log_default(log_path, results, file_type)
|
||||
|
||||
|
||||
def write_log_default(log_path, results, file_type):
|
||||
plain_txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
with open(log_path + "." + file_type, "a", encoding="utf-8") as file:
|
||||
file.writelines(plain_txt_lines)
|
||||
|
||||
|
||||
def write_log_txt(log_path, results):
|
||||
txt_lines = [f"{path}: {prompt}\n" for path, prompt in results]
|
||||
with open(log_path + ".txt", "a", encoding="utf-8") as file:
|
||||
file.writelines(txt_lines)
|
||||
|
||||
|
||||
def write_log_markdown(log_path, results):
|
||||
md_lines = []
|
||||
for path, prompt in results:
|
||||
file_name = os.path.basename(path)
|
||||
md_lines.append(f"## {file_name}\n![]({file_name})\n\n{prompt}\n")
|
||||
with open(log_path + ".md", "a", encoding="utf-8") as file:
|
||||
file.writelines(md_lines)
|
@ -1,29 +1,13 @@
|
||||
import base64
|
||||
import importlib
|
||||
import io
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import abc
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import logging as diffusers_logging
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .devices import torch_dtype
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
@ -41,340 +25,6 @@ def directory_size(directory: Path) -> int:
|
||||
return sum
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = []
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.load_default()
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config, **kwargs):
|
||||
if "target" not in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", {}), **kwargs)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
# create dummy dataset instance
|
||||
|
||||
# run prefetching
|
||||
if idx_to_fn:
|
||||
res = func(data, worker_id=idx)
|
||||
else:
|
||||
res = func(data)
|
||||
Q.put([idx, res])
|
||||
Q.put("Done")
|
||||
|
||||
|
||||
def parallel_data_prefetch(
|
||||
func: callable,
|
||||
data,
|
||||
n_proc,
|
||||
target_data_type="ndarray",
|
||||
cpu_intensive=True,
|
||||
use_worker_id=False,
|
||||
):
|
||||
# if target_data_type not in ["ndarray", "list"]:
|
||||
# raise ValueError(
|
||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||
# )
|
||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
logger.warning(
|
||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = list(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||
)
|
||||
|
||||
if cpu_intensive:
|
||||
Q = mp.Queue(1000)
|
||||
proc = mp.Process
|
||||
else:
|
||||
Q = Queue(1000)
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == "ndarray":
|
||||
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
|
||||
else:
|
||||
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
|
||||
]
|
||||
processes = []
|
||||
for i in range(n_proc):
|
||||
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
logger.info("Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
gather_res = [[] for _ in range(n_proc)]
|
||||
try:
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
k = 0
|
||||
while k < n_proc:
|
||||
# get result
|
||||
res = Q.get()
|
||||
if res == "Done":
|
||||
k += 1
|
||||
else:
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
raise e
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == "ndarray":
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||
|
||||
# order outputs
|
||||
return np.concatenate(gather_res, axis=0)
|
||||
elif target_data_type == "list":
|
||||
out = []
|
||||
for r in gather_res:
|
||||
out.extend(r)
|
||||
return out
|
||||
else:
|
||||
return gather_res
|
||||
|
||||
|
||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = (
|
||||
torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(0, res[0], delta[0]),
|
||||
torch.arange(0, res[1], delta[1]),
|
||||
indexing="ij",
|
||||
),
|
||||
dim=-1,
|
||||
).to(device)
|
||||
% 1
|
||||
)
|
||||
|
||||
rand_val = torch.rand(res[0] + 1, res[1] + 1)
|
||||
|
||||
angles = 2 * math.pi * rand_val
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
|
||||
|
||||
def tile_grads(slice1, slice2):
|
||||
return (
|
||||
gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
|
||||
def dot(grad, shift):
|
||||
return (
|
||||
torch.stack(
|
||||
(
|
||||
grid[: shape[0], : shape[1], 0] + shift[0],
|
||||
grid[: shape[0], : shape[1], 1] + shift[1],
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
|
||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
|
||||
t = fade(grid[: shape[0], : shape[1]])
|
||||
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(
|
||||
device
|
||||
)
|
||||
return noise.to(dtype=torch_dtype(device))
|
||||
|
||||
|
||||
def ask_user(question: str, answers: list):
|
||||
from itertools import chain, repeat
|
||||
|
||||
user_prompt = f"\n>> {question} {answers}: "
|
||||
invalid_answer_msg = "Invalid answer. Please try again."
|
||||
pose_question = chain([user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])))
|
||||
user_answers = map(input, pose_question)
|
||||
valid_response = next(filter(answers.__contains__, user_answers))
|
||||
return valid_response
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
|
||||
"""
|
||||
Download a model file.
|
||||
:param url: https, http or ftp URL
|
||||
:param dest: A Path object. If path exists and is a directory, then we try to derive the filename
|
||||
from the URL's Content-Disposition header and copy the URL contents into
|
||||
dest/filename
|
||||
:param access_token: Access token to access this resource
|
||||
"""
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True, allow_redirects=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if dest.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
|
||||
except AttributeError:
|
||||
file_name = os.path.basename(url)
|
||||
dest = dest / file_name
|
||||
else:
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest.exists():
|
||||
exist_size = dest.stat().st_size
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
return dest
|
||||
elif resp.status_code == 206 or exist_size > 0:
|
||||
logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
||||
else:
|
||||
logger.info(f"{dest}: Downloading...")
|
||||
|
||||
try:
|
||||
if content_length < 2000:
|
||||
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
||||
return None
|
||||
|
||||
with (
|
||||
open(dest, open_mode) as file,
|
||||
tqdm(
|
||||
desc=str(dest),
|
||||
initial=exist_size,
|
||||
total=content_length,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar,
|
||||
):
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
||||
return None
|
||||
|
||||
return dest
|
||||
|
||||
|
||||
def url_attachment_name(url: str) -> dict:
|
||||
try:
|
||||
resp = requests.get(url, stream=True)
|
||||
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
|
||||
return match.group(1)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def download_with_progress_bar(url: str, dest: Path) -> bool:
|
||||
result = download_with_resume(url, dest, access_token=None)
|
||||
return result is not None
|
||||
|
||||
|
||||
def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str:
|
||||
"""
|
||||
Converts an image into a base64 image dataURL.
|
||||
|
@ -1,189 +0,0 @@
|
||||
# This file predefines a few models that the user may want to install.
|
||||
sd-1/main/stable-diffusion-v1-5:
|
||||
description: Stable Diffusion version 1.5 diffusers model (4.27 GB)
|
||||
source: runwayml/stable-diffusion-v1-5
|
||||
recommended: True
|
||||
default: True
|
||||
sd-1/main/stable-diffusion-v1-5-inpainting:
|
||||
description: RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)
|
||||
source: runwayml/stable-diffusion-inpainting
|
||||
recommended: True
|
||||
sd-2/main/stable-diffusion-2-1:
|
||||
description: Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)
|
||||
source: stabilityai/stable-diffusion-2-1
|
||||
recommended: False
|
||||
sd-2/main/stable-diffusion-2-inpainting:
|
||||
description: Stable Diffusion version 2.0 inpainting model (5.21 GB)
|
||||
source: stabilityai/stable-diffusion-2-inpainting
|
||||
recommended: False
|
||||
sdxl/main/stable-diffusion-xl-base-1-0:
|
||||
description: Stable Diffusion XL base model (12 GB)
|
||||
source: stabilityai/stable-diffusion-xl-base-1.0
|
||||
recommended: True
|
||||
sdxl-refiner/main/stable-diffusion-xl-refiner-1-0:
|
||||
description: Stable Diffusion XL refiner model (12 GB)
|
||||
source: stabilityai/stable-diffusion-xl-refiner-1.0
|
||||
recommended: False
|
||||
sdxl/vae/sdxl-vae-fp16-fix:
|
||||
description: Version of the SDXL-1.0 VAE that works in half precision mode
|
||||
source: madebyollin/sdxl-vae-fp16-fix
|
||||
recommended: True
|
||||
sd-1/main/Analog-Diffusion:
|
||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||
source: wavymulder/Analog-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/Deliberate:
|
||||
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
||||
source: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors?download=true
|
||||
recommended: False
|
||||
sd-1/main/Dungeons-and-Diffusion:
|
||||
description: Dungeons & Dragons characters (2.13 GB)
|
||||
source: 0xJustin/Dungeons-and-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/dreamlike-photoreal-2:
|
||||
description: A photorealistic model trained on 768 pixel images based on SD 1.5 (2.13 GB)
|
||||
source: dreamlike-art/dreamlike-photoreal-2.0
|
||||
recommended: False
|
||||
sd-1/main/Inkpunk-Diffusion:
|
||||
description: Stylized illustrations inspired by Gorillaz, FLCL and Shinkawa; prompt with "nvinkpunk" (4.27 GB)
|
||||
source: Envvi/Inkpunk-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/openjourney:
|
||||
description: An SD 1.5 model fine tuned on Midjourney; prompt with "mdjrny-v4 style" (2.13 GB)
|
||||
source: prompthero/openjourney
|
||||
recommended: False
|
||||
sd-1/main/seek.art_MEGA:
|
||||
source: coreco/seek.art_MEGA
|
||||
description: A general use SD-1.5 "anything" model that supports multiple styles (2.1 GB)
|
||||
recommended: False
|
||||
sd-1/main/trinart_stable_diffusion_v2:
|
||||
description: An SD-1.5 model finetuned with ~40K assorted high resolution manga/anime-style images (2.13 GB)
|
||||
source: naclbit/trinart_stable_diffusion_v2
|
||||
recommended: False
|
||||
sd-1/controlnet/qrcode_monster:
|
||||
source: monster-labs/control_v1p_sd15_qrcode_monster
|
||||
description: Controlnet model that generates scannable creative QR codes
|
||||
subfolder: v2
|
||||
sd-1/controlnet/canny:
|
||||
description: Controlnet weights trained on sd-1.5 with canny conditioning.
|
||||
source: lllyasviel/control_v11p_sd15_canny
|
||||
recommended: True
|
||||
sd-1/controlnet/inpaint:
|
||||
source: lllyasviel/control_v11p_sd15_inpaint
|
||||
description: Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version
|
||||
sd-1/controlnet/mlsd:
|
||||
description: Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version
|
||||
source: lllyasviel/control_v11p_sd15_mlsd
|
||||
sd-1/controlnet/depth:
|
||||
description: Controlnet weights trained on sd-1.5 with depth conditioning
|
||||
source: lllyasviel/control_v11f1p_sd15_depth
|
||||
recommended: True
|
||||
sd-1/controlnet/normal_bae:
|
||||
description: Controlnet weights trained on sd-1.5 with normalbae image conditioning
|
||||
source: lllyasviel/control_v11p_sd15_normalbae
|
||||
sd-1/controlnet/seg:
|
||||
description: Controlnet weights trained on sd-1.5 with seg image conditioning
|
||||
source: lllyasviel/control_v11p_sd15_seg
|
||||
sd-1/controlnet/lineart:
|
||||
description: Controlnet weights trained on sd-1.5 with lineart image conditioning
|
||||
source: lllyasviel/control_v11p_sd15_lineart
|
||||
recommended: True
|
||||
sd-1/controlnet/lineart_anime:
|
||||
description: Controlnet weights trained on sd-1.5 with anime image conditioning
|
||||
source: lllyasviel/control_v11p_sd15s2_lineart_anime
|
||||
sd-1/controlnet/openpose:
|
||||
description: Controlnet weights trained on sd-1.5 with openpose image conditioning
|
||||
source: lllyasviel/control_v11p_sd15_openpose
|
||||
recommended: True
|
||||
sd-1/controlnet/scribble:
|
||||
source: lllyasviel/control_v11p_sd15_scribble
|
||||
description: Controlnet weights trained on sd-1.5 with scribble image conditioning
|
||||
recommended: False
|
||||
sd-1/controlnet/softedge:
|
||||
source: lllyasviel/control_v11p_sd15_softedge
|
||||
description: Controlnet weights trained on sd-1.5 with soft edge conditioning
|
||||
sd-1/controlnet/shuffle:
|
||||
source: lllyasviel/control_v11e_sd15_shuffle
|
||||
description: Controlnet weights trained on sd-1.5 with shuffle image conditioning
|
||||
sd-1/controlnet/tile:
|
||||
source: lllyasviel/control_v11f1e_sd15_tile
|
||||
description: Controlnet weights trained on sd-1.5 with tiled image conditioning
|
||||
sd-1/controlnet/ip2p:
|
||||
source: lllyasviel/control_v11e_sd15_ip2p
|
||||
description: Controlnet weights trained on sd-1.5 with ip2p conditioning.
|
||||
sdxl/controlnet/canny-sdxl:
|
||||
description: Controlnet weights trained on sdxl-1.0 with canny conditioning.
|
||||
source: diffusers/controlnet-canny-sdxl-1.0
|
||||
recommended: True
|
||||
sdxl/controlnet/depth-sdxl:
|
||||
description: Controlnet weights trained on sdxl-1.0 with depth conditioning.
|
||||
source: diffusers/controlnet-depth-sdxl-1.0
|
||||
recommended: True
|
||||
sdxl/controlnet/softedge-dexined-sdxl:
|
||||
description: Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.
|
||||
source: SargeZT/controlnet-sd-xl-1.0-softedge-dexined
|
||||
sdxl/controlnet/depth-16bit-zoe-sdxl:
|
||||
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).
|
||||
source: SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe
|
||||
sdxl/controlnet/depth-zoe-sdxl:
|
||||
description: Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).
|
||||
source: diffusers/controlnet-zoe-depth-sdxl-1.0
|
||||
sd-1/t2i_adapter/canny-sd15:
|
||||
source: TencentARC/t2iadapter_canny_sd15v2
|
||||
sd-1/t2i_adapter/sketch-sd15:
|
||||
source: TencentARC/t2iadapter_sketch_sd15v2
|
||||
sd-1/t2i_adapter/depth-sd15:
|
||||
source: TencentARC/t2iadapter_depth_sd15v2
|
||||
sd-1/t2i_adapter/zoedepth-sd15:
|
||||
source: TencentARC/t2iadapter_zoedepth_sd15v1
|
||||
sdxl/t2i_adapter/canny-sdxl:
|
||||
source: TencentARC/t2i-adapter-canny-sdxl-1.0
|
||||
sdxl/t2i_adapter/zoedepth-sdxl:
|
||||
source: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
|
||||
sdxl/t2i_adapter/lineart-sdxl:
|
||||
source: TencentARC/t2i-adapter-lineart-sdxl-1.0
|
||||
sdxl/t2i_adapter/sketch-sdxl:
|
||||
source: TencentARC/t2i-adapter-sketch-sdxl-1.0
|
||||
sd-1/embedding/EasyNegative:
|
||||
source: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
|
||||
recommended: True
|
||||
description: A textual inversion to use in the negative prompt to reduce bad anatomy
|
||||
sd-1/lora/FlatColor:
|
||||
source: https://civitai.com/models/6433/loraflatcolor
|
||||
recommended: True
|
||||
description: A LoRA that generates scenery using solid blocks of color
|
||||
sd-1/lora/Ink scenery:
|
||||
source: https://civitai.com/api/download/models/83390
|
||||
description: Generate india ink-like landscapes
|
||||
sd-1/ip_adapter/ip_adapter_sd15:
|
||||
source: InvokeAI/ip_adapter_sd15
|
||||
recommended: True
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: IP-Adapter for SD 1.5 models
|
||||
sd-1/ip_adapter/ip_adapter_plus_sd15:
|
||||
source: InvokeAI/ip_adapter_plus_sd15
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: Refined IP-Adapter for SD 1.5 models
|
||||
sd-1/ip_adapter/ip_adapter_plus_face_sd15:
|
||||
source: InvokeAI/ip_adapter_plus_face_sd15
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sd_image_encoder
|
||||
description: Refined IP-Adapter for SD 1.5 models, adapted for faces
|
||||
sdxl/ip_adapter/ip_adapter_sdxl:
|
||||
source: InvokeAI/ip_adapter_sdxl
|
||||
recommended: False
|
||||
requires:
|
||||
- InvokeAI/ip_adapter_sdxl_image_encoder
|
||||
description: IP-Adapter for SDXL models
|
||||
any/clip_vision/ip_adapter_sd_image_encoder:
|
||||
source: InvokeAI/ip_adapter_sd_image_encoder
|
||||
recommended: False
|
||||
description: Required model for using IP-Adapters with SD-1/2 models
|
||||
any/clip_vision/ip_adapter_sdxl_image_encoder:
|
||||
source: InvokeAI/ip_adapter_sdxl_image_encoder
|
||||
recommended: False
|
||||
description: Required model for using IP-Adapters with SDXL models
|
@ -3,19 +3,16 @@ from typing import Optional
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
_root_help = r"""Sets a root directory for the app.
|
||||
If omitted, the app will search for the root directory in the following order:
|
||||
_root_help = r"""Path to the runtime root directory. If omitted, the app will search for the root directory in the following order:
|
||||
- The `$INVOKEAI_ROOT` environment variable
|
||||
- The currently active virtual environment's parent directory
|
||||
- `$HOME/invokeai`"""
|
||||
|
||||
_ignore_missing_core_models_help = r"""If set, the app will ignore missing core diffusers conversion models.
|
||||
These are required to use checkpoint/safetensors models.
|
||||
If you only use diffusers models, you can safely enable this."""
|
||||
_config_file_help = r"""Path to the invokeai.yaml configuration file. If omitted, the app will search for the file in the root directory."""
|
||||
|
||||
_parser = ArgumentParser(description="Invoke Studio", formatter_class=RawTextHelpFormatter)
|
||||
_parser.add_argument("--root", type=str, help=_root_help)
|
||||
_parser.add_argument("--ignore_missing_core_models", action="store_true", help=_ignore_missing_core_models_help)
|
||||
_parser.add_argument("--config", dest="config_file", type=str, help=_config_file_help)
|
||||
_parser.add_argument("--version", action="version", version=__version__, help="Displays the version and exits.")
|
||||
|
||||
|
||||
@ -39,9 +36,11 @@ class InvokeAIArgs:
|
||||
"""
|
||||
|
||||
args: Optional[Namespace] = None
|
||||
did_parse: bool = False
|
||||
|
||||
@staticmethod
|
||||
def parse_args() -> Optional[Namespace]:
|
||||
"""Parse CLI args and store the result."""
|
||||
InvokeAIArgs.args = _parser.parse_args()
|
||||
InvokeAIArgs.did_parse = True
|
||||
return InvokeAIArgs.args
|
||||
|
@ -1,60 +0,0 @@
|
||||
"""
|
||||
Wrapper for invokeai.backend.configure.invokeai_configure
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def run_configure() -> None:
|
||||
# Before doing _anything_, parse CLI args!
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--skip-sd-weights",
|
||||
dest="skip_sd_weights",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="skip downloading the large Stable Diffusion weight files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-support-models",
|
||||
dest="skip_support_models",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="skip downloading the support models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="when --yes specified, only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
InvokeAIArgs.args = opt
|
||||
|
||||
from invokeai.backend.install.invokeai_configure import main as invokeai_configure
|
||||
|
||||
invokeai_configure(opt)
|
@ -1,652 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
|
||||
"""
|
||||
This is the npyscreen frontend to the model installation application.
|
||||
It is currently named model_install2.py, but will ultimately replace model_install.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import pathlib
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from shutil import get_terminal_size
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import npyscreen
|
||||
import torch
|
||||
from npyscreen import widget
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.install.check_directories import validate_directories
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import (
|
||||
MIN_COLS,
|
||||
MIN_LINES,
|
||||
CenteredTitleText,
|
||||
CyclingForm,
|
||||
MultiSelectColumns,
|
||||
SingleSelectColumns,
|
||||
TextBox,
|
||||
WindowTooSmallException,
|
||||
set_min_terminal_size,
|
||||
)
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning) # noqa: E402
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger("ModelInstallService", config=config)
|
||||
# logger.setLevel("WARNING")
|
||||
# logger.setLevel('DEBUG')
|
||||
|
||||
# build a table mapping all non-printable characters to None
|
||||
# for stripping control characters
|
||||
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
|
||||
NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()}
|
||||
|
||||
# maximum number of installed models we can display before overflowing vertically
|
||||
MAX_OTHER_MODELS = 72
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
"""Replace non-printable characters in a string."""
|
||||
return s.translate(NOPRINT_TRANS_TABLE)
|
||||
|
||||
|
||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"""Main form for interactive TUI."""
|
||||
|
||||
# for responsive resizing set to False, but this seems to cause a crash!
|
||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||
|
||||
# for persistence
|
||||
current_tab = 0
|
||||
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, multipage: bool = False, **keywords: Any):
|
||||
self.multipage = multipage
|
||||
self.subprocess = None
|
||||
super().__init__(parentApp=parentApp, name=name, **keywords)
|
||||
|
||||
def create(self) -> None:
|
||||
self.installer = self.parentApp.install_helper.installer
|
||||
self.model_labels = self._get_model_labels()
|
||||
self.keypress_timeout = 10
|
||||
self.counter = 0
|
||||
self.subprocess_connection = None
|
||||
|
||||
window_width, window_height = get_terminal_size()
|
||||
|
||||
# npyscreen has no typing hints
|
||||
self.nextrely -= 1 # type: ignore
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields. Cursor keys navigate, and <space> selects.",
|
||||
editable=False,
|
||||
color="CAUTION",
|
||||
)
|
||||
self.nextrely += 1 # type: ignore
|
||||
self.tabs = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[
|
||||
"STARTERS",
|
||||
"MAINS",
|
||||
"CONTROLNETS",
|
||||
"T2I-ADAPTERS",
|
||||
"IP-ADAPTERS",
|
||||
"LORAS",
|
||||
"TI EMBEDDINGS",
|
||||
],
|
||||
value=[self.current_tab],
|
||||
columns=7,
|
||||
max_height=2,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.tabs.on_changed = self._toggle_tables
|
||||
|
||||
top_of_table = self.nextrely # type: ignore
|
||||
self.starter_pipelines = self.add_starter_pipelines()
|
||||
bottom_of_table = self.nextrely # type: ignore
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.pipeline_models = self.add_pipeline_widgets(
|
||||
model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models
|
||||
)
|
||||
# self.pipeline_models['autoload_pending'] = True
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.controlnet_models = self.add_model_widgets(
|
||||
model_type=ModelType.ControlNet,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.t2i_models = self.add_model_widgets(
|
||||
model_type=ModelType.T2IAdapter,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
self.nextrely = top_of_table
|
||||
self.ipadapter_models = self.add_model_widgets(
|
||||
model_type=ModelType.IPAdapter,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.lora_models = self.add_model_widgets(
|
||||
model_type=ModelType.LoRA,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = top_of_table
|
||||
self.ti_models = self.add_model_widgets(
|
||||
model_type=ModelType.TextualInversion,
|
||||
window_width=window_width,
|
||||
)
|
||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||
|
||||
self.nextrely = bottom_of_table + 1
|
||||
|
||||
self.nextrely += 1
|
||||
back_label = "BACK"
|
||||
cancel_label = "CANCEL"
|
||||
current_position = self.nextrely
|
||||
if self.multipage:
|
||||
self.back_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=back_label,
|
||||
when_pressed_function=self.on_back,
|
||||
)
|
||||
else:
|
||||
self.nextrely = current_position
|
||||
self.cancel_button = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel
|
||||
)
|
||||
self.nextrely = current_position
|
||||
|
||||
label = "APPLY CHANGES"
|
||||
self.nextrely = current_position
|
||||
self.done = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name=label,
|
||||
relx=window_width - len(label) - 15,
|
||||
when_pressed_function=self.on_done,
|
||||
)
|
||||
|
||||
# This restores the selected page on return from an installation
|
||||
for _i in range(1, self.current_tab + 1):
|
||||
self.tabs.h_cursor_line_down(1)
|
||||
self._toggle_tables([self.current_tab])
|
||||
|
||||
############# diffusers tab ##########
|
||||
def add_starter_pipelines(self) -> dict[str, npyscreen.widget]:
|
||||
"""Add widgets responsible for selecting diffusers models"""
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
|
||||
all_models = self.all_models # master dict of all models, indexed by key
|
||||
model_list = [x for x in self.starter_models if all_models[x].type in ["main", "vae"]]
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name="Select from a starter set of Stable Diffusion models from HuggingFace and Civitae.",
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
)
|
||||
|
||||
self.nextrely -= 1
|
||||
# if user has already installed some initial models, then don't patronize them
|
||||
# by showing more recommendations
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
|
||||
checked = [
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
]
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=1,
|
||||
name="Install Starter Models",
|
||||
values=model_labels,
|
||||
value=checked,
|
||||
max_height=len(model_list) + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=model_list,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
return widgets
|
||||
|
||||
############# Add a set of model install widgets ########
|
||||
def add_model_widgets(
|
||||
self,
|
||||
model_type: ModelType,
|
||||
window_width: int = 120,
|
||||
install_prompt: Optional[str] = None,
|
||||
exclude: Optional[Set[str]] = None,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Generic code to create model selection widgets"""
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
widgets: Dict[str, npyscreen.widget] = {}
|
||||
all_models = self.all_models
|
||||
model_list = sorted(
|
||||
[x for x in all_models if all_models[x].type == model_type and x not in exclude],
|
||||
key=lambda x: all_models[x].name or "",
|
||||
)
|
||||
model_labels = [self.model_labels[x] for x in model_list]
|
||||
|
||||
show_recommended = len(self.installed_models) == 0
|
||||
truncated = False
|
||||
if len(model_list) > 0:
|
||||
max_width = max([len(x) for x in model_labels])
|
||||
columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding
|
||||
columns = min(len(model_list), columns) or 1
|
||||
prompt = (
|
||||
install_prompt
|
||||
or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk."
|
||||
)
|
||||
|
||||
widgets.update(
|
||||
label1=self.add_widget_intelligent(
|
||||
CenteredTitleText,
|
||||
name=prompt,
|
||||
editable=False,
|
||||
labelColor="CAUTION",
|
||||
)
|
||||
)
|
||||
|
||||
if len(model_labels) > MAX_OTHER_MODELS:
|
||||
model_labels = model_labels[0:MAX_OTHER_MODELS]
|
||||
truncated = True
|
||||
|
||||
widgets.update(
|
||||
models_selected=self.add_widget_intelligent(
|
||||
MultiSelectColumns,
|
||||
columns=columns,
|
||||
name=f"Install {model_type} Models",
|
||||
values=model_labels,
|
||||
value=[
|
||||
model_list.index(x)
|
||||
for x in model_list
|
||||
if (show_recommended and all_models[x].recommended) or all_models[x].installed
|
||||
],
|
||||
max_height=len(model_list) // columns + 1,
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
),
|
||||
models=model_list,
|
||||
)
|
||||
|
||||
if truncated:
|
||||
widgets.update(
|
||||
warning_message=self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value=f"Too many models to display (max={MAX_OTHER_MODELS}). Some are not displayed.",
|
||||
editable=False,
|
||||
color="CAUTION",
|
||||
)
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
widgets.update(
|
||||
download_ids=self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
||||
max_height=6,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
)
|
||||
)
|
||||
return widgets
|
||||
|
||||
### Tab for arbitrary diffusers widgets ###
|
||||
def add_pipeline_widgets(
|
||||
self,
|
||||
model_type: ModelType = ModelType.Main,
|
||||
window_width: int = 120,
|
||||
**kwargs,
|
||||
) -> dict[str, npyscreen.widget]:
|
||||
"""Similar to add_model_widgets() but adds some additional widgets at the bottom
|
||||
to support the autoload directory"""
|
||||
widgets = self.add_model_widgets(
|
||||
model_type=model_type,
|
||||
window_width=window_width,
|
||||
install_prompt=f"Installed {model_type.value.title()} models. Unchecked models in the InvokeAI root directory will be deleted. Enter URLs, paths or repo_ids to import.",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return widgets
|
||||
|
||||
def resize(self) -> None:
|
||||
super().resize()
|
||||
if s := self.starter_pipelines.get("models_selected"):
|
||||
if model_list := self.starter_pipelines.get("models"):
|
||||
s.values = [self.model_labels[x] for x in model_list]
|
||||
|
||||
def _toggle_tables(self, value: List[int]) -> None:
|
||||
selected_tab = value[0]
|
||||
widgets = [
|
||||
self.starter_pipelines,
|
||||
self.pipeline_models,
|
||||
self.controlnet_models,
|
||||
self.t2i_models,
|
||||
self.ipadapter_models,
|
||||
self.lora_models,
|
||||
self.ti_models,
|
||||
]
|
||||
|
||||
for group in widgets:
|
||||
for _k, v in group.items():
|
||||
try:
|
||||
v.hidden = True
|
||||
v.editable = False
|
||||
except Exception:
|
||||
pass
|
||||
for _k, v in widgets[selected_tab].items():
|
||||
try:
|
||||
v.hidden = False
|
||||
if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)):
|
||||
v.editable = True
|
||||
except Exception:
|
||||
pass
|
||||
self.__class__.current_tab = selected_tab # for persistence
|
||||
self.display()
|
||||
|
||||
def _get_model_labels(self) -> dict[str, str]:
|
||||
"""Return a list of trimmed labels for all models."""
|
||||
window_width, window_height = get_terminal_size()
|
||||
checkbox_width = 4
|
||||
spacing_width = 2
|
||||
result = {}
|
||||
|
||||
models = self.all_models
|
||||
label_width = max([len(models[x].name or "") for x in self.starter_models])
|
||||
description_width = window_width - label_width - checkbox_width - spacing_width
|
||||
|
||||
for key in self.all_models:
|
||||
description = models[key].description
|
||||
description = (
|
||||
description[0 : description_width - 3] + "..."
|
||||
if description and len(description) > description_width
|
||||
else description
|
||||
if description
|
||||
else ""
|
||||
)
|
||||
result[key] = f"%-{label_width}s %s" % (models[key].name, description)
|
||||
|
||||
return result
|
||||
|
||||
def _get_columns(self) -> int:
|
||||
window_width, window_height = get_terminal_size()
|
||||
cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1
|
||||
return min(cols, len(self.installed_models))
|
||||
|
||||
def confirm_deletions(self, selections: InstallSelections) -> bool:
|
||||
remove_models = selections.remove_models
|
||||
if remove_models:
|
||||
model_names = [self.all_models[x].name or "" for x in remove_models]
|
||||
mods = "\n".join(model_names)
|
||||
is_ok = npyscreen.notify_ok_cancel(
|
||||
f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}"
|
||||
)
|
||||
assert isinstance(is_ok, bool) # npyscreen doesn't have return type annotations
|
||||
return is_ok
|
||||
else:
|
||||
return True
|
||||
|
||||
@property
|
||||
def all_models(self) -> Dict[str, UnifiedModelInfo]:
|
||||
# npyscreen doesn't having typing hints
|
||||
return self.parentApp.install_helper.all_models # type: ignore
|
||||
|
||||
@property
|
||||
def starter_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._starter_models # type: ignore
|
||||
|
||||
@property
|
||||
def installed_models(self) -> List[str]:
|
||||
return self.parentApp.install_helper._installed_models # type: ignore
|
||||
|
||||
def on_back(self) -> None:
|
||||
self.parentApp.switchFormPrevious()
|
||||
self.editing = False
|
||||
|
||||
def on_cancel(self) -> None:
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = True
|
||||
self.editing = False
|
||||
|
||||
def on_done(self) -> None:
|
||||
self.marshall_arguments()
|
||||
if not self.confirm_deletions(self.parentApp.install_selections):
|
||||
return
|
||||
self.parentApp.setNextForm(None)
|
||||
self.parentApp.user_cancelled = False
|
||||
self.editing = False
|
||||
|
||||
def marshall_arguments(self) -> None:
|
||||
"""
|
||||
Assemble arguments and store as attributes of the application:
|
||||
.starter_models: dict of model names to install from INITIAL_CONFIGURE.yaml
|
||||
True => Install
|
||||
False => Remove
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
"""
|
||||
selections = self.parentApp.install_selections
|
||||
all_models = self.all_models
|
||||
|
||||
# Defined models (in INITIAL_CONFIG.yaml or invokeai.db) to add/remove
|
||||
ui_sections = [
|
||||
self.starter_pipelines,
|
||||
self.pipeline_models,
|
||||
self.controlnet_models,
|
||||
self.t2i_models,
|
||||
self.ipadapter_models,
|
||||
self.lora_models,
|
||||
self.ti_models,
|
||||
]
|
||||
for section in ui_sections:
|
||||
if "models_selected" not in section:
|
||||
continue
|
||||
selected = {section["models"][x] for x in section["models_selected"].value}
|
||||
models_to_install = [x for x in selected if not self.all_models[x].installed]
|
||||
models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed]
|
||||
selections.remove_models.extend(models_to_remove)
|
||||
selections.install_models.extend([all_models[x] for x in models_to_install])
|
||||
|
||||
# models located in the 'download_ids" section
|
||||
for section in ui_sections:
|
||||
if downloads := section.get("download_ids"):
|
||||
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
|
||||
selections.install_models.extend(models)
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore
|
||||
def __init__(self, opt: Namespace, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = opt
|
||||
self.user_cancelled = False
|
||||
self.install_selections = InstallSelections()
|
||||
self.install_helper = install_helper
|
||||
|
||||
def onStart(self) -> None:
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main_form = self.addForm(
|
||||
"MAIN",
|
||||
addModelsForm,
|
||||
name="Install Stable Diffusion Models",
|
||||
cycle_widgets=False,
|
||||
)
|
||||
|
||||
|
||||
def list_models(installer: ModelInstallServiceBase, model_type: ModelType):
|
||||
"""Print out all models of type model_type."""
|
||||
models = installer.record_store.search_by_attr(model_type=model_type)
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
for model in models:
|
||||
path = (config.models_path / model.path).resolve()
|
||||
print(f"{model.name:40}{model.base.value:5}{model.type.value:8}{model.format.value:12}{path}")
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
def select_and_download_models(opt: Namespace) -> None:
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
install_helper = InstallHelper(config, logger)
|
||||
installer = install_helper.installer
|
||||
|
||||
if opt.list_models:
|
||||
list_models(installer, opt.list_models)
|
||||
|
||||
elif opt.add or opt.delete:
|
||||
selections = InstallSelections(
|
||||
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
|
||||
)
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.default_only:
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
selections = InstallSelections(install_models=[default_model])
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
elif opt.yes_to_all:
|
||||
selections = InstallSelections(install_models=install_helper.recommended_models())
|
||||
install_helper.add_or_delete(selections)
|
||||
|
||||
# this is where the TUI is called
|
||||
else:
|
||||
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
|
||||
raise WindowTooSmallException(
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
installApp = AddModelApplication(opt, install_helper)
|
||||
try:
|
||||
installApp.run()
|
||||
except KeyboardInterrupt:
|
||||
print("Aborted...")
|
||||
sys.exit(-1)
|
||||
|
||||
install_helper.add_or_delete(installApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--add",
|
||||
nargs="*",
|
||||
help="List of URLs, local paths or repo_ids of models to install",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete",
|
||||
nargs="*",
|
||||
help="List of names of models to delete. Use type:name to disambiguate, as in `controlnet:my_model`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-precision",
|
||||
dest="full_precision",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="use 32-bit weights instead of faster 16-bit weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes",
|
||||
"-y",
|
||||
dest="yes_to_all",
|
||||
action="store_true",
|
||||
help='answer "yes" to all prompts',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_only",
|
||||
action="store_true",
|
||||
help="Only install the default model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-models",
|
||||
choices=[x.value for x in ModelType],
|
||||
help="list installed models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
dest="root",
|
||||
type=pathlib.Path,
|
||||
default=None,
|
||||
help="path to root of install directory",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
invoke_args: dict[str, Any] = {}
|
||||
if opt.full_precision:
|
||||
invoke_args["precision"] = "float32"
|
||||
config.update_config(invoke_args)
|
||||
if opt.root:
|
||||
config.set_root(opt.root)
|
||||
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
try:
|
||||
validate_directories(config)
|
||||
except AssertionError:
|
||||
logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.")
|
||||
sys.argv = ["invokeai_configure", "--yes", "--skip-sd-weights"]
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
invokeai_configure()
|
||||
sys.exit(0)
|
||||
|
||||
try:
|
||||
select_and_download_models(opt)
|
||||
except AssertionError as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
curses.nocbreak()
|
||||
curses.echo()
|
||||
curses.endwin()
|
||||
logger.info("Goodbye! Come back soon.")
|
||||
except WindowTooSmallException as e:
|
||||
logger.error(str(e))
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")
|
||||
input("Press any key to continue...")
|
||||
except Exception as e:
|
||||
if str(e).startswith("addwstr"):
|
||||
logger.error(
|
||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||
)
|
||||
else:
|
||||
print(f"An exception has occurred: {str(e)} Details:")
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
input("Press any key to continue...")
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,441 +0,0 @@
|
||||
"""
|
||||
Widget class definitions used by model_select.py, merge_diffusers.py and textual_inversion.py
|
||||
"""
|
||||
|
||||
import curses
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from curses import BUTTON2_CLICKED, BUTTON3_CLICKED
|
||||
from shutil import get_terminal_size
|
||||
from typing import Optional
|
||||
|
||||
import npyscreen
|
||||
import npyscreen.wgmultiline as wgmultiline
|
||||
import pyperclip
|
||||
from npyscreen import fmPopup
|
||||
|
||||
# minimum size for UIs
|
||||
MIN_COLS = 150
|
||||
MIN_LINES = 40
|
||||
|
||||
|
||||
class WindowTooSmallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def set_terminal_size(columns: int, lines: int) -> bool:
|
||||
OS = platform.uname().system
|
||||
screen_ok = False
|
||||
while not screen_ok:
|
||||
ts = get_terminal_size()
|
||||
width = max(columns, ts.columns)
|
||||
height = max(lines, ts.lines)
|
||||
|
||||
if OS == "Windows":
|
||||
pass
|
||||
# not working reliably - ask user to adjust the window
|
||||
# _set_terminal_size_powershell(width,height)
|
||||
elif OS in ["Darwin", "Linux"]:
|
||||
_set_terminal_size_unix(width, height)
|
||||
|
||||
# check whether it worked....
|
||||
ts = get_terminal_size()
|
||||
if ts.columns < columns or ts.lines < lines:
|
||||
print(
|
||||
f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m"
|
||||
)
|
||||
resp = input(
|
||||
"Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.."
|
||||
)
|
||||
if resp.upper().startswith("Q"):
|
||||
break
|
||||
else:
|
||||
screen_ok = True
|
||||
return screen_ok
|
||||
|
||||
|
||||
def _set_terminal_size_powershell(width: int, height: int):
|
||||
script = f"""
|
||||
$pshost = get-host
|
||||
$pswindow = $pshost.ui.rawui
|
||||
$newsize = $pswindow.buffersize
|
||||
$newsize.height = 3000
|
||||
$newsize.width = {width}
|
||||
$pswindow.buffersize = $newsize
|
||||
$newsize = $pswindow.windowsize
|
||||
$newsize.height = {height}
|
||||
$newsize.width = {width}
|
||||
$pswindow.windowsize = $newsize
|
||||
"""
|
||||
subprocess.run(["powershell", "-Command", "-"], input=script, text=True)
|
||||
|
||||
|
||||
def _set_terminal_size_unix(width: int, height: int):
|
||||
import fcntl
|
||||
import termios
|
||||
|
||||
# These terminals accept the size command and report that the
|
||||
# size changed, but they lie!!!
|
||||
for bad_terminal in ["TERMINATOR_UUID", "ALACRITTY_WINDOW_ID"]:
|
||||
if os.environ.get(bad_terminal):
|
||||
return
|
||||
|
||||
winsize = struct.pack("HHHH", height, width, 0, 0)
|
||||
fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize)
|
||||
sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
|
||||
# make sure there's enough room for the ui
|
||||
term_cols, term_lines = get_terminal_size()
|
||||
if term_cols >= min_cols and term_lines >= min_lines:
|
||||
return True
|
||||
cols = max(term_cols, min_cols)
|
||||
lines = max(term_lines, min_lines)
|
||||
return set_terminal_size(cols, lines)
|
||||
|
||||
|
||||
class IntSlider(npyscreen.Slider):
|
||||
def translate_value(self):
|
||||
stri = "%2d / %2d" % (self.value, self.out_of)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
# fix npyscreen form so that cursor wraps both forward and backward
|
||||
class CyclingForm(object):
|
||||
def find_previous_editable(self, *args):
|
||||
done = False
|
||||
n = self.editw - 1
|
||||
while not done:
|
||||
if self._widgets__[n].editable and not self._widgets__[n].hidden:
|
||||
self.editw = n
|
||||
done = True
|
||||
n -= 1
|
||||
if n < 0:
|
||||
if self.cycle_widgets:
|
||||
n = len(self._widgets__) - 1
|
||||
else:
|
||||
done = True
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredTitleText(npyscreen.TitleText):
|
||||
def __init__(self, *args, **keywords):
|
||||
super().__init__(*args, **keywords)
|
||||
self.resize()
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class CenteredButtonPress(npyscreen.ButtonPress):
|
||||
def resize(self):
|
||||
super().resize()
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
label = self.name
|
||||
self.relx = (maxx - len(label)) // 2
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
class OffsetButtonPress(npyscreen.ButtonPress):
|
||||
def __init__(self, screen, offset=0, *args, **keywords):
|
||||
super().__init__(screen, *args, **keywords)
|
||||
self.offset = offset
|
||||
|
||||
def resize(self):
|
||||
maxy, maxx = self.parent.curses_pad.getmaxyx()
|
||||
width = len(self.name)
|
||||
self.relx = self.offset + (maxx - width) // 2
|
||||
|
||||
|
||||
class IntTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = IntSlider
|
||||
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
length = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(length)
|
||||
return stri
|
||||
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = npyscreen.Slider
|
||||
|
||||
|
||||
class SelectColumnBase:
|
||||
"""Base class for selection widget arranged in columns."""
|
||||
|
||||
def make_contained_widgets(self):
|
||||
self._my_widgets = []
|
||||
column_width = self.width // self.columns
|
||||
for h in range(self.value_cnt):
|
||||
self._my_widgets.append(
|
||||
self._contained_widgets(
|
||||
self.parent,
|
||||
rely=self.rely + (h % self.rows) * self._contained_widget_height,
|
||||
relx=self.relx + (h // self.rows) * column_width,
|
||||
max_width=column_width,
|
||||
max_height=self.__class__._contained_widget_height,
|
||||
)
|
||||
)
|
||||
|
||||
def set_up_handlers(self):
|
||||
super().set_up_handlers()
|
||||
self.handlers.update(
|
||||
{
|
||||
curses.KEY_UP: self.h_cursor_line_left,
|
||||
curses.KEY_DOWN: self.h_cursor_line_right,
|
||||
}
|
||||
)
|
||||
|
||||
def h_cursor_line_down(self, ch):
|
||||
self.cursor_line += self.rows
|
||||
if self.cursor_line >= len(self.values):
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = len(self.values) - self.rows
|
||||
self.h_exit_down(ch)
|
||||
return True
|
||||
else:
|
||||
self.cursor_line -= self.rows
|
||||
return True
|
||||
|
||||
def h_cursor_line_up(self, ch):
|
||||
self.cursor_line -= self.rows
|
||||
if self.cursor_line < 0:
|
||||
if self.scroll_exit:
|
||||
self.cursor_line = 0
|
||||
self.h_exit_up(ch)
|
||||
else:
|
||||
self.cursor_line = 0
|
||||
|
||||
def h_cursor_line_left(self, ch):
|
||||
super().h_cursor_line_up(ch)
|
||||
|
||||
def h_cursor_line_right(self, ch):
|
||||
super().h_cursor_line_down(ch)
|
||||
|
||||
def handle_mouse_event(self, mouse_event):
|
||||
mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
|
||||
column_width = self.width // self.columns
|
||||
column_height = math.ceil(self.value_cnt / self.columns)
|
||||
column_no = rel_x // column_width
|
||||
row_no = rel_y // self._contained_widget_height
|
||||
self.cursor_line = column_no * column_height + row_no
|
||||
if bstate & curses.BUTTON1_DOUBLE_CLICKED:
|
||||
if hasattr(self, "on_mouse_double_click"):
|
||||
self.on_mouse_double_click(self.cursor_line)
|
||||
self.display()
|
||||
|
||||
|
||||
class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect):
|
||||
def __init__(self, screen, columns: int = 1, values: Optional[list] = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def on_mouse_double_click(self, cursor_line):
|
||||
self.h_select_toggle(cursor_line)
|
||||
|
||||
|
||||
class SingleSelectWithChanged(npyscreen.SelectOne):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.on_changed = None
|
||||
|
||||
def h_select(self, ch):
|
||||
super().h_select(ch)
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class CheckboxWithChanged(npyscreen.Checkbox):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.on_changed = None
|
||||
|
||||
def whenToggled(self):
|
||||
super().whenToggled()
|
||||
if self.on_changed:
|
||||
self.on_changed(self.value)
|
||||
|
||||
|
||||
class SingleSelectColumnsSimple(SelectColumnBase, SingleSelectWithChanged):
|
||||
"""Row of radio buttons. Spacebar to select."""
|
||||
|
||||
def __init__(self, screen, columns: int = 1, values: list = None, **keywords):
|
||||
if values is None:
|
||||
values = []
|
||||
self.columns = columns
|
||||
self.value_cnt = len(values)
|
||||
self.rows = math.ceil(self.value_cnt / self.columns)
|
||||
self.on_changed = None
|
||||
super().__init__(screen, values=values, **keywords)
|
||||
|
||||
def h_cursor_line_right(self, ch):
|
||||
self.h_exit_down("bye bye")
|
||||
|
||||
def h_cursor_line_left(self, ch):
|
||||
self.h_exit_up("bye bye")
|
||||
|
||||
|
||||
class SingleSelectColumns(SingleSelectColumnsSimple):
|
||||
"""Row of radio buttons. When tabbing over a selection, it is auto selected."""
|
||||
|
||||
def when_cursor_moved(self):
|
||||
self.h_select(self.cursor_line)
|
||||
|
||||
|
||||
class TextBoxInner(npyscreen.MultiLineEdit):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.yank = None
|
||||
self.handlers.update(
|
||||
{
|
||||
"^A": self.h_cursor_to_start,
|
||||
"^E": self.h_cursor_to_end,
|
||||
"^K": self.h_kill,
|
||||
"^F": self.h_cursor_right,
|
||||
"^B": self.h_cursor_left,
|
||||
"^Y": self.h_yank,
|
||||
"^V": self.h_paste,
|
||||
}
|
||||
)
|
||||
|
||||
def h_cursor_to_start(self, input):
|
||||
self.cursor_position = 0
|
||||
|
||||
def h_cursor_to_end(self, input):
|
||||
self.cursor_position = len(self.value)
|
||||
|
||||
def h_kill(self, input):
|
||||
self.yank = self.value[self.cursor_position :]
|
||||
self.value = self.value[: self.cursor_position]
|
||||
|
||||
def h_yank(self, input):
|
||||
if self.yank:
|
||||
self.paste(self.yank)
|
||||
|
||||
def paste(self, text: str):
|
||||
self.value = self.value[: self.cursor_position] + text + self.value[self.cursor_position :]
|
||||
self.cursor_position += len(text)
|
||||
|
||||
def h_paste(self, input: int = 0):
|
||||
try:
|
||||
text = pyperclip.paste()
|
||||
except ModuleNotFoundError:
|
||||
text = "To paste with the mouse on Linux, please install the 'xclip' program."
|
||||
self.paste(text)
|
||||
|
||||
def handle_mouse_event(self, mouse_event):
|
||||
mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event)
|
||||
if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED):
|
||||
self.h_paste()
|
||||
|
||||
|
||||
class TextBox(npyscreen.BoxTitle):
|
||||
_contained_widget = TextBoxInner
|
||||
|
||||
|
||||
class BufferBox(npyscreen.BoxTitle):
|
||||
_contained_widget = npyscreen.BufferPager
|
||||
|
||||
|
||||
class ConfirmCancelPopup(fmPopup.ActionPopup):
|
||||
DEFAULT_COLUMNS = 100
|
||||
|
||||
def on_ok(self):
|
||||
self.value = True
|
||||
|
||||
def on_cancel(self):
|
||||
self.value = False
|
||||
|
||||
|
||||
class FileBox(npyscreen.BoxTitle):
|
||||
_contained_widget = npyscreen.Filename
|
||||
|
||||
|
||||
class PrettyTextBox(npyscreen.BoxTitle):
|
||||
_contained_widget = TextBox
|
||||
|
||||
|
||||
def _wrap_message_lines(message, line_length):
|
||||
lines = []
|
||||
for line in message.split("\n"):
|
||||
lines.extend(textwrap.wrap(line.rstrip(), line_length))
|
||||
return lines
|
||||
|
||||
|
||||
def _prepare_message(message):
|
||||
if isinstance(message, list) or isinstance(message, tuple):
|
||||
return "\n".join([s.rstrip() for s in message])
|
||||
# return "\n".join(message)
|
||||
else:
|
||||
return message
|
||||
|
||||
|
||||
def select_stable_diffusion_config_file(
|
||||
form_color: str = "DANGER",
|
||||
wrap: bool = True,
|
||||
model_name: str = "Unknown",
|
||||
):
|
||||
message = f"Please select the correct prediction type for the checkpoint named '{model_name}'. Press <CANCEL> to skip installation."
|
||||
title = "CONFIG FILE SELECTION"
|
||||
options = [
|
||||
"'epsilon' - most v1.5 models and v2 models trained on 512 pixel images",
|
||||
"'vprediction' - v2 models trained on 768 pixel images and a few v1.5 models)",
|
||||
"Accept the best guess; you can fix it in the Web UI later",
|
||||
]
|
||||
|
||||
F = ConfirmCancelPopup(
|
||||
name=title,
|
||||
color=form_color,
|
||||
cycle_widgets=True,
|
||||
lines=16,
|
||||
)
|
||||
F.preserve_selected_widget = True
|
||||
|
||||
mlw = F.add(
|
||||
wgmultiline.Pager,
|
||||
max_height=4,
|
||||
editable=False,
|
||||
)
|
||||
mlw_width = mlw.width - 1
|
||||
if wrap:
|
||||
message = _wrap_message_lines(message, mlw_width)
|
||||
mlw.values = message
|
||||
|
||||
choice = F.add(
|
||||
npyscreen.SelectOne,
|
||||
values=options,
|
||||
value=[2],
|
||||
max_height=len(options) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
|
||||
F.editw = 1
|
||||
F.edit()
|
||||
if not F.value:
|
||||
return None
|
||||
assert choice.value[0] in range(0, 3), "invalid choice"
|
||||
choices = ["epsilon", "v", "guess"]
|
||||
return choices[choice.value[0]]
|
@ -1,22 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--web", action="store_true")
|
||||
opts, _ = parser.parse_known_args()
|
||||
|
||||
if opts.web:
|
||||
sys.argv.pop(sys.argv.index("--web"))
|
||||
from invokeai.app.api_app import invoke_api
|
||||
|
||||
invoke_api()
|
||||
else:
|
||||
from invokeai.app.cli_app import invoke_cli
|
||||
|
||||
invoke_cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.merge
|
||||
"""
|
||||
|
||||
from .merge_diffusers import main as invokeai_merge_diffusers # noqa: F401
|
@ -1,448 +0,0 @@
|
||||
"""
|
||||
invokeai.frontend.merge exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import re
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import ModelMerger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import FloatTitleSlider, SingleSelectColumns, TextBox
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
BASE_TYPES = [
|
||||
(BaseModelType.StableDiffusion1, "Models Built on SD-1.x"),
|
||||
(BaseModelType.StableDiffusion2, "Models Built on SD-2.x"),
|
||||
(BaseModelType.StableDiffusionXL, "Models Built on SDXL"),
|
||||
]
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=config.root_path,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
dest="model_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
choices=[x[0].value for x in BASE_TYPES],
|
||||
help="The base model shared by the models to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest="clobber",
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
self.ALLOW_RESIZE = True
|
||||
self.FIX_MINIMUM_SIZE_WHEN_CREATED = False
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def record_store(self):
|
||||
return self.parentApp.record_store
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
window_height, window_width = curses.initscr().getmaxyx()
|
||||
self.current_base = 0
|
||||
self.models = self.get_models(BASE_TYPES[self.current_base][0])
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
max_width = max([len(x) for x in self.model_names])
|
||||
max_width += 6
|
||||
horizontal_layout = max_width * 3 < window_width
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Select two models to merge and optionally a third.",
|
||||
editable=False,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
color="CONTROL",
|
||||
value="Use up and down arrows to move, <space> to select an item, <tab> and <shift-tab> to move from one field to the next.",
|
||||
editable=False,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.base_select = self.add_widget_intelligent(
|
||||
SingleSelectColumns,
|
||||
values=[x[1] for x in BASE_TYPES],
|
||||
value=[self.current_base],
|
||||
columns=4,
|
||||
max_height=2,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.base_select.on_changed = self._populate_models
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 1",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
rely=7,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 2",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(2)",
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names),
|
||||
max_width=max_width,
|
||||
relx=max_width + 3 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="MODEL 3",
|
||||
color="GOOD",
|
||||
editable=False,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=6 if horizontal_layout else None,
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.SelectOne,
|
||||
name="(3)",
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names) + 1,
|
||||
max_width=max_width,
|
||||
scroll_exit=True,
|
||||
relx=max_width * 2 + 3 if horizontal_layout else None,
|
||||
rely=7 if horizontal_layout else None,
|
||||
)
|
||||
for m in [self.model1, self.model2, self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Name for merged model:",
|
||||
labelColor="CONTROL",
|
||||
max_height=3,
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of models created by different diffusers library versions",
|
||||
labelColor="CONTROL",
|
||||
value=True,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
labelColor="CONTROL",
|
||||
max_height=len(self.interpolations) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name="Weight (alpha) to assign to second and third models:",
|
||||
out_of=1.0,
|
||||
step=0.01,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
labelColor="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model1.editing = True
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f"{models[selected_model1]}+{models[selected_model2]}"
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values = ["add_difference ( A+(B-C) )"]
|
||||
self.merged_model_name.value += f"+{models[selected_model3 -1]}" # In model3 there is one more element in the list (None). So we have to subtract one.
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merge_method.value = 0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Starting the merge...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
model_keys = [x[0] for x in self.models]
|
||||
models = [
|
||||
model_keys[self.model1.value[0]],
|
||||
model_keys[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_keys[self.model3.value[0] - 1])
|
||||
interp = "add_difference"
|
||||
else:
|
||||
interp = self.interpolations[self.merge_method.value[0]]
|
||||
|
||||
args = {
|
||||
"model_keys": models,
|
||||
"base_model": tuple(BaseModelType)[self.base_select.value[0]],
|
||||
"alpha": self.alpha.value,
|
||||
"interp": interp,
|
||||
"force": self.force.value,
|
||||
"merged_model_name": self.merged_model_name.value,
|
||||
}
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self) -> bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(
|
||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||
)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = {model_names[self.model1.value[0]], model_names[self.model2.value[0]]}
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0] - 1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_models(self, base_model: Optional[BaseModelType] = None) -> List[Tuple[str, str]]: # key to name
|
||||
models = [
|
||||
(x.key, x.name)
|
||||
for x in self.record_store.search_by_attr(model_type=ModelType.Main, base_model=base_model)
|
||||
if x.format == ModelFormat("diffusers") and x.variant == ModelVariantType("normal")
|
||||
]
|
||||
return sorted(models, key=lambda x: x[1])
|
||||
|
||||
def _populate_models(self, value: List[int]):
|
||||
base_model = BASE_TYPES[value[0]][0]
|
||||
self.models = self.get_models(base_model)
|
||||
self.model_names = [x[1] for x in self.models]
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0, "None")
|
||||
self.model1.values = self.model_names
|
||||
self.model2.values = self.model_names
|
||||
self.model3.values = models_plus_none
|
||||
|
||||
self.display()
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self, record_store: ModelRecordServiceBase):
|
||||
super().__init__()
|
||||
self.record_store = record_store
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.ElegantTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace) -> None:
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
mergeapp = Mergeapp(record_store)
|
||||
mergeapp.run()
|
||||
args = mergeapp.merge_arguments
|
||||
merger = get_model_merger(record_store)
|
||||
merger.merge_diffusion_models_and_save(**args)
|
||||
merged_model_name = args["merged_model_name"]
|
||||
logger.info(f'Models merged into new model: "{merged_model_name}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3
|
||||
), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage."
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.model_names)
|
||||
logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"')
|
||||
|
||||
record_store: ModelRecordServiceBase = get_config_store()
|
||||
assert (
|
||||
len(record_store.search_by_attr(args.merged_model_name, args.base_model, ModelType.Main)) == 0 or args.clobber
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merger = get_model_merger(record_store)
|
||||
model_keys = []
|
||||
for name in args.model_names:
|
||||
if len(name) == 32 and re.match(r"^[0-9a-f]$", name):
|
||||
model_keys.append(name)
|
||||
else:
|
||||
models = record_store.search_by_attr(
|
||||
model_name=name, model_type=ModelType.Main, base_model=BaseModelType(args.base_model)
|
||||
)
|
||||
assert len(models) > 0, f"{name}: Unknown model"
|
||||
assert len(models) < 2, f"{name}: More than one model by this name. Please specify the model key instead."
|
||||
model_keys.append(models[0].key)
|
||||
|
||||
merger.merge_diffusion_models_and_save(
|
||||
alpha=args.alpha,
|
||||
model_keys=model_keys,
|
||||
merged_model_name=args.merged_model_name,
|
||||
interp=args.interp,
|
||||
force=args.force,
|
||||
)
|
||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
||||
|
||||
|
||||
def get_config_store() -> ModelRecordServiceSQL:
|
||||
output_path = config.outputs_path
|
||||
assert output_path is not None
|
||||
image_files = DiskImageFileStorage(output_path / "images")
|
||||
db = init_db(config=config, logger=InvokeAILogger.get_logger(), image_files=image_files)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
|
||||
def get_model_merger(record_store: ModelRecordServiceBase) -> ModelMerger:
|
||||
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=DownloadQueueService())
|
||||
installer.start()
|
||||
return ModelMerger(installer)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
if args.root_dir:
|
||||
config.set_root(Path(args.root_dir))
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
run_gui(args)
|
||||
else:
|
||||
run_cli(args)
|
||||
except widget.NotEnoughSpaceForWidget as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("You need to have at least two diffusers models in order to merge")
|
||||
else:
|
||||
logger.error("Not enough room for the user interface. Try making this window larger.")
|
||||
sys.exit(-1)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.frontend.training
|
||||
"""
|
||||
|
||||
from .textual_inversion import main as invokeai_textual_inversion # noqa: F401
|
@ -1,452 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023-24 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.install.install_helper import initialize_installer
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.training import do_textual_inversion_training, parse_args
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
CONF_FILE = "preferences.conf"
|
||||
config = None
|
||||
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear",
|
||||
"cosine",
|
||||
"cosine_with_restarts",
|
||||
"polynomial",
|
||||
"constant",
|
||||
"constant_with_warmup",
|
||||
]
|
||||
precisions = ["no", "fp16", "bf16"]
|
||||
learnable_properties = ["object", "style"]
|
||||
|
||||
def __init__(self, parentApp: npyscreen.NPSAppManaged, name: str, saved_args: Optional[Dict[str, str]] = None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self) -> None:
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self) -> None:
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = "★"
|
||||
default_placeholder_token = ""
|
||||
saved_args = self.saved_args
|
||||
|
||||
assert config is not None
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
)
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Model Name:",
|
||||
values=sorted(self.model_names),
|
||||
value=default,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Trigger Term:",
|
||||
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value="",
|
||||
editable=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Initializer:",
|
||||
value=saved_args.get("initializer_token", default_initializer_token),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(saved_args.get("learnable_property", "object")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Data Training Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"train_data_dir",
|
||||
config.root_path / TRAINING_DATA / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Output Destination Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"output_dir",
|
||||
config.root_path / TRAINING_DIR / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Image resolution (pixels):",
|
||||
values=self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get("center_crop", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Mixed Precision:",
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.num_train_epochs = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Number of training epochs:",
|
||||
out_of=1000,
|
||||
step=50,
|
||||
lowest=1,
|
||||
value=saved_args.get("num_train_epochs", 100),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Max Training Steps:",
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get("max_train_steps", 3000),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Batch Size (reduce if you run out of memory):",
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("train_batch_size", 8),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("gradient_accumulation_steps", 4),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Warmup Steps:",
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get("lr_warmup_steps", 0),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(
|
||||
saved_args.get("learning_rate", "5.0e-04"),
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get("scale_lr", True),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learning rate scheduler:",
|
||||
values=self.lr_schedulers,
|
||||
max_height=7,
|
||||
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model.editing = True
|
||||
|
||||
def initializer_changed(self) -> None:
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(config.root_path / TRAINING_DATA / placeholder)
|
||||
self.output_dir.value = str(config.root_path / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Launching textual inversion training. This will take a while...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append("Model Name must correspond to a known model in invokeai.db")
|
||||
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
||||
bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen")
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append("Data Training Directory cannot be empty")
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append("The Output Destination Directory cannot be empty")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> Tuple[List[str], int]:
|
||||
global config
|
||||
assert config is not None
|
||||
installer = initialize_installer(config)
|
||||
store = installer.record_store
|
||||
main_models = store.search_by_attr(model_type=ModelType.Main)
|
||||
model_names = [f"{x.base.value}/{x.type.value}/{x.name}" for x in main_models if x.format == "diffusers"]
|
||||
default = 0
|
||||
return (model_names, default)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
args = {}
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model=self.model_names[self.model.value[0]],
|
||||
resolution=self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property=self.learnable_properties[self.learnable_property.value[0]],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in (
|
||||
"initializer_token",
|
||||
"placeholder_token",
|
||||
"train_data_dir",
|
||||
"output_dir",
|
||||
"scale_lr",
|
||||
"center_crop",
|
||||
"enable_xformers_memory_efficient_attention",
|
||||
):
|
||||
args[attr] = getattr(self, attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in (
|
||||
"train_batch_size",
|
||||
"gradient_accumulation_steps",
|
||||
"num_train_epochs",
|
||||
"max_train_steps",
|
||||
"lr_warmup_steps",
|
||||
):
|
||||
args[attr] = int(getattr(self, attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(learning_rate=float(self.learning_rate.value))
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args["resume_from_checkpoint"] = "latest"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args: Optional[Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.ti_arguments = None
|
||||
self.saved_args = saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm(
|
||||
"MAIN",
|
||||
textualInversionForm,
|
||||
name="Textual Inversion Settings",
|
||||
saved_args=self.saved_args,
|
||||
)
|
||||
|
||||
|
||||
def copy_to_embeddings_folder(args: Dict[str, str]) -> None:
|
||||
"""
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
"""
|
||||
assert config is not None
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = config.root_path / "embeddings" / dest_dir_name
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
logger.info(f'Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict) -> None:
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
assert config is not None
|
||||
dest_dir = config.root_path / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
|
||||
def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
assert config is not None
|
||||
conf_file = config.root_path / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except Exception:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def do_front_end() -> None:
|
||||
global config
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
if my_args := myapplication.ti_arguments:
|
||||
os.makedirs(my_args["output_dir"], exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match("^<.+>$", my_args["placeholder_token"]):
|
||||
my_args["placeholder_token"] = f"<{my_args['placeholder_token']}>"
|
||||
|
||||
my_args["only_save_embeds"] = True
|
||||
save_args(my_args)
|
||||
|
||||
try:
|
||||
print(my_args)
|
||||
do_textual_inversion_training(config, **my_args)
|
||||
copy_to_embeddings_folder(my_args)
|
||||
except Exception as e:
|
||||
logger.error("An exception occurred during training. The exception was:")
|
||||
logger.error(str(e))
|
||||
logger.error("DETAILS:")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
global config
|
||||
|
||||
args: Namespace = parse_args()
|
||||
config = get_config()
|
||||
|
||||
# change root if needed
|
||||
if args.root_dir:
|
||||
config.set_root(args.root_dir)
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end()
|
||||
else:
|
||||
do_textual_inversion_training(config, **vars(args))
|
||||
except AssertionError as e:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||
if str(e).startswith("Height of 1 allocated"):
|
||||
logger.error("You need to have at least one diffusers models defined in invokeai.db in order to train")
|
||||
elif str(e).startswith("addwstr"):
|
||||
logger.error("Not enough window space for the interface. Please make your window larger and try again.")
|
||||
else:
|
||||
logger.error(e)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -638,6 +638,14 @@
|
||||
"huggingFacePlaceholder": "owner/model-name",
|
||||
"huggingFaceRepoID": "HuggingFace Repo ID",
|
||||
"huggingFaceHelper": "If multiple models are found in this repo, you will be prompted to select one to install.",
|
||||
"hfToken": "HuggingFace Token",
|
||||
"hfTokenHelperText": "A HF token is required to use checkpoint models. Click here to create or get your token.",
|
||||
"hfTokenInvalid": "Invalid or Missing HF Token",
|
||||
"hfTokenInvalidErrorMessage": "Invalid or missing HuggingFace token.",
|
||||
"hfTokenInvalidErrorMessage2": "Update it in the ",
|
||||
"hfTokenUnableToVerify": "Unable to Verify HF Token",
|
||||
"hfTokenUnableToVerifyErrorMessage": "Unable to verify HuggingFace token. This is likely due to a network error. Please try again later.",
|
||||
"hfTokenSaved": "HF Token Saved",
|
||||
"imageEncoderModelId": "Image Encoder Model ID",
|
||||
"installQueue": "Install Queue",
|
||||
"inplaceInstall": "In-place install",
|
||||
@ -648,6 +656,8 @@
|
||||
"load": "Load",
|
||||
"localOnly": "local only",
|
||||
"manual": "Manual",
|
||||
"loraModels": "LoRAs",
|
||||
"main": "Main",
|
||||
"metadata": "Metadata",
|
||||
"model": "Model",
|
||||
"modelConversionFailed": "Model Conversion Failed",
|
||||
@ -667,6 +677,8 @@
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
"name": "Name",
|
||||
"noModelsInstalled": "No Models Installed",
|
||||
"noModelsInstalledDesc1": "Install models with the",
|
||||
"noModelSelected": "No Model Selected",
|
||||
"none": "none",
|
||||
"path": "Path",
|
||||
@ -686,7 +698,9 @@
|
||||
"settings": "Settings",
|
||||
"simpleModelPlaceholder": "URL or path to a local file or diffusers folder",
|
||||
"source": "Source",
|
||||
"starterModels": "Starter Models",
|
||||
"syncModels": "Sync Models",
|
||||
"textualInversions": "Textual Inversions",
|
||||
"triggerPhrases": "Trigger Phrases",
|
||||
"loraTriggerPhrases": "LoRA Trigger Phrases",
|
||||
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
|
||||
@ -1424,9 +1438,17 @@
|
||||
"undo": "Undo"
|
||||
},
|
||||
"workflows": {
|
||||
"ascending": "Ascending",
|
||||
"created": "Created",
|
||||
"descending": "Descending",
|
||||
"workflows": "Workflows",
|
||||
"workflowLibrary": "Library",
|
||||
"userWorkflows": "My Workflows",
|
||||
"defaultWorkflows": "Default Workflows",
|
||||
"projectWorkflows": "Project Workflows",
|
||||
"opened": "Opened",
|
||||
"openWorkflow": "Open Workflow",
|
||||
"updated": "Updated",
|
||||
"uploadWorkflow": "Load from File",
|
||||
"deleteWorkflow": "Delete Workflow",
|
||||
"unnamedWorkflow": "Unnamed Workflow",
|
||||
@ -1437,6 +1459,9 @@
|
||||
"savingWorkflow": "Saving Workflow...",
|
||||
"problemSavingWorkflow": "Problem Saving Workflow",
|
||||
"workflowSaved": "Workflow Saved",
|
||||
"name": "Name",
|
||||
"noRecentWorkflows": "No Recent Workflows",
|
||||
"noUserWorkflows": "No User Workflows",
|
||||
"noWorkflows": "No Workflows",
|
||||
"problemLoading": "Problem Loading Workflows",
|
||||
"loading": "Loading Workflows",
|
||||
|
@ -11,6 +11,8 @@ import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import { useHFLoginToast } from 'features/modelManagerV2/hooks/useHFLoginToast';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||
import InvokeTabs from 'features/ui/components/InvokeTabs';
|
||||
@ -68,6 +70,9 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
|
||||
dispatch(appStarted());
|
||||
}, [dispatch]);
|
||||
|
||||
useStarterModelsToast();
|
||||
useHFLoginToast();
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box
|
||||
|
@ -24,7 +24,10 @@ export type AppFeature =
|
||||
| 'resumeQueue'
|
||||
| 'prependQueue'
|
||||
| 'invocationCache'
|
||||
| 'bulkDownload';
|
||||
| 'bulkDownload'
|
||||
| 'starterModels'
|
||||
| 'hfToken';
|
||||
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
*/
|
||||
|
@ -0,0 +1,78 @@
|
||||
import {
|
||||
Button,
|
||||
ExternalLink,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Input,
|
||||
useToast,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetHFTokenStatusQuery, useSetHFTokenMutation } from 'services/api/endpoints/models';
|
||||
|
||||
export const HFToken = () => {
|
||||
const { t } = useTranslation();
|
||||
const isEnabled = useFeatureStatus('hfToken').isFeatureEnabled;
|
||||
const [token, setToken] = useState('');
|
||||
const { currentData } = useGetHFTokenStatusQuery(isEnabled ? undefined : skipToken);
|
||||
const [trigger, { isLoading }] = useSetHFTokenMutation();
|
||||
const toast = useToast();
|
||||
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setToken(e.target.value);
|
||||
}, []);
|
||||
const onClick = useCallback(() => {
|
||||
trigger({ token })
|
||||
.unwrap()
|
||||
.then((res) => {
|
||||
if (res === 'valid') {
|
||||
setToken('');
|
||||
toast({
|
||||
title: t('modelManager.hfTokenSaved'),
|
||||
status: 'success',
|
||||
duration: 3000,
|
||||
});
|
||||
}
|
||||
});
|
||||
}, [t, toast, token, trigger]);
|
||||
|
||||
const error = useMemo(() => {
|
||||
if (!currentData || isLoading) {
|
||||
return null;
|
||||
}
|
||||
if (currentData === 'invalid') {
|
||||
return t('modelManager.hfTokenInvalidErrorMessage');
|
||||
}
|
||||
if (currentData === 'unknown') {
|
||||
return t('modelManager.hfTokenUnableToVerifyErrorMessage');
|
||||
}
|
||||
return null;
|
||||
}, [currentData, isLoading, t]);
|
||||
|
||||
if (!currentData || currentData === 'valid') {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex borderRadius="base" w="full">
|
||||
<FormControl isInvalid={Boolean(error)} orientation="vertical">
|
||||
<FormLabel>{t('modelManager.hfToken')}</FormLabel>
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Input type="password" value={token} onChange={onChange} />
|
||||
<Button onClick={onClick} size="sm" isDisabled={token.trim().length === 0} isLoading={isLoading}>
|
||||
{t('common.save')}
|
||||
</Button>
|
||||
</Flex>
|
||||
<FormHelperText>
|
||||
<ExternalLink label={t('modelManager.hfTokenHelperText')} href="https://huggingface.co/settings/tokens" />
|
||||
</FormHelperText>
|
||||
<FormErrorMessage>{error}</FormErrorMessage>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,88 @@
|
||||
import { Button, Text, useToast } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { t } from 'i18next';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetHFTokenStatusQuery } from 'services/api/endpoints/models';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
const FEATURE_ID = 'hfToken';
|
||||
|
||||
const getTitle = (token_status: S['HFTokenStatus']) => {
|
||||
switch (token_status) {
|
||||
case 'invalid':
|
||||
return t('modelManager.hfTokenInvalid');
|
||||
case 'unknown':
|
||||
return t('modelManager.hfTokenUnableToVerify');
|
||||
}
|
||||
};
|
||||
|
||||
export const useHFLoginToast = () => {
|
||||
const { t } = useTranslation();
|
||||
const isEnabled = useFeatureStatus(FEATURE_ID).isFeatureEnabled;
|
||||
const [didToast, setDidToast] = useState(false);
|
||||
const { data } = useGetHFTokenStatusQuery(isEnabled ? undefined : skipToken);
|
||||
const toast = useToast();
|
||||
|
||||
useEffect(() => {
|
||||
if (toast.isActive(FEATURE_ID)) {
|
||||
if (data === 'valid') {
|
||||
setDidToast(true);
|
||||
toast.close(FEATURE_ID);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (data && data !== 'valid' && !didToast && isEnabled) {
|
||||
const title = getTitle(data);
|
||||
toast({
|
||||
id: FEATURE_ID,
|
||||
title,
|
||||
description: <ToastDescription token_status={data} />,
|
||||
status: 'info',
|
||||
isClosable: true,
|
||||
duration: null,
|
||||
onCloseComplete: () => setDidToast(true),
|
||||
});
|
||||
}
|
||||
}, [data, didToast, isEnabled, t, toast]);
|
||||
};
|
||||
|
||||
type Props = {
|
||||
token_status: S['HFTokenStatus'];
|
||||
};
|
||||
|
||||
const ToastDescription = ({ token_status }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const toast = useToast();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setActiveTab('modelManager'));
|
||||
toast.close(FEATURE_ID);
|
||||
}, [dispatch, toast]);
|
||||
|
||||
if (token_status === 'invalid') {
|
||||
return (
|
||||
<Text fontSize="md">
|
||||
{t('modelManager.hfTokenInvalidErrorMessage')} {t('modelManager.hfTokenInvalidErrorMessage2')}
|
||||
<Button onClick={onClick} variant="link" color="base.50" flexGrow={0}>
|
||||
{t('modelManager.modelManager')}.
|
||||
</Button>
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (token_status === 'unknown') {
|
||||
return (
|
||||
<Text fontSize="md">
|
||||
{t('modelManager.hfTokenUnableToErrorMessage')}{' '}
|
||||
<Button onClick={onClick} variant="link" color="base.50" flexGrow={0}>
|
||||
{t('modelManager.modelManager')}.
|
||||
</Button>
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
};
|
@ -0,0 +1,54 @@
|
||||
import { Button, Text, useToast } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
const TOAST_ID = 'starterModels';
|
||||
|
||||
export const useStarterModelsToast = () => {
|
||||
const { t } = useTranslation();
|
||||
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
|
||||
const [didToast, setDidToast] = useState(false);
|
||||
const [mainModels, { data }] = useMainModels();
|
||||
const toast = useToast();
|
||||
|
||||
useEffect(() => {
|
||||
if (toast.isActive(TOAST_ID)) {
|
||||
return;
|
||||
}
|
||||
if (data && mainModels.length === 0 && !didToast && isEnabled) {
|
||||
toast({
|
||||
id: TOAST_ID,
|
||||
title: t('modelManager.noModelsInstalled'),
|
||||
description: <ToastDescription />,
|
||||
status: 'info',
|
||||
isClosable: true,
|
||||
duration: null,
|
||||
onCloseComplete: () => setDidToast(true),
|
||||
});
|
||||
}
|
||||
}, [data, didToast, isEnabled, mainModels.length, t, toast]);
|
||||
};
|
||||
|
||||
const ToastDescription = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const toast = useToast();
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(setActiveTab('modelManager'));
|
||||
toast.close(TOAST_ID);
|
||||
}, [dispatch, toast]);
|
||||
|
||||
return (
|
||||
<Text fontSize="md">
|
||||
{t('modelManager.noModelsInstalledDesc1')}{' '}
|
||||
<Button onClick={onClick} variant="link" color="base.50" flexGrow={0}>
|
||||
{t('modelManager.modelManager')}.
|
||||
</Button>
|
||||
</Text>
|
||||
);
|
||||
};
|
@ -139,7 +139,7 @@ type TooltipLabelProps = {
|
||||
|
||||
const TooltipLabel = ({ name, source, installJob }: TooltipLabelProps) => {
|
||||
const progressString = useMemo(() => {
|
||||
if (installJob.status === 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||
return '';
|
||||
}
|
||||
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||
|
@ -0,0 +1,75 @@
|
||||
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ModelBaseBadge from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
result: GetStarterModelsResponse[number];
|
||||
};
|
||||
export const StarterModelsResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const allSources = useMemo(() => {
|
||||
const _allSources = [result.source];
|
||||
if (result.dependencies) {
|
||||
_allSources.push(...result.dependencies);
|
||||
}
|
||||
return _allSources;
|
||||
}, [result]);
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const handleQuickAdd = useCallback(() => {
|
||||
for (const source of allSources) {
|
||||
installModel({ source })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [allSources, installModel, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Flex gap={3}>
|
||||
<Badge h="min-content">{result.type.replace('_', ' ')}</Badge>
|
||||
<ModelBaseBadge base={result.base} />
|
||||
<Text fontWeight="semibold">{result.name}</Text>
|
||||
</Flex>
|
||||
<Text variant="subtext">{result.description}</Text>
|
||||
</Flex>
|
||||
<Box>
|
||||
{result.is_installed ? (
|
||||
<Badge>{t('common.installed')}</Badge>
|
||||
) : (
|
||||
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,16 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader';
|
||||
import { useGetStarterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { StarterModelsResults } from './StarterModelsResults';
|
||||
|
||||
export const StarterModelsForm = () => {
|
||||
const { isLoading, data } = useGetStarterModelsQuery();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="100%" gap={3}>
|
||||
{isLoading && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
|
||||
{data && <StarterModelsResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement } from '@invoke-ai/ui-library';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import type { GetStarterModelsResponse } from 'services/api/endpoints/models';
|
||||
|
||||
import { StarterModelsResultItem } from './StartModelsResultItem';
|
||||
|
||||
type StarterModelsResultsProps = {
|
||||
results: NonNullable<GetStarterModelsResponse>;
|
||||
};
|
||||
|
||||
export const StarterModelsResults = ({ results }: StarterModelsResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
|
||||
const filteredResults = useMemo(() => {
|
||||
return results.filter((result) => {
|
||||
const name = result.name.toLowerCase();
|
||||
const type = result.type.toLowerCase();
|
||||
return name.includes(searchTerm.toLowerCase()) || type.includes(searchTerm.toLowerCase());
|
||||
});
|
||||
}, [results, searchTerm]);
|
||||
|
||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setSearchTerm(e.target.value.trim());
|
||||
}, []);
|
||||
|
||||
const clearSearch = useCallback(() => {
|
||||
setSearchTerm('');
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={3} height="100%">
|
||||
<Flex justifyContent="flex-end" alignItems="center">
|
||||
<InputGroup w={64} size="xs">
|
||||
<Input
|
||||
placeholder={t('modelManager.search')}
|
||||
value={searchTerm}
|
||||
data-testid="board-search-input"
|
||||
onChange={handleSearch}
|
||||
size="xs"
|
||||
/>
|
||||
|
||||
{searchTerm && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
icon={<PiXBold />}
|
||||
onClick={clearSearch}
|
||||
flexShrink={0}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
<Flex height="100%" layerStyle="third" borderRadius="base" p={3}>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
{filteredResults.map((result) => (
|
||||
<StarterModelsResultItem key={result.source} result={result} />
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,5 +1,8 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { StarterModelsForm } from 'features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
|
||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
@ -8,14 +11,23 @@ import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
|
||||
export const InstallModels = () => {
|
||||
const { t } = useTranslation();
|
||||
const [mainModels, { data }] = useMainModels();
|
||||
const defaultIndex = useMemo(() => {
|
||||
if (data && mainModels.length) {
|
||||
return 0;
|
||||
}
|
||||
return 3;
|
||||
}, [data, mainModels.length]);
|
||||
|
||||
return (
|
||||
<Flex layerStyle="first" borderRadius="base" w="full" h="full" flexDir="column" gap={4}>
|
||||
<Heading fontSize="xl">{t('modelManager.addModel')}</Heading>
|
||||
<Tabs variant="collapse" height="50%" display="flex" flexDir="column">
|
||||
<Tabs variant="collapse" height="50%" display="flex" flexDir="column" defaultIndex={defaultIndex}>
|
||||
<TabList>
|
||||
<Tab>{t('modelManager.urlOrLocalPath')}</Tab>
|
||||
<Tab>{t('modelManager.huggingFace')}</Tab>
|
||||
<Tab>{t('modelManager.scanFolder')}</Tab>
|
||||
<Tab>{t('modelManager.starterModels')}</Tab>
|
||||
</TabList>
|
||||
<TabPanels p={3} height="100%">
|
||||
<TabPanel>
|
||||
@ -27,6 +39,9 @@ export const InstallModels = () => {
|
||||
<TabPanel height="100%">
|
||||
<ScanModelsForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<StarterModelsForm />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
<Box layerStyle="second" borderRadius="base" h="50%">
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { Button, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { HFToken } from 'features/modelManagerV2/components/HFToken';
|
||||
import { SyncModelsButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsButton';
|
||||
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback } from 'react';
|
||||
@ -27,6 +28,7 @@ export const ModelManager = () => {
|
||||
</Button>
|
||||
</Flex>
|
||||
<Flex flexDir="column" layerStyle="second" p={4} gap={4} borderRadius="base" w="full" h="full">
|
||||
<HFToken />
|
||||
<ModelListNavigation />
|
||||
<ModelList />
|
||||
</Flex>
|
||||
|
@ -0,0 +1,15 @@
|
||||
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4} borderRadius="base" p={4} bg="base.800">
|
||||
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
|
||||
<Spinner />
|
||||
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
FetchingModelsLoader.displayName = 'FetchingModelsLoader';
|
@ -17,7 +17,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
|
||||
const ModelBaseBadge = ({ base }: Props) => {
|
||||
return (
|
||||
<Badge flexGrow={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle">
|
||||
<Badge flexGrow={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle" h="min-content">
|
||||
{MODEL_TYPE_SHORT_MAP[base]}
|
||||
</Badge>
|
||||
);
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useControlNetModels,
|
||||
useEmbeddingModels,
|
||||
@ -13,10 +14,12 @@ import {
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||
|
||||
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
|
||||
const ModelList = () => {
|
||||
const { searchTerm, filteredModelType } = useAppSelector((s) => s.modelmanagerV2);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [mainModels, { isLoading: isLoadingMainModels }] = useMainModels();
|
||||
const filteredMainModels = useMemo(
|
||||
@ -66,18 +69,22 @@ const ModelList = () => {
|
||||
{/* Main Model List */}
|
||||
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main Models..." />}
|
||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||
<ModelListWrapper title="Main" modelList={filteredMainModels} key="main" />
|
||||
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
||||
)}
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||
<ModelListWrapper title="LoRA" modelList={filteredLoRAModels} key="loras" />
|
||||
<ModelListWrapper title={t('modelManager.loraModels')} modelList={filteredLoRAModels} key="loras" />
|
||||
)}
|
||||
|
||||
{/* TI List */}
|
||||
{isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Embeddings..." />}
|
||||
{isLoadingEmbeddingModels && <FetchingModelsLoader loadingMessage="Loading Textual Inversions..." />}
|
||||
{!isLoadingEmbeddingModels && filteredEmbeddingModels.length > 0 && (
|
||||
<ModelListWrapper title="Embedding" modelList={filteredEmbeddingModels} key="textual-inversions" />
|
||||
<ModelListWrapper
|
||||
title={t('modelManager.textualInversions')}
|
||||
modelList={filteredEmbeddingModels}
|
||||
key="textual-inversions"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* VAE List */}
|
||||
@ -94,12 +101,16 @@ const ModelList = () => {
|
||||
{/* IP Adapter List */}
|
||||
{isLoadingIPAdapterModels && <FetchingModelsLoader loadingMessage="Loading IP Adapters..." />}
|
||||
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="IP Adapter" modelList={filteredIPAdapterModels} key="ip-adapters" />
|
||||
<ModelListWrapper
|
||||
title={t('modelManager.ipAdapters')}
|
||||
modelList={filteredIPAdapterModels}
|
||||
key="ip-adapters"
|
||||
/>
|
||||
)}
|
||||
{/* T2I Adapters List */}
|
||||
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title="T2I Adapter" modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
@ -120,16 +131,3 @@ const modelsFilter = <T extends AnyModelConfig>(
|
||||
return matchesFilter && matchesType;
|
||||
});
|
||||
};
|
||||
|
||||
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4} borderRadius="base" p={4} bg="base.800">
|
||||
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
|
||||
<Spinner />
|
||||
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
FetchingModelsLoader.displayName = 'FetchingModelsLoader';
|
||||
|
@ -2,24 +2,27 @@ import { Button, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-libr
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { setFilteredModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFunnelBold } from 'react-icons/pi';
|
||||
import { objectKeys } from 'tsafe';
|
||||
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = {
|
||||
main: 'Main',
|
||||
lora: 'LoRA',
|
||||
embedding: 'Textual Inversion',
|
||||
controlnet: 'ControlNet',
|
||||
vae: 'VAE',
|
||||
t2i_adapter: 'T2I Adapter',
|
||||
ip_adapter: 'IP Adapter',
|
||||
};
|
||||
|
||||
export const ModelTypeFilter = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||
() => ({
|
||||
main: t('modelManager.main'),
|
||||
lora: 'LoRA',
|
||||
embedding: t('modelManager.textualInversions'),
|
||||
controlnet: 'ControlNet',
|
||||
vae: 'VAE',
|
||||
t2i_adapter: t('common.t2iAdapter'),
|
||||
ip_adapter: t('modelManager.ipAdapters'),
|
||||
clip_vision: 'Clip Vision',
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
const filteredModelType = useAppSelector((s) => s.modelmanagerV2.filteredModelType);
|
||||
|
||||
const selectModelType = useCallback(
|
||||
|
@ -4,7 +4,7 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
|
||||
|
||||
@ -16,10 +16,11 @@ const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
});
|
||||
|
||||
const InspectorDataTab = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data } = useAppSelector(selector);
|
||||
|
||||
if (!data) {
|
||||
return <IAINoContentFallback label="No node selected" icon={null} />;
|
||||
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
|
||||
}
|
||||
|
||||
return <DataViewer data={data} label="Node Data" />;
|
||||
|
@ -18,6 +18,7 @@ import {
|
||||
SDXL_CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
SEAMLESS,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
@ -38,6 +39,8 @@ export const addVAEToGraph = async (
|
||||
|
||||
const isAutoVae = !vae;
|
||||
const isSeamlessEnabled = seamlessXAxis || seamlessYAxis;
|
||||
const isSDXL = Boolean(graph.id?.includes('sdxl'));
|
||||
const isUsingRefiner = isSDXL && Boolean(refinerModel);
|
||||
|
||||
if (!isAutoVae && !isSeamlessEnabled) {
|
||||
graph.nodes[VAE_LOADER] = {
|
||||
@ -56,7 +59,13 @@ export const addVAEToGraph = async (
|
||||
) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
node_id: isSeamlessEnabled
|
||||
? isUsingRefiner
|
||||
? SDXL_REFINER_SEAMLESS
|
||||
: SEAMLESS
|
||||
: isAutoVae
|
||||
? modelLoaderNodeId
|
||||
: VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -74,7 +83,13 @@ export const addVAEToGraph = async (
|
||||
) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
node_id: isSeamlessEnabled
|
||||
? isUsingRefiner
|
||||
? SDXL_REFINER_SEAMLESS
|
||||
: SEAMLESS
|
||||
: isAutoVae
|
||||
? modelLoaderNodeId
|
||||
: VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -92,7 +107,13 @@ export const addVAEToGraph = async (
|
||||
) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
node_id: isSeamlessEnabled
|
||||
? isUsingRefiner
|
||||
? SDXL_REFINER_SEAMLESS
|
||||
: SEAMLESS
|
||||
: isAutoVae
|
||||
? modelLoaderNodeId
|
||||
: VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -111,7 +132,13 @@ export const addVAEToGraph = async (
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
node_id: isSeamlessEnabled
|
||||
? isUsingRefiner
|
||||
? SDXL_REFINER_SEAMLESS
|
||||
: SEAMLESS
|
||||
: isAutoVae
|
||||
? modelLoaderNodeId
|
||||
: VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -122,7 +149,13 @@ export const addVAEToGraph = async (
|
||||
|
||||
{
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
node_id: isSeamlessEnabled
|
||||
? isUsingRefiner
|
||||
? SDXL_REFINER_SEAMLESS
|
||||
: SEAMLESS
|
||||
: isAutoVae
|
||||
? modelLoaderNodeId
|
||||
: VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
@ -137,7 +170,13 @@ export const addVAEToGraph = async (
|
||||
if (graph.id === SDXL_CANVAS_INPAINT_GRAPH || graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
node_id: isSeamlessEnabled
|
||||
? isUsingRefiner
|
||||
? SDXL_REFINER_SEAMLESS
|
||||
: SEAMLESS
|
||||
: isAutoVae
|
||||
? modelLoaderNodeId
|
||||
: VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
|
@ -8,18 +8,21 @@ import { selectOptimalDimension } from 'features/parameters/store/generationSlic
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const OPTIONS: ComboboxOption[] = [
|
||||
{ label: 'None', value: 'none' },
|
||||
{ label: 'Auto', value: 'auto' },
|
||||
{ label: 'Manual', value: 'manual' },
|
||||
];
|
||||
|
||||
const ParamScaleBeforeProcessing = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const boundingBoxScaleMethod = useAppSelector((s) => s.canvas.boundingBoxScaleMethod);
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
|
||||
const OPTIONS: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ label: t('modelManager.none'), value: 'none' },
|
||||
{ label: t('common.auto'), value: 'auto' },
|
||||
{ label: t('modelManager.manual'), value: 'manual' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isBoundingBoxScaleMethod(v?.value)) {
|
||||
@ -30,7 +33,10 @@ const ParamScaleBeforeProcessing = () => {
|
||||
[dispatch, optimalDimension]
|
||||
);
|
||||
|
||||
const value = useMemo(() => OPTIONS.find((o) => o.value === boundingBoxScaleMethod), [boundingBoxScaleMethod]);
|
||||
const value = useMemo(
|
||||
() => OPTIONS.find((o) => o.value === boundingBoxScaleMethod),
|
||||
[boundingBoxScaleMethod, OPTIONS]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
|
@ -37,20 +37,10 @@ const PER_PAGE = 10;
|
||||
const zOrderBy = z.enum(['opened_at', 'created_at', 'updated_at', 'name']);
|
||||
type OrderBy = z.infer<typeof zOrderBy>;
|
||||
const isOrderBy = (v: unknown): v is OrderBy => zOrderBy.safeParse(v).success;
|
||||
const ORDER_BY_OPTIONS: ComboboxOption[] = [
|
||||
{ value: 'opened_at', label: 'Opened' },
|
||||
{ value: 'created_at', label: 'Created' },
|
||||
{ value: 'updated_at', label: 'Updated' },
|
||||
{ value: 'name', label: 'Name' },
|
||||
];
|
||||
|
||||
const zDirection = z.enum(['ASC', 'DESC']);
|
||||
type Direction = z.infer<typeof zDirection>;
|
||||
const isDirection = (v: unknown): v is Direction => zDirection.safeParse(v).success;
|
||||
const DIRECTION_OPTIONS: ComboboxOption[] = [
|
||||
{ value: 'ASC', label: 'Ascending' },
|
||||
{ value: 'DESC', label: 'Descending' },
|
||||
];
|
||||
|
||||
const WorkflowLibraryList = () => {
|
||||
const { t } = useTranslation();
|
||||
@ -60,9 +50,27 @@ const WorkflowLibraryList = () => {
|
||||
const [query, setQuery] = useState('');
|
||||
const projectId = useStore($projectId);
|
||||
|
||||
const ORDER_BY_OPTIONS: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ value: 'opened_at', label: t('workflows.opened') },
|
||||
{ value: 'created_at', label: t('workflows.created') },
|
||||
{ value: 'updated_at', label: t('workflows.updated') },
|
||||
{ value: 'name', label: t('workflows.name') },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const DIRECTION_OPTIONS: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ value: 'ASC', label: t('workflows.ascending') },
|
||||
{ value: 'DESC', label: t('workflows.descending') },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const orderByOptions = useMemo(() => {
|
||||
return projectId ? ORDER_BY_OPTIONS.filter((option) => option.value !== 'opened_at') : ORDER_BY_OPTIONS;
|
||||
}, [projectId]);
|
||||
}, [projectId, ORDER_BY_OPTIONS]);
|
||||
|
||||
const [order_by, setOrderBy] = useState<WorkflowRecordOrderBy>(orderByOptions[0]?.value as WorkflowRecordOrderBy);
|
||||
const [direction, setDirection] = useState<SQLiteDirection>('ASC');
|
||||
@ -113,7 +121,10 @@ const WorkflowLibraryList = () => {
|
||||
},
|
||||
[direction]
|
||||
);
|
||||
const valueDirection = useMemo(() => DIRECTION_OPTIONS.find((o) => o.value === direction), [direction]);
|
||||
const valueDirection = useMemo(
|
||||
() => DIRECTION_OPTIONS.find((o) => o.value === direction),
|
||||
[direction, DIRECTION_OPTIONS]
|
||||
);
|
||||
|
||||
const resetFilterText = useCallback(() => {
|
||||
setQuery('');
|
||||
|
@ -27,6 +27,18 @@ type GetModelConfigsResponse = NonNullable<
|
||||
paths['/api/v2/models/']['get']['responses']['200']['content']['application/json']
|
||||
>;
|
||||
|
||||
type GetHFTokenStatusResponse =
|
||||
paths['/api/v2/models/hf_login']['get']['responses']['200']['content']['application/json'];
|
||||
type SetHFTokenResponse = NonNullable<
|
||||
paths['/api/v2/models/hf_login']['post']['responses']['200']['content']['application/json']
|
||||
>;
|
||||
type SetHFTokenArg = NonNullable<
|
||||
paths['/api/v2/models/hf_login']['post']['requestBody']['content']['application/json']
|
||||
>;
|
||||
|
||||
export type GetStarterModelsResponse =
|
||||
paths['/api/v2/models/starter_models']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type DeleteModelArg = {
|
||||
key: string;
|
||||
};
|
||||
@ -259,6 +271,26 @@ export const modelsApi = api.injectEndpoints({
|
||||
});
|
||||
},
|
||||
}),
|
||||
getStarterModels: build.query<GetStarterModelsResponse, void>({
|
||||
query: () => buildModelsUrl('starter_models'),
|
||||
providesTags: [{ type: 'ModelConfig', id: LIST_TAG }],
|
||||
}),
|
||||
getHFTokenStatus: build.query<GetHFTokenStatusResponse, void>({
|
||||
query: () => buildModelsUrl('hf_login'),
|
||||
providesTags: ['HFTokenStatus'],
|
||||
}),
|
||||
setHFToken: build.mutation<SetHFTokenResponse, SetHFTokenArg>({
|
||||
query: (body) => ({ url: buildModelsUrl('hf_login'), method: 'POST', body }),
|
||||
invalidatesTags: ['HFTokenStatus'],
|
||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||
try {
|
||||
const { data } = await queryFulfilled;
|
||||
dispatch(modelsApi.util.updateQueryData('getHFTokenStatus', undefined, () => data));
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -277,4 +309,7 @@ export const {
|
||||
useListModelInstallsQuery,
|
||||
useCancelModelInstallMutation,
|
||||
usePruneCompletedModelInstallsMutation,
|
||||
useGetStarterModelsQuery,
|
||||
useGetHFTokenStatusQuery,
|
||||
useSetHFTokenMutation,
|
||||
} = modelsApi;
|
||||
|
@ -12,6 +12,7 @@ export const tagTypes = [
|
||||
'Board',
|
||||
'BoardImagesTotal',
|
||||
'BoardAssetsTotal',
|
||||
'HFTokenStatus',
|
||||
'Image',
|
||||
'ImageNameList',
|
||||
'ImageList',
|
||||
|
File diff suppressed because one or more lines are too long
@ -1 +1 @@
|
||||
__version__ = "4.0.0rc2"
|
||||
__version__ = "4.0.0rc4"
|
||||
|
@ -47,10 +47,10 @@ dependencies = [
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.2",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.1.2",
|
||||
"torch==2.2.1",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.16.2",
|
||||
"torchvision==0.17.1",
|
||||
"transformers==4.38.2",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
@ -96,7 +96,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
"xformers" = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"xformers==0.0.23.post1; sys_platform!='darwin'",
|
||||
"xformers==0.0.25; sys_platform!='darwin'",
|
||||
# Auxiliary dependencies, pinned only if necessary.
|
||||
"triton; sys_platform=='linux'",
|
||||
]
|
||||
@ -125,27 +125,8 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
# legacy entrypoints; provided for backwards compatibility
|
||||
"configure_invokeai.py" = "invokeai.frontend.install.invokeai_configure:run_configure"
|
||||
"textual_inversion.py" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||
|
||||
# shortcut commands to start web ui
|
||||
# "invokeai --web" will launch the web interface
|
||||
# "invokeai" = "invokeai.frontend.legacy_launch_invokeai:main"
|
||||
|
||||
# new shortcut to launch web interface
|
||||
"invokeai-web" = "invokeai.app.run_app:run_app"
|
||||
|
||||
# full commands
|
||||
"invokeai-configure" = "invokeai.frontend.install.invokeai_configure:run_configure"
|
||||
"invokeai-merge" = "invokeai.frontend.merge.merge_diffusers:main"
|
||||
"invokeai-ti" = "invokeai.frontend.training:invokeai_textual_inversion"
|
||||
"invokeai-model-install" = "invokeai.frontend.install.model_install:main"
|
||||
"invokeai-model-install2" = "invokeai.frontend.install.model_install2:main" # will eventually be renamed to invokeai-model-install
|
||||
"invokeai-migrate3" = "invokeai.backend.install.migrate_to_3:main"
|
||||
"invokeai-update" = "invokeai.frontend.install.invokeai_update:main"
|
||||
"invokeai-metadata" = "invokeai.backend.image_util.invoke_metadata:main"
|
||||
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
|
||||
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
|
||||
|
||||
@ -189,7 +170,7 @@ version = { attr = "invokeai.version.__version__" }
|
||||
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
|
||||
markers = [
|
||||
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
|
||||
"timeout: Marks the timeout override."
|
||||
"timeout: Marks the timeout override.",
|
||||
]
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
|
@ -1,12 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
|
||||
import warnings
|
||||
|
||||
from invokeai.frontend.install.invokeai_configure import run_configure as configure
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.warn(
|
||||
"configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
configure()
|
@ -1,29 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
|
||||
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
||||
print(
|
||||
"This script opens up the indicated invoke.py-generated PNG file(s) and prints out the prompt used to generate them."
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
filenames = sys.argv[1:]
|
||||
for f in filenames:
|
||||
try:
|
||||
im = Image.open(f)
|
||||
try:
|
||||
prompt = im.text["Dream"]
|
||||
except KeyError:
|
||||
prompt = ""
|
||||
print(f"{f}: {prompt}")
|
||||
except FileNotFoundError:
|
||||
sys.stderr.write(f"{f} not found\n")
|
||||
continue
|
||||
except PermissionError:
|
||||
sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n")
|
||||
continue
|
@ -1,22 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
|
||||
|
||||
|
||||
def main():
|
||||
# Change working directory to the repo root
|
||||
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
# TODO: Parse some top-level args here.
|
||||
from invokeai.app.cli_app import invoke_cli
|
||||
|
||||
invoke_cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from invokeai.frontend.install.model_install import main
|
||||
|
||||
main()
|
@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
This script is used at release time to generate a markdown table describing the
|
||||
starter models. This text is then manually copied into 050_INSTALL_MODELS.md.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def main():
|
||||
initial_models_file = Path(__file__).parent / "../invokeai/configs/INITIAL_MODELS.yaml"
|
||||
models = OmegaConf.load(initial_models_file)
|
||||
print("|Model Name | HuggingFace Repo ID | Description | URL |")
|
||||
print("|---------- | ---------- | ----------- | --- |")
|
||||
for model in models:
|
||||
repo_id = models[model].repo_id
|
||||
url = f"https://huggingface.co/{repo_id}"
|
||||
print(f"|{model}|{repo_id}|{models[model].description}|{url} |")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,29 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import requests
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
local_version = str(__version__).replace("-", "")
|
||||
package_name = "InvokeAI"
|
||||
|
||||
|
||||
def get_pypi_versions(package_name=package_name) -> list[str]:
|
||||
"""Get the versions of the package from PyPI"""
|
||||
url = f"https://pypi.org/pypi/{package_name}/json"
|
||||
response = requests.get(url).json()
|
||||
versions: list[str] = list(response["releases"].keys())
|
||||
return versions
|
||||
|
||||
|
||||
def local_on_pypi(package_name=package_name, local_version=local_version) -> bool:
|
||||
"""Compare the versions of the package from PyPI and the local package"""
|
||||
pypi_versions = get_pypi_versions(package_name)
|
||||
return local_version in pypi_versions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if local_on_pypi():
|
||||
print(f"Package {package_name} is up to date")
|
||||
else:
|
||||
print(f"Package {package_name} is not up to date")
|
@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Scan the models directory and print out a new models.yaml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model directory scanner")
|
||||
parser.add_argument("models_directory")
|
||||
parser.add_argument(
|
||||
"--all-models",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If true, then generates stanzas for all models; otherwise just diffusers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
directory = args.models_directory
|
||||
|
||||
conf = OmegaConf.create()
|
||||
conf["_version"] = "3.0.0"
|
||||
|
||||
for root, dirs, files in os.walk(directory):
|
||||
parents = root.split("/")
|
||||
subpaths = parents[parents.index("models") + 1 :]
|
||||
if len(subpaths) < 2:
|
||||
continue
|
||||
base, model_type, *_ = subpaths
|
||||
|
||||
if args.all_models or model_type == "diffusers":
|
||||
for d in dirs:
|
||||
conf[f"{base}/{model_type}/{d}"] = {
|
||||
"path": os.path.join(root, d),
|
||||
"description": f"{model_type} model {d}",
|
||||
"format": "folder",
|
||||
"base": base,
|
||||
}
|
||||
|
||||
for f in files:
|
||||
basename = Path(f).stem
|
||||
format = Path(f).suffix[1:]
|
||||
conf[f"{base}/{model_type}/{basename}"] = {
|
||||
"path": os.path.join(root, f),
|
||||
"description": f"{model_type} model {basename}",
|
||||
"format": format,
|
||||
"base": base,
|
||||
}
|
||||
|
||||
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,23 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from invokeai.backend.image_util import retrieve_metadata
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
|
||||
print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out their metadata.")
|
||||
exit(-1)
|
||||
|
||||
filenames = sys.argv[1:]
|
||||
for f in filenames:
|
||||
try:
|
||||
metadata = retrieve_metadata(f)
|
||||
print(f"{f}:\n", json.dumps(metadata["sd-metadata"], indent=4))
|
||||
except FileNotFoundError:
|
||||
sys.stderr.write(f"{f} not found\n")
|
||||
continue
|
||||
except PermissionError:
|
||||
sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n")
|
||||
continue
|
@ -35,7 +35,7 @@ def store(
|
||||
datadir: Any,
|
||||
) -> ModelRecordServiceSQL:
|
||||
config = InvokeAIAppConfig()
|
||||
config.set_root(datadir)
|
||||
config._root = datadir
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = create_mock_sqlite_database(config, logger)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
@ -92,7 +92,7 @@ def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
@pytest.fixture
|
||||
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info")
|
||||
app_config.set_root(mm2_root_dir)
|
||||
app_config._root = mm2_root_dir
|
||||
return app_config
|
||||
|
||||
|
||||
|
@ -9,14 +9,14 @@ from pydantic import ValidationError
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
|
||||
|
||||
v4_config = """
|
||||
schema_version: 4
|
||||
schema_version: 4.0.0
|
||||
|
||||
host: "192.168.1.1"
|
||||
port: 8080
|
||||
"""
|
||||
|
||||
invalid_v5_config = """
|
||||
schema_version: 5
|
||||
schema_version: 5.0.0
|
||||
|
||||
host: "192.168.1.1"
|
||||
port: 8080
|
||||
@ -170,7 +170,7 @@ def test_set_and_resolve_paths():
|
||||
"""Test setting root and resolving paths based on it."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
config = InvokeAIAppConfig()
|
||||
config.set_root(Path(tmpdir))
|
||||
config._root = Path(tmpdir)
|
||||
assert config.models_path == Path(tmpdir).resolve() / "models"
|
||||
assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user