Compare commits

..

2 Commits

Author SHA1 Message Date
ec1e66dcd3 refactor the model bases variable 2023-11-04 11:25:13 -04:00
69543c23d0 fix model-not-found error 2023-11-04 11:12:29 -04:00
727 changed files with 13710 additions and 19081 deletions

20
.github/workflows/pyflakes.yml vendored Normal file
View File

@ -0,0 +1,20 @@
on:
pull_request:
push:
branches:
- main
- development
- 'release-candidate-*'
jobs:
pyflakes:
name: runner / pyflakes
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: pyflakes
uses: reviewdog/action-pyflakes@v1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
reporter: github-pr-review

View File

@ -6,7 +6,7 @@ on:
branches: main
jobs:
ruff:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
@ -18,7 +18,8 @@ jobs:
- name: Install dependencies with pip
run: |
pip install ruff
pip install black flake8 Flake8-pyproject isort
- run: ruff check --output-format=github .
- run: ruff format --check .
- run: isort --check-only .
- run: black --check .
- run: flake8

View File

@ -1,21 +0,0 @@
# simple Makefile with scripts that are otherwise hard to remember
# to use, run from the repo root `make <command>`
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
ruff check . --fix
ruff format .
# Runs ruff, fixing all errors it can fix and formatting
ruff-unsafe:
ruff check . --fix --unsafe-fixes
ruff format .
# Runs mypy, using the config in pyproject.toml
mypy:
mypy scripts/invokeai-web.py
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
# (many files are ignored by the config, so this is useful for checking all files)
mypy-all:
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports

View File

@ -161,7 +161,7 @@ the command `npm install -g yarn` if needed)
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
_For Linux with an AMD GPU:_
@ -175,7 +175,7 @@ the command `npm install -g yarn` if needed)
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2/M3:_
_For Macintoshes, either Intel or M1/M2:_
```sh
pip install InvokeAI --use-pep517
@ -395,7 +395,7 @@ Notes](https://github.com/invoke-ai/InvokeAI/releases) and the
### Troubleshooting
Please check out our **[Troubleshooting Guide](https://invoke-ai.github.io/InvokeAI/installation/010_INSTALL_AUTOMATED/#troubleshooting)** to get solutions for common installation
Please check out our **[Q&A](https://invoke-ai.github.io/InvokeAI/help/TROUBLESHOOT/#faq)** to get solutions for common installation
problems and other issues. For more help, please join our [Discord][discord link]
## Contributing

View File

@ -11,5 +11,5 @@ INVOKEAI_ROOT=
# HUGGING_FACE_HUB_TOKEN=
## optional variables specific to the docker setup.
# GPU_DRIVER=cuda # or rocm
# CONTAINER_UID=1000
# GPU_DRIVER=cuda
# CONTAINER_UID=1000

View File

@ -18,8 +18,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv/invokeai
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ARG TORCH_VERSION=2.1.0
ARG TORCHVISION_VERSION=0.16
ARG TORCH_VERSION=2.0.1
ARG TORCHVISION_VERSION=0.15.2
ARG GPU_DRIVER=cuda
ARG TARGETPLATFORM="linux/amd64"
# unused but available
@ -35,7 +35,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
extra_index_url_arg="--index-url https://download.pytorch.org/whl/rocm5.6"; \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.4.2"; \
else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
fi &&\

View File

@ -15,10 +15,6 @@ services:
- driver: nvidia
count: 1
capabilities: [gpu]
# For AMD support, comment out the deploy section above and uncomment the devices section below:
#devices:
# - /dev/kfd:/dev/kfd
# - /dev/dri:/dev/dri
build:
context: ..
dockerfile: docker/Dockerfile

View File

@ -7,5 +7,5 @@ set -e
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
cd "$SCRIPTDIR" || exit 1
docker compose up -d
docker compose up --build -d
docker compose logs -f

View File

@ -1,6 +1,6 @@
# Nodes
# Invocations
Features in InvokeAI are added in the form of modular nodes systems called
Features in InvokeAI are added in the form of modular node-like systems called
**Invocations**.
An Invocation is simply a single operation that takes in some inputs and gives
@ -9,34 +9,13 @@ complex functionality.
## Invocations Directory
InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes.
InvokeAI Invocations can be found in the `invokeai/app/invocations` directory.
New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup.
Example `nodes` subfolder structure:
```py
├── __init__.py # Invoke-managed custom node loader
├── cool_node
├── __init__.py # see example below
└── cool_node.py
└── my_node_pack
├── __init__.py # see example below
├── tasty_node.py
├── bodacious_node.py
├── utils.py
└── extra_nodes
└── fancy_node.py
```
Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded.
See the README in the nodes folder for more examples:
```py
from .cool_node import CoolInvocation
```
You can add your new functionality to one of the existing Invocations in this
directory or create a new file in this directory as per your needs.
**Note:** _All Invocations must be inside this directory for InvokeAI to
recognize them as valid Invocations._
## Creating A New Invocation
@ -65,7 +44,7 @@ The first set of things we need to do when creating a new Invocation are -
So let us do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from .baseinvocation import BaseInvocation, invocation
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@ -99,8 +78,8 @@ create your own custom field types later in this guide. For now, let's go ahead
and use it.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@ -124,8 +103,8 @@ image: ImageField = InputField(description="The input image")
Great. Now let us create our other inputs for `width` and `height`
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
from invokeai.app.invocations.primitives import ImageField
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@ -160,8 +139,8 @@ that are provided by it by InvokeAI.
Let us create this function first.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@ -189,9 +168,9 @@ all the necessary info related to image outputs. So let us use that.
We will cover how to create your own output types later in this guide.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.image import ImageOutput
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
from .image import ImageOutput
@invocation('resize')
class ResizeInvocation(BaseInvocation):
@ -216,9 +195,9 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
So let's do that.
```python
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
from .baseinvocation import BaseInvocation, InputField, invocation
from .primitives import ImageField
from .image import ImageOutput
@invocation("resize")
class ResizeInvocation(BaseInvocation):

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,12 @@
---
title: Textual Inversion Embeddings and LoRAs
---
# :material-library-shelves: Textual Inversions and LoRAs
With the advances in research, many new capabilities are available to customize the knowledge and understanding of novel concepts not originally contained in the base model.
## Using Textual Inversion Files
Textual inversion (TI) files are small models that customize the output of
@ -52,4 +61,29 @@ files it finds there for compatible models. At startup you will see a message si
>> Current embedding manager terms: <HOI4-Leader>, <princess-knight>
```
To use these when generating, simply type the `<` key in your prompt to open the Textual Inversion WebUI and
select the embedding you'd like to use. This UI has type-ahead support, so you can easily find supported embeddings.
select the embedding you'd like to use. This UI has type-ahead support, so you can easily find supported embeddings.
## Using LoRAs
LoRA files are models that customize the output of Stable Diffusion
image generation. Larger than embeddings, but much smaller than full
models, they augment SD with improved understanding of subjects and
artistic styles.
Unlike TI files, LoRAs do not introduce novel vocabulary into the
model's known tokens. Instead, LoRAs augment the model's weights that
are applied to generate imagery. LoRAs may be supplied with a
"trigger" word that they have been explicitly trained on, or may
simply apply their effect without being triggered.
LoRAs are typically stored in .safetensors files, which are the most
secure way to store and transmit these types of weights. You may
install any number of `.safetensors` LoRA files simply by copying them
into the `autoimport/lora` directory of the corresponding InvokeAI models
directory (usually `invokeai` in your home directory).
To use these when generating, open the LoRA menu item in the options
panel, select the LoRAs you want to apply and ensure that they have
the appropriate weight recommended by the model provider. Typically,
most LoRAs perform best at a weight of .75-1.

View File

@ -1,53 +0,0 @@
---
title: LoRAs & LCM-LoRAs
---
# :material-library-shelves: LoRAs & LCM-LoRAs
With the advances in research, many new capabilities are available to customize the knowledge and understanding of novel concepts not originally contained in the base model.
## LoRAs
Low-Rank Adaptation (LoRA) files are models that customize the output of Stable Diffusion
image generation. Larger than embeddings, but much smaller than full
models, they augment SD with improved understanding of subjects and
artistic styles.
Unlike TI files, LoRAs do not introduce novel vocabulary into the
model's known tokens. Instead, LoRAs augment the model's weights that
are applied to generate imagery. LoRAs may be supplied with a
"trigger" word that they have been explicitly trained on, or may
simply apply their effect without being triggered.
LoRAs are typically stored in .safetensors files, which are the most
secure way to store and transmit these types of weights. You may
install any number of `.safetensors` LoRA files simply by copying them
into the `autoimport/lora` directory of the corresponding InvokeAI models
directory (usually `invokeai` in your home directory).
To use these when generating, open the LoRA menu item in the options
panel, select the LoRAs you want to apply and ensure that they have
the appropriate weight recommended by the model provider. Typically,
most LoRAs perform best at a weight of .75-1.
## LCM-LoRAs
Latent Consistency Models (LCMs) allowed a reduced number of steps to be used to generate images with Stable Diffusion. These are created by distilling base models, creating models that only require a small number of steps to generate images. However, LCMs require that any fine-tune of a base model be distilled to be used as an LCM.
LCM-LoRAs are models that provide the benefit of LCMs but are able to be used as LoRAs and applied to any fine tune of a base model. LCM-LoRAs are created by training a small number of adapters, rather than distilling the entire fine-tuned base model. The resulting LoRA can be used the same way as a standard LoRA, but with a greatly reduced step count. This enables SDXL images to be generated up to 10x faster than without the use of LCM-LoRAs.
**Using LCM-LoRAs**
LCM-LoRAs are natively supported in InvokeAI throughout the application. To get started, install any diffusers format LCM-LoRAs using the model manager and select it in the LoRA field.
There are a number parameter differences when using LCM-LoRAs and standard generation:
- When using LCM-LoRAs, the LoRA strength should be lower than if using a standard LoRA, with 0.35 recommended as a starting point.
- The LCM scheduler should be used for generation
- CFG-Scale should be reduced to ~1
- Steps should be reduced in the range of 4-8
Standard LoRAs can also be used alongside LCM-LoRAs, but will also require a lower strength, with 0.45 being recommended as a starting point.
More information can be found here: https://huggingface.co/blog/lcm_lora#fast-inference-with-sdxl-lcm-loras

View File

@ -20,7 +20,7 @@ a single convenient digital artist-optimized user interface.
### * [Prompt Engineering](PROMPTS.md)
Get the images you want with the InvokeAI prompt engineering language.
### * The [LoRA, LyCORIS, LCM-LoRA Models](CONCEPTS.md)
### * The [LoRA, LyCORIS and Textual Inversion Models](CONCEPTS.md)
Add custom subjects and styles using a variety of fine-tuned models.
### * [ControlNet](CONTROLNET.md)
@ -40,7 +40,7 @@ guide also covers optimizing models to load quickly.
Teach an old model new tricks. Merge 2-3 models together to create a
new model that combines characteristics of the originals.
### * [Textual Inversion](TEXTUAL_INVERSIONS.md)
### * [Textual Inversion](TRAINING.md)
Personalize models by adding your own style or subjects.
## Other Features

View File

@ -1,43 +0,0 @@
# FAQs
**Where do I get started? How can I install Invoke?**
- You can download the latest installers [here](https://github.com/invoke-ai/InvokeAI/releases) - Note that any releases marked as *pre-release* are in a beta state. You may experience some issues, but we appreciate your help testing those! For stable/reliable installations, please install the **[Latest Release](https://github.com/invoke-ai/InvokeAI/releases/latest)**
**How can I download models? Can I use models I already have downloaded?**
- Models can be downloaded through the model manager, or through option [4] in the invoke.bat/invoke.sh launcher script. To download a model through the Model Manager, use the HuggingFace Repo ID by pressing the “Copy” button next to the repository name. Alternatively, to download a model from CivitAi, use the download link in the Model Manager.
- Models that are already downloaded can be used by creating a symlink to the model location in the `autoimport` folder or by using the Model Mangers “Scan for Models” function.
**My images are taking a long time to generate. How can I speed up generation?**
- A common solution is to reduce the size of your RAM & VRAM cache to 0.25. This ensures your system has enough memory to generate images.
- Additionally, check the [hardware requirements](https://invoke-ai.github.io/InvokeAI/#hardware-requirements) to ensure that your system is capable of generating images.
- Lastly, double check your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup.
**Ive installed Python on Windows but the installer says it cant find it?**
- Then ensure that you checked **'Add python.exe to PATH'** when installing Python. This can be found at the bottom of the Python Installer window. If you already have Python installed, this can be done with the modify / repair feature of the installer.
**Ive installed everything successfully but I still get an error about Triton when starting Invoke?**
- This can be safely ignored. InvokeAI doesn't use Triton, but if you are on Linux and wish to dismiss the error, you can install Triton.
**I updated to 3.4.0 and now xFormers cant load C++/CUDA?**
- An issue occurred with your PyTorch update. Follow these steps to fix :
1. Launch your invoke.bat / invoke.sh and select the option to open the developer console
2. Run:`pip install ".[xformers]" --upgrade --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121`
- If you run into an error with `typing_extensions`, re-open the developer console and run: `pip install -U typing-extensions`
**It says my pip is out of date - is that why my install isn't working?**
- An out of date won't cause an installation to fail. The cause of the error can likely be found above the message that says pip is out of date.
- If you saw that warning but the install went well, don't worry about it (but you can update pip afterwards if you'd like).
**How can I generate the exact same that I found on the internet?**
Most example images with prompts that you'll find on the internet have been generated using different software, so you can't expect to get identical results. In order to reproduce an image, you need to replicate the exact settings and processing steps, including (but not limited to) the model, the positive and negative prompts, the seed, the sampler, the exact image size, any upscaling steps, etc.
**Where can I get more help?**
- Create an issue on [GitHub](https://github.com/invoke-ai/InvokeAI/issues) or post in the [#help channel](https://discord.com/channels/1020123559063990373/1149510134058471514) of the InvokeAI Discord

View File

@ -101,13 +101,16 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
<div align="center"><img src="assets/invoke-web-server-1.png" width=640></div>
!!! Note
This project is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates as it will help aid response time.
## :octicons-link-24: Quick Links
<div class="button-container">
<a href="installation/INSTALLATION"> <button class="button">Installation</button> </a>
<a href="features/"> <button class="button">Features</button> </a>
<a href="help/gettingStartedWithAI/"> <button class="button">Getting Started</button> </a>
<a href="help/FAQ/"> <button class="button">FAQ</button> </a>
<a href="contributing/CONTRIBUTING/"> <button class="button">Contributing</button> </a>
<a href="https://github.com/invoke-ai/InvokeAI/"> <button class="button">Code and Downloads</button> </a>
<a href="https://github.com/invoke-ai/InvokeAI/issues"> <button class="button">Bug Reports </button> </a>
@ -195,7 +198,6 @@ The list of schedulers has been completely revamped and brought up to date:
| **dpmpp_2m** | DPMSolverMultistepScheduler | original noise scnedule |
| **dpmpp_2m_k** | DPMSolverMultistepScheduler | using karras noise schedule |
| **unipc** | UniPCMultistepScheduler | CPU only |
| **lcm** | LCMScheduler | |
Please see [3.0.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.0.0) for further details.

View File

@ -179,7 +179,7 @@ experimental versions later.
you will have the choice of CUDA (NVidia cards), ROCm (AMD cards),
or CPU (no graphics acceleration). On Windows, you'll have the
choice of CUDA vs CPU, and on Macs you'll be offered CPU only. When
you select CPU on M1/M2/M3 Macintoshes, you will get MPS-based
you select CPU on M1 or M2 Macintoshes, you will get MPS-based
graphics acceleration without installing additional drivers. If you
are unsure what GPU you are using, you can ask the installer to
guess.
@ -471,7 +471,7 @@ Then type the following commands:
=== "NVIDIA System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu121
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/cu118
pip install xformers
```

View File

@ -148,7 +148,7 @@ manager, please follow these steps:
=== "CUDA (NVidia)"
```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
=== "ROCm (AMD)"
@ -327,7 +327,7 @@ installation protocol (important!)
=== "CUDA (NVidia)"
```bash
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install -e .[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
```
=== "ROCm (AMD)"
@ -375,7 +375,7 @@ you can do so using this unsupported recipe:
mkdir ~/invokeai
conda create -n invokeai python=3.10
conda activate invokeai
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
invokeai-configure --root ~/invokeai
invokeai --root ~/invokeai --web
```

View File

@ -85,7 +85,7 @@ You can find which version you should download from [this link](https://docs.nvi
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/cu121` as described in the [Manual
https://download.pytorch.org/whl/cu118` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
## :simple-amd: ROCm

View File

@ -30,7 +30,7 @@ methodology for details on why running applications in such a stateless fashion
The container is configured for CUDA by default, but can be built to support AMD GPUs
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
Developers on Apple silicon (M1/M2/M3): You
Developers on Apple silicon (M1/M2): You
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
and performance is reduced compared with running it directly on macOS but for
development purposes it's fine. Once you're done with development tasks on your

View File

@ -28,7 +28,7 @@ command line, then just be sure to activate it's virtual environment.
Then run the following three commands:
```sh
pip install xformers~=0.0.22
pip install xformers~=0.0.19
pip install triton # WON'T WORK ON WINDOWS
python -m xformers.info output
```
@ -42,7 +42,7 @@ If all goes well, you'll see a report like the
following:
```sh
xFormers 0.0.22
xFormers 0.0.20
memory_efficient_attention.cutlassF: available
memory_efficient_attention.cutlassB: available
memory_efficient_attention.flshattF: available
@ -59,14 +59,14 @@ swiglu.gemm_fused_operand_sum: available
swiglu.fused.p.cpp: available
is_triton_available: True
is_functorch_available: False
pytorch.version: 2.1.0+cu121
pytorch.version: 2.0.1+cu118
pytorch.cuda: available
gpu.compute_capability: 8.9
gpu.name: NVIDIA GeForce RTX 4070
build.info: available
build.cuda_version: 1108
build.python_version: 3.10.11
build.torch_version: 2.1.0+cu121
build.torch_version: 2.0.1+cu118
build.env.TORCH_CUDA_ARCH_LIST: 5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE: Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS: None
@ -92,22 +92,33 @@ installed from source. These instructions were written for a system
running Ubuntu 22.04, but other Linux distributions should be able to
adapt this recipe.
#### 1. Install CUDA Toolkit 12.1
#### 1. Install CUDA Toolkit 11.8
You will need the CUDA developer's toolkit in order to compile and
install xFormers. **Do not try to install Ubuntu's nvidia-cuda-toolkit
package.** It is out of date and will cause conflicts among the NVIDIA
driver and binaries. Instead install the CUDA Toolkit package provided
by NVIDIA itself. Go to [CUDA Toolkit 12.1
Downloads](https://developer.nvidia.com/cuda-12-1-0-download-archive)
by NVIDIA itself. Go to [CUDA Toolkit 11.8
Downloads](https://developer.nvidia.com/cuda-11-8-0-download-archive)
and use the target selection wizard to choose your platform and Linux
distribution. Select an installer type of "runfile (local)" at the
last step.
This will provide you with a recipe for downloading and running a
install shell script that will install the toolkit and drivers.
install shell script that will install the toolkit and drivers. For
example, the install script recipe for Ubuntu 22.04 running on a
x86_64 system is:
#### 2. Confirm/Install pyTorch 2.1.0 with CUDA 12.1 support
```
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sudo sh cuda_11.8.0_520.61.05_linux.run
```
Rather than cut-and-paste this example, We recommend that you walk
through the toolkit wizard in order to get the most up to date
installer for your system.
#### 2. Confirm/Install pyTorch 2.01 with CUDA 11.8 support
If you are using InvokeAI 3.0.2 or higher, these will already be
installed. If not, you can check whether you have the needed libraries
@ -122,7 +133,7 @@ Then run the command:
python -c 'exec("import torch\nprint(torch.__version__)")'
```
If it prints __2.1.0+cu121__ you're good. If not, you can install the
If it prints __1.13.1+cu118__ you're good. If not, you can install the
most up to date libraries with this command:
```sh

View File

@ -8,7 +8,7 @@ To use a node, add the node to the `nodes` folder found in your InvokeAI install
The suggested method is to use `git clone` to clone the repository the node is found in. This allows for easy updates of the node in the future.
If you'd prefer, you can also just download the whole node folder from the linked repository and add it to the `nodes` folder.
If you'd prefer, you can also just download the `.py` file from the linked repository and add it to the `nodes` folder.
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
@ -26,15 +26,12 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [Image Picker](#image-picker)
+ [Load Video Frame](#load-video-frame)
+ [Make 3D](#make-3d)
+ [Match Histogram](#match-histogram)
+ [Oobabooga](#oobabooga)
+ [Prompt Tools](#prompt-tools)
+ [Remote Image](#remote-image)
+ [Retroize](#retroize)
+ [Size Stepper Nodes](#size-stepper-nodes)
+ [Text font to Image](#text-font-to-image)
+ [Thresholding](#thresholding)
+ [Unsharp Mask](#unsharp-mask)
+ [XY Image to Grid and Images to Grids nodes](#xy-image-to-grid-and-images-to-grids-nodes)
- [Example Node Template](#example-node-template)
- [Disclaimer](#disclaimer)
@ -209,23 +206,6 @@ This includes 15 Nodes:
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png" width="300" />
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png" width="300" />
--------------------------------
### Match Histogram
**Description:** An InvokeAI node to match a histogram from one image to another. This is a bit like the `color correct` node in the main InvokeAI but this works in the YCbCr colourspace and can handle images of different sizes. Also does not require a mask input.
- Option to only transfer luminance channel.
- Option to save output as grayscale
A good use case for this node is to normalize the colors of an image that has been through the tiled scaling workflow of my XYGrid Nodes.
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
**Node Link:** https://github.com/skunkworxdark/match_histogram
**Output Examples**
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
--------------------------------
### Oobabooga
@ -255,41 +235,22 @@ This node works best with SDXL models, especially as the style can be described
--------------------------------
### Prompt Tools
**Description:** A set of InvokeAI nodes that add general prompt (string) manipulation tools. Designed to accompany the `Prompts From File` node and other prompt generation nodes.
1. `Prompt To File` - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
2. `PTFields Collect` - Converts image generation fields into a Json format string that can be passed to Prompt to file.
3. `PTFields Expand` - Takes Json string and converts it to individual generation parameters. This can be fed from the Prompts to file node.
4. `Prompt Strength` - Formats prompt with strength like the weighted format of compel
5. `Prompt Strength Combine` - Combines weighted prompts for .and()/.blend()
6. `CSV To Index String` - Gets a string from a CSV by index. Includes a Random index option
The following Nodes are now included in v3.2 of Invoke and are nolonger in this set of tools.<br>
- `Prompt Join` -> `String Join`
- `Prompt Join Three` -> `String Join Three`
- `Prompt Replace` -> `String Replace`
- `Prompt Split Neg` -> `String Split Neg`
**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These were written to accompany the PromptsFromFile node and other prompt generation nodes.
1. PromptJoin - Joins to prompts into one.
2. PromptReplace - performs a search and replace on a prompt. With the option of using regex.
3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative.
4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file.
6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node.
7. PromptJoinThree - Joins 3 prompt together.
8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel.
9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node.
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
**Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes
**Workflow Examples**
<img src="https://github.com/skunkworxdark/prompt-tools/blob/main/images/CSVToIndexStringNode.png" width="300" />
--------------------------------
### Remote Image
**Description:** This is a pack of nodes to interoperate with other services, be they public websites or bespoke local servers. The pack consists of these nodes:
- *Load Remote Image* - Lets you load remote images such as a realtime webcam image, an image of the day, or dynamically created images.
- *Post Image to Remote Server* - Lets you upload an image to a remote server using an HTTP POST request, eg for storage, display or further processing.
**Node Link:** https://github.com/fieldOfView/InvokeAI-remote_image
--------------------------------
### Retroize
@ -355,37 +316,18 @@ Highlights/Midtones/Shadows (with LUT blur enabled):
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0a440e43-697f-4d17-82ee-f287467df0a5" width="300" />
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/0701fd0f-2ca7-4fe2-8613-2b52547bafce" width="300" />
--------------------------------
### Unsharp Mask
**Description:** Applies an unsharp mask filter to an image, preserving its alpha channel in the process.
**Node Link:** https://github.com/JPPhoto/unsharp-mask-node
--------------------------------
### XY Image to Grid and Images to Grids nodes
**Description:** These nodes add the following to InvokeAI:
- Generate grids of images from multiple input images
- Create XY grid images with labels from parameters
- Split images into overlapping tiles for processing (for super-resolution workflows)
- Recombine image tiles into a single output image blending the seams
**Description:** Image to grid nodes and supporting tools.
The nodes include:
1. `Images To Grids` - Combine multiple images into a grid of images
2. `XYImage To Grid` - Take X & Y params and creates a labeled image grid.
3. `XYImage Tiles` - Super-resolution (embiggen) style tiled resizing
4. `Image Tot XYImages` - Takes an image and cuts it up into a number of columns and rows.
5. Multiple supporting nodes - Helper nodes for data wrangling and building `XYImage` collections
1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then multiple grids will be created until it runs out of images.
2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporting nodes. See example node setups for more details.
See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md
**Node Link:** https://github.com/skunkworxdark/XYGrid_nodes
**Output Examples**
<img src="https://github.com/skunkworxdark/XYGrid_nodes/blob/main/images/collage.png" width="300" />
--------------------------------
### Example Node Template

View File

@ -7,12 +7,12 @@ To use them, right click on your desired workflow, follow the link to GitHub and
If you're interested in finding more workflows, checkout the [#share-your-workflows](https://discord.com/channels/1020123559063990373/1130291608097661000) channel in the InvokeAI Discord.
* [SD1.5 / SD2 Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Text_to_Image.json)
* [SDXL Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_Text_to_Image.json)
* [SDXL Text to Image with Refiner](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_w_Refiner_Text_to_Image.json)
* [Multi ControlNet (Canny & Depth)](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Multi_ControlNet_Canny_and_Depth.json)
* [SDXL Text to Image](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/SDXL_Text_to_Image.json)
* [SDXL Text to Image with Refiner](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/SDXL_w_Refiner_Text_to_Image.json)
* [Multi ControlNet (Canny & Depth)](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/Multi_ControlNet_Canny_and_Depth.json)
* [Tiled Upscaling with ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/ESRGAN_img2img_upscale_w_Canny_ControlNet.json)
* [Prompt From File](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Prompt_from_File.json)
* [Face Detailer with IP-Adapter & ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Face_Detailer_with_IP-Adapter_and_Canny.json)
* [Prompt From File](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/Prompt_from_File.json)
* [Face Detailer with IP-Adapter & ControlNet](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/Face_Detailer_with_IP-Adapter_and_Canny.json.json)
* [FaceMask](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceMask.json)
* [FaceOff with 2x Face Scaling](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceOff_FaceScale2x.json)
* [QR Code Monster](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/QR_Code_Monster.json)
* [QR Code Monster](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/QR_Code_Monster.json)

View File

@ -244,7 +244,7 @@ class InvokeAiInstance:
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch==2.1.0",
"torch~=2.0.0",
"torchmetrics==0.11.4",
"torchvision>=0.14.1",
"--force-reinstall",
@ -460,10 +460,10 @@ def get_torch_source() -> (Union[str, None], str):
url = "https://download.pytorch.org/whl/cpu"
if device == "cuda":
url = "https://download.pytorch.org/whl/cu121"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu121"
url = "https://download.pytorch.org/whl/cu118"
optional_modules = "[xformers,onnx-directml]"
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@ -137,7 +137,7 @@ def dest_path(dest=None) -> Path:
path_completer = PathCompleter(
only_directories=True,
expanduser=True,
get_paths=lambda: [browse_start], # noqa: B023
get_paths=lambda: [browse_start],
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
)
@ -149,7 +149,7 @@ def dest_path(dest=None) -> Path:
completer=path_completer,
default=str(browse_start) + os.sep,
vi_mode=True,
complete_while_typing=True,
complete_while_typing=True
# Test that this is not needed on Windows
# complete_style=CompleteStyle.READLINE_LIKE,
)

View File

@ -24,7 +24,6 @@ from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
@ -86,7 +85,6 @@ class ApiDependencies:
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -113,7 +111,6 @@ class ApiDependencies:
latents=latents,
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,

View File

@ -28,7 +28,7 @@ class FastAPIEventService(EventServiceBase):
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
self.__queue.put(dict(event_name=event_name, payload=payload))
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""

View File

@ -1,164 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
class ModelsList(BaseModel):
"""Return list of configs."""
models: list[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
@model_records_router.get(
"/",
operation_id="list_model_records",
)
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
else:
found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models)
@model_records_router.get(
"/i/{key}",
operation_id="get_model_record",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_records
try:
return record_store.get_model(key)
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=AnyModelConfig,
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
try:
model_response = record_store.update_model(key, config=info)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return model_response
@model_records_router.delete(
"/i/{key}",
operation_id="del_model_record",
responses={
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"},
},
status_code=204,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.post(
"/i/",
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.
"""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
config.key = sha1(randbytes(100)).hexdigest()
logger.info(f"Created model {config.key} for {config.name}")
try:
record_store.add_model(config.key, config)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
# now fetch it out
return record_store.get_model(config.key)

View File

@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
import pathlib
from typing import Annotated, List, Literal, Optional, Union
@ -54,7 +55,7 @@ async def list_models(
) -> ModelsList:
"""Gets a list of models"""
if base_models and len(base_models) > 0:
models_raw = []
models_raw = list()
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:

View File

@ -34,4 +34,4 @@ class SocketIO:
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
await self.__sio.enter_room(sid, data["queue_id"])

View File

@ -1,17 +1,14 @@
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from typing import Any
from invokeai.version.invokeai_version import __version__
from fastapi.responses import HTMLResponse
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
if app_config.version:
print(f"InvokeAI version {__version__}")
sys.exit(0)
if True: # hack to make flake8 happy with imports coming after setting up the config
import asyncio
@ -19,7 +16,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
import socket
from inspect import signature
from pathlib import Path
from typing import Any
import uvicorn
from fastapi import FastAPI
@ -27,7 +23,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
@ -38,6 +34,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.version.invokeai_version import __version__
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -46,7 +43,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
board_images,
boards,
images,
model_records,
models,
session_queue,
sessions,
@ -54,12 +50,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
UIConfigBase,
)
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
@ -115,7 +106,6 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
@ -140,7 +130,7 @@ def custom_openapi() -> dict[str, Any]:
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
output_type_titles = dict()
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
@ -155,11 +145,7 @@ def custom_openapi() -> dict[str, Any]:
# Add Node Editor UI helper schemas
ui_config_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
],
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
ref_template="#/components/schemas/{model}",
)
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
@ -167,7 +153,7 @@ def custom_openapi() -> dict[str, Any]:
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
invoker_name = invoker.__name__
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
@ -185,12 +171,12 @@ def custom_openapi() -> dict[str, Any]:
# print(f"Config with name {name} already defined")
continue
openapi_schema["components"]["schemas"][name] = {
"title": name,
"description": "An enumeration.",
"type": "string",
"enum": [v.value for v in model_config_format_enum],
}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
app.openapi_schema = openapi_schema
return app.openapi_schema
@ -285,4 +271,7 @@ def invoke_api() -> None:
if __name__ == "__main__":
invoke_api()
if app_config.version:
print(f"InvokeAI version {__version__}")
else:
invoke_api()

View File

@ -5,7 +5,7 @@ from pathlib import Path
from invokeai.app.services.config.config_default import InvokeAIAppConfig
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.resolve())
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.absolute())
custom_nodes_path.mkdir(parents=True, exist_ok=True)
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
@ -25,4 +25,4 @@ spec.loader.exec_module(module)
# add core nodes to __all__
python_files = filter(lambda f: not f.name.startswith("_"), Path(__file__).parent.glob("*.py"))
__all__ = [f.stem for f in python_files] # type: ignore
__all__ = list(f.stem for f in python_files) # type: ignore

View File

@ -1,4 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI team
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from __future__ import annotations
@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
import semver
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
@ -16,18 +16,11 @@ from pydantic.fields import FieldInfo, _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.logging import InvokeAILogger
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
logger = InvokeAILogger.get_logger()
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
class InvalidVersionError(ValueError):
pass
@ -37,7 +30,71 @@ class InvalidFieldError(TypeError):
pass
class Input(str, Enum, metaclass=MetaEnum):
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
skipped_layers = "Number of layers to skip in text encoder"
seed = "Seed for random number generation"
steps = "Number of steps to run"
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
t2i_adapter = "T2I-Adapter(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
metadata = "Optional metadata to be saved with the image"
metadata_collection = "Collection of Metadata"
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
metadata_item_label = "Label for this metadata item"
metadata_item_value = "The value for this metadata item (may be any type)"
workflow = "Optional workflow to be saved with the image"
interp_mode = "Interpolation mode"
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
blend_alpha = (
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
)
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"
inclusive_low = "The inclusive low value"
exclusive_high = "The exclusive high value"
decimal_places = "The number of decimal places to round to"
class Input(str, Enum):
"""
The type of input a field accepts.
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
@ -51,124 +108,86 @@ class Input(str, Enum, metaclass=MetaEnum):
Any = "any"
class FieldKind(str, Enum, metaclass=MetaEnum):
class UIType(str, Enum):
"""
The kind of field.
- `Input`: An input field on a node.
- `Output`: An output field on a node.
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
allowing "metadata" for that field.
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
but which are used to store information about the node. For example, the `id` and `type` fields are node
attributes.
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
startup, and when generating the OpenAPI schema for the workflow editor.
Type hints for the UI.
If a field should be provided a data type that does not exactly match the python type of the field, \
use this to provide the type that should be used instead. See the node development docs for detail \
on adding a new field type, which involves client-side changes.
"""
Input = "input"
Output = "output"
Internal = "internal"
NodeAttribute = "node_attribute"
# region Primitives
Boolean = "boolean"
Color = "ColorField"
Conditioning = "ConditioningField"
Control = "ControlField"
Float = "float"
Image = "ImageField"
Integer = "integer"
Latents = "LatentsField"
String = "string"
# endregion
# region Collection Primitives
BooleanCollection = "BooleanCollection"
ColorCollection = "ColorCollection"
ConditioningCollection = "ConditioningCollection"
ControlCollection = "ControlCollection"
FloatCollection = "FloatCollection"
ImageCollection = "ImageCollection"
IntegerCollection = "IntegerCollection"
LatentsCollection = "LatentsCollection"
StringCollection = "StringCollection"
# endregion
class UIType(str, Enum, metaclass=MetaEnum):
"""
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
# region Polymorphic Primitives
BooleanPolymorphic = "BooleanPolymorphic"
ColorPolymorphic = "ColorPolymorphic"
ConditioningPolymorphic = "ConditioningPolymorphic"
ControlPolymorphic = "ControlPolymorphic"
FloatPolymorphic = "FloatPolymorphic"
ImagePolymorphic = "ImagePolymorphic"
IntegerPolymorphic = "IntegerPolymorphic"
LatentsPolymorphic = "LatentsPolymorphic"
StringPolymorphic = "StringPolymorphic"
# endregion
- Model Fields
The most common node-author-facing use will be for model fields. Internally, there is no difference
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
the field is an SDXL main model field.
- Any Field
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
- Scheduler Field
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
- Internal Fields
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
should not be used by node authors.
- DEPRECATED Fields
These types are deprecated and should not be used by node authors. A warning will be logged if one is
used, and the type will be ignored. They are included here for backwards compatibility.
"""
# region Model Field Types
# region Models
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
VaeModel = "VAEModelField"
VaeModel = "VaeModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
UNet = "UNetField"
Vae = "VaeField"
CLIP = "ClipField"
# endregion
# region Misc Field Types
Scheduler = "SchedulerField"
Any = "AnyField"
# region Iterate/Collect
Collection = "Collection"
CollectionItem = "CollectionItem"
# endregion
# region Internal Field Types
_Collection = "CollectionField"
_CollectionItem = "CollectionItemField"
# endregion
# region DEPRECATED
Boolean = "DEPRECATED_Boolean"
Color = "DEPRECATED_Color"
Conditioning = "DEPRECATED_Conditioning"
Control = "DEPRECATED_Control"
Float = "DEPRECATED_Float"
Image = "DEPRECATED_Image"
Integer = "DEPRECATED_Integer"
Latents = "DEPRECATED_Latents"
String = "DEPRECATED_String"
BooleanCollection = "DEPRECATED_BooleanCollection"
ColorCollection = "DEPRECATED_ColorCollection"
ConditioningCollection = "DEPRECATED_ConditioningCollection"
ControlCollection = "DEPRECATED_ControlCollection"
FloatCollection = "DEPRECATED_FloatCollection"
ImageCollection = "DEPRECATED_ImageCollection"
IntegerCollection = "DEPRECATED_IntegerCollection"
LatentsCollection = "DEPRECATED_LatentsCollection"
StringCollection = "DEPRECATED_StringCollection"
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
StringPolymorphic = "DEPRECATED_StringPolymorphic"
MainModel = "DEPRECATED_MainModel"
UNet = "DEPRECATED_UNet"
Vae = "DEPRECATED_Vae"
CLIP = "DEPRECATED_CLIP"
Collection = "DEPRECATED_Collection"
CollectionItem = "DEPRECATED_CollectionItem"
Enum = "DEPRECATED_Enum"
WorkflowField = "DEPRECATED_WorkflowField"
IsIntermediate = "DEPRECATED_IsIntermediate"
BoardField = "DEPRECATED_BoardField"
MetadataItem = "DEPRECATED_MetadataItem"
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
MetadataDict = "DEPRECATED_MetadataDict"
# region Misc
Enum = "enum"
Scheduler = "Scheduler"
WorkflowField = "WorkflowField"
IsIntermediate = "IsIntermediate"
BoardField = "BoardField"
Any = "Any"
MetadataItem = "MetadataItem"
MetadataItemCollection = "MetadataItemCollection"
MetadataItemPolymorphic = "MetadataItemPolymorphic"
MetadataDict = "MetadataDict"
# endregion
class UIComponent(str, Enum, metaclass=MetaEnum):
class UIComponent(str, Enum):
"""
The type of UI component to use for a field, used to override the default components, which are
The type of UI component to use for a field, used to override the default components, which are \
inferred from the field type.
"""
@ -177,22 +196,21 @@ class UIComponent(str, Enum, metaclass=MetaEnum):
Slider = "slider"
class InputFieldJSONSchemaExtra(BaseModel):
class _InputField(BaseModel):
"""
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
and by the workflow editor during schema parsing and UI rendering.
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
"""
input: Input
orig_required: bool
field_kind: FieldKind
default: Optional[Any] = None
orig_default: Optional[Any] = None
ui_hidden: bool = False
ui_type: Optional[UIType] = None
ui_component: Optional[UIComponent] = None
ui_order: Optional[int] = None
ui_choice_labels: Optional[dict[str, str]] = None
ui_hidden: bool
ui_type: Optional[UIType]
ui_component: Optional[UIComponent]
ui_order: Optional[int]
ui_choice_labels: Optional[dict[str, str]]
item_default: Optional[Any]
model_config = ConfigDict(
validate_assignment=True,
@ -200,13 +218,14 @@ class InputFieldJSONSchemaExtra(BaseModel):
)
class OutputFieldJSONSchemaExtra(BaseModel):
class _OutputField(BaseModel):
"""
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
during schema parsing and UI rendering.
*DO NOT USE*
This helper class is used to tell the client about our custom field attributes via OpenAPI
schema generation, and Typescript type generation from that schema. It serves no functional
purpose in the backend.
"""
field_kind: FieldKind
ui_hidden: bool
ui_type: Optional[UIType]
ui_order: Optional[int]
@ -217,9 +236,13 @@ class OutputFieldJSONSchemaExtra(BaseModel):
)
def get_type(klass: BaseModel) -> str:
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
return klass.model_fields["type"].default
def InputField(
# copied from pydantic's Field
# TODO: Can we support default_factory?
default: Any = _Unset,
default_factory: Callable[[], Any] | None = _Unset,
title: str | None = _Unset,
@ -243,11 +266,12 @@ def InputField(
ui_hidden: bool = False,
ui_order: Optional[int] = None,
ui_choice_labels: Optional[dict[str, str]] = None,
item_default: Optional[Any] = None,
) -> Any:
"""
Creates an input field for an invocation.
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
that adds a few extra parameters to support graph execution and the node editor UI.
:param Input input: [Input.Any] The kind of input this field requires. \
@ -267,102 +291,108 @@ def InputField(
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
For this case, you could provide `UIComponent.Textarea`.
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
Ignored for non-collection fields.
"""
json_schema_extra_ = InputFieldJSONSchemaExtra(
json_schema_extra_: dict[str, Any] = dict(
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
item_default=item_default,
ui_choice_labels=ui_choice_labels,
field_kind=FieldKind.Input,
orig_required=True,
_field_kind="input",
)
field_args = dict(
default=default,
default_factory=default_factory,
title=title,
description=description,
pattern=pattern,
strict=strict,
gt=gt,
ge=ge,
lt=lt,
le=le,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
min_length=min_length,
max_length=max_length,
)
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's
`invoke()` function.
On instantiation of a node, the invocation definition is used to create the python class. At this time,
any number of fields may be optional, because they may be provided by connections.
On calling of `invoke()`, however, those fields may be required.
Invocation definitions have their fields typed correctly for their `invoke()` functions.
This typing is often more specific than the actual invocation definition requires, because
fields may have values provided only by connections.
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
`image` is required during the call to `invoke()`, but when the python class is instantiated,
the field may not be present. This is fine, because that image field will be provided by a
connection from an ancestor node, which outputs an image.
an ancestor node that outputs the image.
This means we want to type the `image` field as optional for the node class definition, but required
for the `invoke()` function.
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
value or not. This is very tedious.
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
any static type analysis tools will complain.
Ideally, the invocation definition would be able to specify that the field is required during
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
but when calling the `invoke()` function, we raise an error if the field is not present.
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
extra validation when calling `invoke()`.
There is some additional logic here to cleaning create the pydantic field via the wrapper.
"""
if default_factory is not _Unset and default_factory is not None:
default = default_factory()
logger.warn('"default_factory" is not supported, calling it now to set "default"')
# These are the args we may wish pass to the pydantic `Field()` function
field_args = {
"default": default,
"title": title,
"description": description,
"pattern": pattern,
"strict": strict,
"gt": gt,
"ge": ge,
"lt": lt,
"le": le,
"multiple_of": multiple_of,
"allow_inf_nan": allow_inf_nan,
"max_digits": max_digits,
"decimal_places": decimal_places,
"min_length": min_length,
"max_length": max_length,
}
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
# Filter out field args not provided
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
# Because we are manually making fields optional, we need to store the original required bool for reference later
json_schema_extra_.orig_required = default is PydanticUndefined
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
raise ValueError("Cannot specify both default and default_factory")
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
if input is Input.Any or input is Input.Connection:
# because we are manually making fields optional, we need to store the original required bool for reference later
if default is PydanticUndefined and default_factory is PydanticUndefined:
json_schema_extra_.update(dict(orig_required=True))
else:
json_schema_extra_.update(dict(orig_required=False))
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
default_ = None if default is PydanticUndefined else default
provided_args.update({"default": default_})
provided_args.update(dict(default=default_))
if default is not PydanticUndefined:
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
json_schema_extra_.default = default
json_schema_extra_.orig_default = default
elif default is not PydanticUndefined:
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
json_schema_extra_.update(dict(default=default))
json_schema_extra_.update(dict(orig_default=default))
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
default_ = default
provided_args.update({"default": default_})
json_schema_extra_.orig_default = default_
provided_args.update(dict(default=default_))
json_schema_extra_.update(dict(orig_default=default_))
elif default_factory is not PydanticUndefined:
provided_args.update(dict(default_factory=default_factory))
# TODO: cannot serialize default_factory...
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
return Field(
**provided_args,
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
json_schema_extra=json_schema_extra_,
)
def OutputField(
# copied from pydantic's Field
default: Any = _Unset,
default_factory: Callable[[], Any] | None = _Unset,
title: str | None = _Unset,
description: str | None = _Unset,
pattern: str | None = _Unset,
@ -395,12 +425,13 @@ def OutputField(
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
"""
return Field(
default=default,
default_factory=default_factory,
title=title,
description=description,
pattern=pattern,
@ -415,12 +446,12 @@ def OutputField(
decimal_places=decimal_places,
min_length=min_length,
max_length=max_length,
json_schema_extra=OutputFieldJSONSchemaExtra(
json_schema_extra=dict(
ui_type=ui_type,
ui_hidden=ui_hidden,
ui_order=ui_order,
field_kind=FieldKind.Output,
).model_dump(exclude_none=True),
_field_kind="output",
),
)
@ -433,10 +464,10 @@ class UIConfigBase(BaseModel):
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
title: Optional[str] = Field(default=None, description="The node's display name")
category: Optional[str] = Field(default=None, description="The node's category")
version: str = Field(
version: Optional[str] = Field(
default=None,
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
model_config = ConfigDict(
validate_assignment=True,
@ -479,39 +510,29 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_outputs_union(cls) -> UnionType:
"""Gets a union of all invocation outputs."""
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
return outputs_union # type: ignore [return-value]
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
return map(lambda i: get_type(i), BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["required"] = list()
schema["required"].extend(["type"])
@classmethod
def get_type(cls) -> str:
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
return cls.model_fields["type"].default
model_config = ConfigDict(
protected_namespaces=(),
validate_assignment=True,
@ -541,29 +562,21 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
@classmethod
def get_type(cls) -> str:
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
return cls.model_fields["type"].default
@classmethod
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
@classmethod
def get_invocations_union(cls) -> UnionType:
"""Gets a union of all invocation types."""
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
return invocations_union # type: ignore [return-value]
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = InvokeAIAppConfig.get_config()
allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
invocation_type = get_type(sc)
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
@ -576,35 +589,36 @@ class BaseInvocation(ABC, BaseModel):
@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
# Get the type strings out of the literals and into a dictionary
return dict(
map(
lambda i: (get_type(i), i),
BaseInvocation.get_invocations(),
)
)
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in BaseInvocation.get_invocations())
return map(lambda i: get_type(i), BaseInvocation.get_invocations())
@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
def get_output_type(cls) -> BaseInvocationOutput:
return signature(cls.invoke).return_annotation
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
if uiconfig.title is not None:
schema["title"] = uiconfig.title
if uiconfig.tags is not None:
schema["tags"] = uiconfig.tags
if uiconfig.category is not None:
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
uiconfig = getattr(model_class, "UIConfig", None)
if uiconfig and hasattr(uiconfig, "title"):
schema["title"] = uiconfig.title
if uiconfig and hasattr(uiconfig, "tags"):
schema["tags"] = uiconfig.tags
if uiconfig and hasattr(uiconfig, "category"):
schema["category"] = uiconfig.category
if uiconfig and hasattr(uiconfig, "version"):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["required"] = list()
schema["required"].extend(["type", "id"])
@abstractmethod
@ -613,10 +627,6 @@ class BaseInvocation(ABC, BaseModel):
pass
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
"""
Internal invoke method, calls `invoke()` after some prep.
Handles optional fields that are required to call `invoke()` and invocation cache.
"""
for field_name, field in self.model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict
@ -656,20 +666,21 @@ class BaseInvocation(ABC, BaseModel):
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context)
def get_type(self) -> str:
return self.model_fields["type"].default
id: str = Field(
default_factory=uuid_string,
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
json_schema_extra=dict(_field_kind="internal"),
)
is_intermediate: bool = Field(
default=False,
description="Whether or not this is an intermediate invocation.",
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
json_schema_extra=dict(ui_type=UIType.IsIntermediate, _field_kind="internal"),
)
use_cache: bool = Field(
default=True,
description="Whether or not to use the cache",
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
default=True, description="Whether or not to use the cache", json_schema_extra=dict(_field_kind="internal")
)
UIConfig: ClassVar[Type[UIConfigBase]]
@ -686,15 +697,12 @@ class BaseInvocation(ABC, BaseModel):
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
RESERVED_INPUT_FIELD_NAMES = {
"id",
"is_intermediate",
"use_cache",
"type",
"workflow",
}
RESERVED_INPUT_FIELD_NAMES = {
"metadata",
}
@ -706,65 +714,48 @@ class _Model(BaseModel):
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
RESERVED_PYDANTIC_FIELD_NAMES = set(map(lambda m: m[0], inspect.getmembers(_Model())))
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
"""
Validates the fields of an invocation or invocation output:
- Must not override any pydantic reserved fields
- Must have a type annotation
- Must have a json_schema_extra dict
- Must have field_kind in json_schema_extra
- Field name must not be reserved, according to its field_kind
- must not override any pydantic reserved fields
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
"""
for name, field in model_fields.items():
if name in RESERVED_PYDANTIC_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
if not field.annotation:
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
if not isinstance(field.json_schema_extra, dict):
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
)
field_kind = field.json_schema_extra.get("field_kind", None)
field_kind = (
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
field.json_schema_extra.get("_field_kind", None)
if field.json_schema_extra
else None
)
# must have a field_kind
if not isinstance(field_kind, FieldKind):
if field_kind is None or field_kind not in {"input", "output", "internal"}:
raise InvalidFieldError(
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
)
if field_kind is FieldKind.Input and (
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
):
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
# internal fields *must* be in the reserved list
if (
field_kind == "internal"
and name not in RESERVED_INPUT_FIELD_NAMES
and name not in RESERVED_OUTPUT_FIELD_NAMES
):
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
)
# node attribute fields *must* be in the reserved list
if (
field_kind is FieldKind.NodeAttribute
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
and name not in RESERVED_OUTPUT_FIELD_NAMES
):
raise InvalidFieldError(
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
)
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
field.json_schema_extra.pop("ui_type")
return None
@ -799,30 +790,21 @@ def invocation(
validate_fields(cls.model_fields, invocation_type)
# Add OpenAPI schema extras
uiconfig_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
# Grab the node pack's name from the module name, if it's a custom node
module_name = cls.__module__.split(".")[0]
if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX):
cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0]
else:
cls.UIConfig.node_pack = None
uiconf_name = cls.__qualname__ + ".UIConfig"
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict())
if title is not None:
cls.UIConfig.title = title
if tags is not None:
cls.UIConfig.tags = tags
if category is not None:
cls.UIConfig.category = category
if version is not None:
try:
semver.Version.parse(version)
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
else:
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
cls.UIConfig.version = "1.0.0"
if use_cache is not None:
cls.model_fields["use_cache"].default = use_cache
@ -837,7 +819,7 @@ def invocation(
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
title="type", default=invocation_type, json_schema_extra=dict(_field_kind="internal")
)
docstring = cls.__doc__
@ -883,9 +865,7 @@ def invocation_output(
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
output_type_field = Field(title="type", default=output_type, json_schema_extra=dict(_field_kind="internal"))
docstring = cls.__doc__
cls = create_model(
@ -917,7 +897,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
default=None, description=FieldDescriptions.workflow, json_schema_extra=dict(_field_kind="internal")
)
@ -935,11 +915,5 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
class WithMetadata(BaseModel):
metadata: Optional[MetadataField] = Field(
default=None,
description=FieldDescriptions.metadata,
json_schema_extra=InputFieldJSONSchemaExtra(
field_kind=FieldKind.Internal,
input=Input.Connection,
orig_required=False,
).model_dump(exclude_none=True),
default=None, description=FieldDescriptions.metadata, json_schema_extra=dict(_field_kind="internal")
)

View File

@ -5,7 +5,7 @@ import numpy as np
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.util.misc import SEED_MAX
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@ -55,7 +55,7 @@ class RangeOfSizeInvocation(BaseInvocation):
title="Random Range",
tags=["range", "integer", "random", "collection"],
category="collections",
version="1.0.1",
version="1.0.0",
use_cache=False,
)
class RandomRangeInvocation(BaseInvocation):
@ -65,10 +65,10 @@ class RandomRangeInvocation(BaseInvocation):
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
size: int = InputField(default=1, description="The number of values to generate")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:

View File

@ -7,7 +7,6 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ExtraConditioningInfo,
@ -20,6 +19,7 @@ from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -112,11 +112,10 @@ class CompelInvocation(BaseInvocation):
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@ -235,11 +234,10 @@ class SDXLPromptInvocationBase:
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,

View File

@ -28,12 +28,12 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from ...backend.model_management import BaseModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -96,7 +96,7 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.0")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
@ -173,7 +173,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
@ -196,7 +196,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
@ -225,7 +225,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
@ -247,7 +247,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
@ -270,7 +270,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
title="Openpose Processor",
tags=["controlnet", "openpose", "pose"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
@ -295,7 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
@ -322,7 +322,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
@ -339,7 +339,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.1.0"
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0"
)
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
@ -362,7 +362,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.1.0"
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0"
)
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
@ -389,7 +389,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
@ -419,7 +419,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
@ -435,7 +435,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
@ -458,7 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
@ -487,7 +487,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
@ -527,7 +527,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
@ -569,7 +569,7 @@ class SamDetectorReproducibleColors(SamDetector):
title="Color Map Processor",
tags=["controlnet"],
category="controlnet",
version="1.1.0",
version="1.0.0",
)
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image"""

View File

@ -6,7 +6,6 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@ -33,15 +32,13 @@ for d in Path(__file__).parent.iterdir():
if module_name in globals():
continue
# load the module, appending adding a suffix to identify it as a custom node pack
spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute())
# we have a legit module to import
spec = spec_from_file_location(module_name, init.absolute())
if spec is None or spec.loader is None:
logger.warn(f"Could not load {init}")
continue
logger.info(f"Loading node pack {module_name}")
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
@ -50,5 +47,5 @@ for d in Path(__file__).parent.iterdir():
del init, module_name
if loaded_count > 0:
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")
logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}")

View File

@ -11,7 +11,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.1.0")
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Simple inpaint using opencv."""

View File

@ -131,7 +131,7 @@ def prepare_faces_list(
deduped_faces: list[FaceResultData] = []
if len(face_result_list) == 0:
return []
return list()
for candidate in face_result_list:
should_add = True
@ -210,7 +210,7 @@ def generate_face_box_mask(
# Check if any face is detected.
if results.multi_face_landmarks: # type: ignore # this are via protobuf and not typed
# Search for the face_id in the detected faces.
for _face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
for face_id, face_landmarks in enumerate(results.multi_face_landmarks): # type: ignore #this are via protobuf and not typed
# Get the bounding box of the face mesh.
x_coordinates = [landmark.x for landmark in face_landmarks.landmark]
y_coordinates = [landmark.y for landmark in face_landmarks.landmark]
@ -438,7 +438,7 @@ def get_faces_list(
return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.1.0")
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.2")
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
@ -532,7 +532,7 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.1.0")
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.2")
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Face mask creation using mediapipe face detection"""
@ -650,7 +650,7 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.1.0"
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.2"
)
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""

View File

@ -8,12 +8,20 @@ import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import (
BaseInvocation,
FieldDescriptions,
Input,
InputField,
InvocationContext,
WithMetadata,
WithWorkflow,
invocation,
)
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
@ -36,7 +44,7 @@ class ShowImageInvocation(BaseInvocation):
)
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.1.0")
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Creates a blank image and forwards it to the pipeline"""
@ -66,7 +74,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.1.0")
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Crops an image to a specified box. The box can be outside of the image."""
@ -100,7 +108,7 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0")
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.1")
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Pastes an image into another image."""
@ -154,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.1.0")
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Extracts the alpha channel of an image as a mask."""
@ -186,7 +194,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.1.0")
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
@ -220,7 +228,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.1.0")
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Gets a channel from an image."""
@ -253,7 +261,7 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.1.0")
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Converts an image to a different mode."""
@ -283,7 +291,7 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.1.0")
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Blurs an image"""
@ -338,7 +346,7 @@ PIL_RESAMPLING_MAP = {
}
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.1.0")
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Resizes an image to specific dimensions"""
@ -375,7 +383,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.1.0")
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Scales an image by a factor"""
@ -417,7 +425,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.1.0")
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Linear interpolation of all pixels of an image"""
@ -451,7 +459,7 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.1.0")
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Inverse linear interpolation of all pixels of an image"""
@ -485,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.1.0")
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Add blur to NSFW-flagged images"""
@ -532,7 +540,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Add Invisible Watermark",
tags=["image", "watermark"],
category="image",
version="1.1.0",
version="1.0.0",
)
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Add an invisible watermark to an image"""
@ -561,7 +569,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.1.0")
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Applies an edge mask to an image"""
@ -612,7 +620,7 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
title="Combine Masks",
tags=["image", "mask", "multiply"],
category="image",
version="1.1.0",
version="1.0.0",
)
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
@ -644,7 +652,7 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.1.0")
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""
Shifts the colors of a target image to match the reference image, optionally
@ -755,7 +763,7 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.1.0")
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Adjusts the Hue of an image."""
@ -858,7 +866,7 @@ CHANNEL_FORMATS = {
"value",
],
category="image",
version="1.1.0",
version="1.0.0",
)
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Add or subtract a value from a specific color channel of an image."""
@ -929,7 +937,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"value",
],
category="image",
version="1.1.0",
version="1.0.0",
)
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Scale a specific color channel of an image."""
@ -988,7 +996,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
title="Save Image",
tags=["primitives", "image"],
category="primitives",
version="1.1.0",
version="1.0.1",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@ -1017,35 +1025,3 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
width=image_dto.width,
height=image_dto.height,
)
@invocation(
"linear_ui_output",
title="Linear UI Image Output",
tags=["primitives", "image"],
category="primitives",
version="1.0.1",
use_cache=False,
)
class LinearUIOutputInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Handles Linear UI Image Outputting tasks."""
image: ImageField = InputField(description=FieldDescriptions.image)
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.services.images.get_dto(self.image.image_name)
if self.board:
context.services.board_images.add_image_to_board(self.board.board_id, self.image.image_name)
if image_dto.is_intermediate != self.is_intermediate:
context.services.images.update(
self.image.image_name, changes=ImageRecordChanges(is_intermediate=self.is_intermediate)
)
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -8,7 +8,7 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
@ -118,7 +118,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Infills transparent areas of an image with a solid color"""
@ -154,17 +154,17 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1")
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Infills transparent areas of an image with tiles of the image"""
image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -192,7 +192,7 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0"
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
)
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
@ -245,7 +245,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Infills transparent areas of an image using the LaMa model"""
@ -274,7 +274,7 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Infills transparent areas of an image using OpenCV Inpainting"""

View File

@ -7,15 +7,16 @@ from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
@ -66,7 +67,7 @@ class IPAdapterInvocation(BaseInvocation):
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField(
default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight"
default=1, ge=-1, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(

View File

@ -10,7 +10,7 @@ import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import T2IAdapter
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@ -34,7 +34,6 @@ from invokeai.app.invocations.primitives import (
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
@ -58,6 +57,7 @@ from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -77,7 +77,7 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@invocation_output("scheduler_output")
@ -274,10 +274,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
ui_order=7,
)
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
ui_order=4,
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
@ -565,6 +562,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
# this PR: https://github.com/huggingface/diffusers/pull/5134.
total_downscale_factor = total_downscale_factor // 2
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
@ -709,7 +710,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
@ -792,7 +792,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.1.0",
version="1.0.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Generates an image from latents."""
@ -1107,7 +1107,7 @@ class BlendLatentsInvocation(BaseInvocation):
latents_b = context.services.latents.get(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
raise "Latents to blend must be the same size."
# TODO:
device = choose_torch_device()

View File

@ -6,9 +6,8 @@ import numpy as np
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
from invokeai.app.shared.fields import FieldDescriptions
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
@ -145,17 +144,17 @@ INTEGER_OPERATIONS = Literal[
]
INTEGER_OPERATIONS_LABELS = {
"ADD": "Add A+B",
"SUB": "Subtract A-B",
"MUL": "Multiply A*B",
"DIV": "Divide A/B",
"EXP": "Exponentiate A^B",
"MOD": "Modulus A%B",
"ABS": "Absolute Value of A",
"MIN": "Minimum(A,B)",
"MAX": "Maximum(A,B)",
}
INTEGER_OPERATIONS_LABELS = dict(
ADD="Add A+B",
SUB="Subtract A-B",
MUL="Multiply A*B",
DIV="Divide A/B",
EXP="Exponentiate A^B",
MOD="Modulus A%B",
ABS="Absolute Value of A",
MIN="Minimum(A,B)",
MAX="Maximum(A,B)",
)
@invocation(
@ -183,8 +182,8 @@ class IntegerMathInvocation(BaseInvocation):
operation: INTEGER_OPERATIONS = InputField(
default="ADD", description="The operation to perform", ui_choice_labels=INTEGER_OPERATIONS_LABELS
)
a: int = InputField(default=1, description=FieldDescriptions.num_1)
b: int = InputField(default=1, description=FieldDescriptions.num_2)
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
@field_validator("b")
def no_unrepresentable_results(cls, v: int, info: ValidationInfo):
@ -231,17 +230,17 @@ FLOAT_OPERATIONS = Literal[
]
FLOAT_OPERATIONS_LABELS = {
"ADD": "Add A+B",
"SUB": "Subtract A-B",
"MUL": "Multiply A*B",
"DIV": "Divide A/B",
"EXP": "Exponentiate A^B",
"ABS": "Absolute Value of A",
"SQRT": "Square Root of A",
"MIN": "Minimum(A,B)",
"MAX": "Maximum(A,B)",
}
FLOAT_OPERATIONS_LABELS = dict(
ADD="Add A+B",
SUB="Subtract A-B",
MUL="Multiply A*B",
DIV="Divide A/B",
EXP="Exponentiate A^B",
ABS="Absolute Value of A",
SQRT="Square Root of A",
MIN="Minimum(A,B)",
MAX="Maximum(A,B)",
)
@invocation(
@ -257,8 +256,8 @@ class FloatMathInvocation(BaseInvocation):
operation: FLOAT_OPERATIONS = InputField(
default="ADD", description="The operation to perform", ui_choice_labels=FLOAT_OPERATIONS_LABELS
)
a: float = InputField(default=1, description=FieldDescriptions.num_1)
b: float = InputField(default=1, description=FieldDescriptions.num_2)
a: float = InputField(default=0, description=FieldDescriptions.num_1)
b: float = InputField(default=0, description=FieldDescriptions.num_2)
@field_validator("b")
def no_unrepresentable_results(cls, v: float, info: ValidationInfo):
@ -266,7 +265,7 @@ class FloatMathInvocation(BaseInvocation):
raise ValueError("Cannot divide by zero")
elif info.data["operation"] == "EXP" and info.data["a"] == 0 and v < 0:
raise ValueError("Cannot raise zero to a negative power")
elif info.data["operation"] == "EXP" and isinstance(info.data["a"] ** v, complex):
elif info.data["operation"] == "EXP" and type(info.data["a"] ** v) is complex:
raise ValueError("Root operation resulted in a complex number")
return v

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
MetadataField,
@ -18,7 +19,6 @@ from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.shared.fields import FieldDescriptions
from ...version import __version__
@ -112,7 +112,7 @@ GENERATION_MODES = Literal[
]
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.0")
class CoreMetadataInvocation(BaseInvocation):
"""Collects core generation metadata into a MetadataField"""
@ -160,14 +160,13 @@ class CoreMetadataInvocation(BaseInvocation):
)
# High resolution fix metadata.
hrf_enabled: Optional[bool] = InputField(
hrf_width: Optional[int] = InputField(
default=None,
description="Whether or not high resolution fix was enabled.",
description="The high resolution fix height and width multipler.",
)
# TODO: should this be stricter or do we just let the UI handle it?
hrf_method: Optional[str] = InputField(
hrf_height: Optional[int] = InputField(
default=None,
description="The high resolution fix upscale method.",
description="The high resolution fix height and width multipler.",
)
hrf_strength: Optional[float] = InputField(
default=None,

View File

@ -3,17 +3,16 @@ from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.shared.models import FreeUConfig
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
@ -37,7 +36,6 @@ class UNetField(BaseModel):
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class ClipField(BaseModel):
@ -53,32 +51,13 @@ class VaeField(BaseModel):
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@invocation_output("unet_output")
class UNetOutput(BaseInvocationOutput):
"""Base class for invocations that output a UNet field"""
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
@invocation_output("vae_output")
class VAEOutput(BaseInvocationOutput):
"""Base class for invocations that output a VAE field"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation_output("clip_output")
class CLIPOutput(BaseInvocationOutput):
"""Base class for invocations that output a CLIP field"""
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
@invocation_output("model_loader_output")
class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
pass
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
class MainModelField(BaseModel):
@ -387,6 +366,13 @@ class VAEModelField(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@invocation_output("vae_loader_output")
class VaeLoaderOutput(BaseInvocationOutput):
"""VAE output"""
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
@ -394,10 +380,11 @@ class VaeLoaderInvocation(BaseInvocation):
vae_model: VAEModelField = InputField(
description=FieldDescriptions.vae_model,
input=Input.Direct,
ui_type=UIType.VaeModel,
title="VAE",
)
def invoke(self, context: InvocationContext) -> VAEOutput:
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
@ -408,7 +395,7 @@ class VaeLoaderInvocation(BaseInvocation):
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VAEOutput(
return VaeLoaderOutput(
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
@ -470,24 +457,3 @@ class SeamlessModeInvocation(BaseInvocation):
vae.seamless_axes = seamless_axes_list
return SeamlessModeOutput(unet=unet, vae=vae)
@invocation("freeu", title="FreeU", tags=["freeu"], category="unet", version="1.0.0")
class FreeUInvocation(BaseInvocation):
"""
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
SD1.5: 1.2/1.4/0.9/0.2,
SD2: 1.1/1.2/0.9/0.2,
SDXL: 1.1/1.2/0.6/0.4,
"""
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet")
b1: float = InputField(default=1.2, ge=-1, le=3, description=FieldDescriptions.freeu_b1)
b2: float = InputField(default=1.4, ge=-1, le=3, description=FieldDescriptions.freeu_b2)
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
def invoke(self, context: InvocationContext) -> UNetOutput:
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
return UNetOutput(unet=self.unet)

View File

@ -5,13 +5,13 @@ import torch
from pydantic import field_validator
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.misc import SEED_MAX
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
OutputField,
@ -83,16 +83,16 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
title="Noise",
tags=["latents", "noise"],
category="latents",
version="1.0.1",
version="1.0.0",
)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
default_factory=get_random_seed,
)
width: int = InputField(
default=512,

View File

@ -14,7 +14,6 @@ from tqdm import tqdm
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
@ -24,6 +23,7 @@ from ...backend.util import choose_torch_device
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -54,7 +54,7 @@ ORT_TO_NP_TYPE = {
"tensor(double)": np.float64,
}
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
@ -252,7 +252,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma)
extra_step_kwargs = {}
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
@ -326,7 +326,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
title="ONNX Latents to Image",
tags=["latents", "image", "vae", "onnx"],
category="image",
version="1.1.0",
version="1.0.0",
)
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
"""Generates an image from latents."""

View File

@ -100,7 +100,7 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
# actually I think for now could just use CollectionOutput (which is list[Any]
@ -161,7 +161,7 @@ class StepParamEasingInvocation(BaseInvocation):
easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics:
context.services.logger.debug("easing class: " + str(easing_class))
easing_list = []
easing_list = list()
if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2
# and create reverse copy of list to append
@ -178,7 +178,7 @@ class StepParamEasingInvocation(BaseInvocation):
end=self.end_value,
duration=base_easing_duration - 1,
)
base_easing_vals = []
base_easing_vals = list()
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val)

View File

@ -5,11 +5,10 @@ from typing import Optional, Tuple
import torch
from pydantic import BaseModel, Field
from invokeai.app.shared.fields import FieldDescriptions
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
@ -62,12 +61,12 @@ class BooleanInvocation(BaseInvocation):
title="Boolean Collection Primitive",
tags=["primitives", "boolean", "collection"],
category="primitives",
version="1.0.1",
version="1.0.0",
)
class BooleanCollectionInvocation(BaseInvocation):
"""A collection of boolean primitive values"""
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection)
@ -111,12 +110,12 @@ class IntegerInvocation(BaseInvocation):
title="Integer Collection Primitive",
tags=["primitives", "integer", "collection"],
category="primitives",
version="1.0.1",
version="1.0.0",
)
class IntegerCollectionInvocation(BaseInvocation):
"""A collection of integer primitive values"""
collection: list[int] = InputField(default=[], description="The collection of integer values")
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection)
@ -158,12 +157,12 @@ class FloatInvocation(BaseInvocation):
title="Float Collection Primitive",
tags=["primitives", "float", "collection"],
category="primitives",
version="1.0.1",
version="1.0.0",
)
class FloatCollectionInvocation(BaseInvocation):
"""A collection of float primitive values"""
collection: list[float] = InputField(default=[], description="The collection of float values")
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection)
@ -205,12 +204,12 @@ class StringInvocation(BaseInvocation):
title="String Collection Primitive",
tags=["primitives", "string", "collection"],
category="primitives",
version="1.0.1",
version="1.0.0",
)
class StringCollectionInvocation(BaseInvocation):
"""A collection of string primitive values"""
collection: list[str] = InputField(default=[], description="The collection of string values")
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection)
@ -467,13 +466,13 @@ class ConditioningInvocation(BaseInvocation):
title="Conditioning Collection Primitive",
tags=["primitives", "conditioning", "collection"],
category="primitives",
version="1.0.1",
version="1.0.0",
)
class ConditioningCollectionInvocation(BaseInvocation):
"""A collection of conditioning tensor primitive values"""
collection: list[ConditioningField] = InputField(
default=[],
default_factory=list,
description="The collection of conditioning tensors",
)

View File

@ -44,7 +44,7 @@ class DynamicPromptInvocation(BaseInvocation):
title="Prompts from File",
tags=["prompt", "file"],
category="prompt",
version="1.0.1",
version="1.0.0",
)
class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file"""
@ -82,7 +82,7 @@ class PromptsFromFileInvocation(BaseInvocation):
end_line = start_line + max_prompts
if max_prompts <= 0:
end_line = np.iinfo(np.int32).max
with open(file_path, encoding="utf-8") as f:
with open(file_path) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))

View File

@ -1,9 +1,8 @@
from invokeai.app.shared.fields import FieldDescriptions
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,

View File

@ -5,16 +5,17 @@ from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType
@ -58,7 +59,7 @@ class T2IAdapterInvocation(BaseInvocation):
ui_order=-1,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"

View File

@ -2,16 +2,16 @@
from pathlib import Path
from typing import Literal
import cv2
import cv2 as cv
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import ConfigDict
from realesrgan import RealESRGANer
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
@ -29,7 +29,7 @@ if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.2.0")
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"""Upscales an image using RealESRGAN."""
@ -92,9 +92,9 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upscaler = RealESRGAN(
upsampler = RealESRGANer(
scale=netscale,
model_path=models_path / esrgan_model_path,
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
tile=self.tile_size,
@ -102,9 +102,15 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):

View File

@ -139,7 +139,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
@ -167,7 +167,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = [r[0] for r in result]
image_names = list(map(lambda r: r[0], result))
return image_names
except sqlite3.Error as e:
self._conn.rollback()

View File

@ -199,7 +199,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
# Get the total number of boards
self._cursor.execute(
@ -236,7 +236,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
return boards

View File

@ -15,7 +15,7 @@ import os
import sys
from argparse import ArgumentParser
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, ListConfig, OmegaConf
from pydantic_settings import BaseSettings, SettingsConfigDict
@ -24,7 +24,10 @@ from invokeai.app.services.config.config_common import PagingArgumentParser, int
class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
"""
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
"""
initconf: ClassVar[Optional[DictConfig]] = None
argparse_groups: ClassVar[Dict] = {}
@ -32,7 +35,6 @@ class InvokeAISettings(BaseSettings):
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
"""Call to parse command-line arguments."""
parser = self.get_parser()
opt, unknown_opts = parser.parse_known_args(argv)
if len(unknown_opts) > 0:
@ -47,21 +49,22 @@ class InvokeAISettings(BaseSettings):
setattr(self, name, value)
def to_yaml(self) -> str:
"""Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
"""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0]
field_dict: Dict[str, Dict[str, Any]] = {type: {}}
field_dict = dict({type: dict()})
for name, field in self.model_fields.items():
if name in cls._excluded_from_yaml():
continue
assert isinstance(field.json_schema_extra, dict)
category = (
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
)
value = getattr(self, name)
assert isinstance(category, str)
if category not in field_dict[type]:
field_dict[type][category] = {}
field_dict[type][category] = dict()
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)
@ -69,7 +72,6 @@ class InvokeAISettings(BaseSettings):
@classmethod
def add_parser_arguments(cls, parser):
"""Dynamically create arguments for a settings parser."""
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
else:
@ -87,7 +89,7 @@ class InvokeAISettings(BaseSettings):
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = {}
upcase_environ = dict()
for key, value in os.environ.items():
upcase_environ[key.upper()] = value
@ -114,7 +116,6 @@ class InvokeAISettings(BaseSettings):
@classmethod
def cmd_name(cls, command_field: str = "type") -> str:
"""Return the category of a setting."""
hints = get_type_hints(cls)
if command_field in hints:
return get_args(hints[command_field])[0]
@ -123,7 +124,6 @@ class InvokeAISettings(BaseSettings):
@classmethod
def get_parser(cls) -> ArgumentParser:
"""Get the command-line parser for a setting."""
parser = PagingArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
@ -152,14 +152,10 @@ class InvokeAISettings(BaseSettings):
"free_gpu_mem",
"xformers_enabled",
"tiled_decode",
"lora_dir",
"embedding_dir",
"controlnet_dir",
]
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
"""Add the argparse arguments for a setting parser."""
field_type = get_type_hints(cls).get(name)
default = (
default_override

View File

@ -177,7 +177,6 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hint
from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter
from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict
from .config_base import InvokeAISettings
@ -189,24 +188,28 @@ DEFAULT_MAX_VRAM = 0.5
class Categories(object):
"""Category headers for configuration variable groups."""
WebServer: JsonDict = {"category": "Web Server"}
Features: JsonDict = {"category": "Features"}
Paths: JsonDict = {"category": "Paths"}
Logging: JsonDict = {"category": "Logging"}
Development: JsonDict = {"category": "Development"}
Other: JsonDict = {"category": "Other"}
ModelCache: JsonDict = {"category": "Model Cache"}
Device: JsonDict = {"category": "Device"}
Generation: JsonDict = {"category": "Generation"}
Queue: JsonDict = {"category": "Queue"}
Nodes: JsonDict = {"category": "Nodes"}
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
WebServer = dict(category="Web Server")
Features = dict(category="Features")
Paths = dict(category="Paths")
Logging = dict(category="Logging")
Development = dict(category="Development")
Other = dict(category="Other")
ModelCache = dict(category="Model Cache")
Device = dict(category="Device")
Generation = dict(category="Generation")
Queue = dict(category="Queue")
Nodes = dict(category="Nodes")
MemoryPerformance = dict(category="Memory/Performance")
class InvokeAIAppConfig(InvokeAISettings):
"""Configuration object for InvokeAI App."""
"""
Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
"""
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict]] = None
@ -231,12 +234,15 @@ class InvokeAIAppConfig(InvokeAISettings):
# PATHS
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
@ -279,15 +285,11 @@ class InvokeAIAppConfig(InvokeAISettings):
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
# this is not referred to in the source code and can be removed entirely
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
# fmt: on
@ -301,8 +303,8 @@ class InvokeAIAppConfig(InvokeAISettings):
clobber=False,
):
"""
Update settings with contents of init file, environment, and command-line settings.
Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
:param clobber: ovewrite any initialization parameters passed during initialization
@ -335,7 +337,9 @@ class InvokeAIAppConfig(InvokeAISettings):
@classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
"""Return a singleton InvokeAIAppConfig configuration object."""
"""
This returns a singleton InvokeAIAppConfig configuration object.
"""
if (
cls.singleton_config is None
or type(cls.singleton_config) is not cls
@ -347,7 +351,9 @@ class InvokeAIAppConfig(InvokeAISettings):
@property
def root_path(self) -> Path:
"""Path to the runtime root directory."""
"""
Path to the runtime root directory
"""
if self.root:
root = Path(self.root).expanduser().absolute()
else:
@ -357,7 +363,9 @@ class InvokeAIAppConfig(InvokeAISettings):
@property
def root_dir(self) -> Path:
"""Alias for above."""
"""
Alias for above.
"""
return self.root_path
def _resolve(self, partial_path: Path) -> Path:
@ -365,95 +373,108 @@ class InvokeAIAppConfig(InvokeAISettings):
@property
def init_file_path(self) -> Path:
"""Path to invokeai.yaml."""
resolved_path = self._resolve(INIT_FILE)
assert resolved_path is not None
return resolved_path
"""
Path to invokeai.yaml
"""
return self._resolve(INIT_FILE)
@property
def output_path(self) -> Optional[Path]:
"""Path to defaults outputs directory."""
def output_path(self) -> Path:
"""
Path to defaults outputs directory.
"""
return self._resolve(self.outdir)
@property
def db_path(self) -> Path:
"""Path to the invokeai.db file."""
db_dir = self._resolve(self.db_dir)
assert db_dir is not None
return db_dir / DB_FILE
"""
Path to the invokeai.db file.
"""
return self._resolve(self.db_dir) / DB_FILE
@property
def model_conf_path(self) -> Optional[Path]:
"""Path to models configuration file."""
def model_conf_path(self) -> Path:
"""
Path to models configuration file.
"""
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self) -> Optional[Path]:
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
def legacy_conf_path(self) -> Path:
"""
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
"""
return self._resolve(self.legacy_conf_dir)
@property
def models_path(self) -> Optional[Path]:
"""Path to the models directory."""
def models_path(self) -> Path:
"""
Path to the models directory
"""
return self._resolve(self.models_dir)
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory."""
custom_nodes_path = self._resolve(self.custom_nodes_dir)
assert custom_nodes_path is not None
return custom_nodes_path
"""
Path to the custom nodes directory
"""
return self._resolve(self.custom_nodes_dir)
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self) -> bool:
"""Return true if precision set to float32."""
"""Return true if precision set to float32"""
return self.precision == "float32"
@property
def try_patchmatch(self) -> bool:
"""Return true if patchmatch true."""
"""Return true if patchmatch true"""
return self.patchmatch
@property
def nsfw_checker(self) -> bool:
"""Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
"""NSFW node is always active and disabled from Web UIe"""
return True
@property
def invisible_watermark(self) -> bool:
"""Return value of invisible watermark. It is always active and disabled from Web UI."""
"""invisible watermark node is always active and disabled from Web UIe"""
return True
@property
def ram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the ram cache size using the legacy or modern setting."""
return self.max_cache_size or self.ram
@property
def vram_cache_size(self) -> Union[Literal["auto"], float]:
"""Return the vram cache size using the legacy or modern setting."""
return self.max_vram_cache_size or self.vram
@property
def use_cpu(self) -> bool:
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
return self.always_use_cpu or self.device == "cpu"
@property
def disable_xformers(self) -> bool:
"""Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
"""
Return true if enable_xformers is false (reversed logic)
and attention type is not set to xformers.
"""
disabled_in_config = not self.xformers_enabled
return disabled_in_config and self.attention_type != "xformers"
@staticmethod
def find_root() -> Path:
"""Choose the runtime root directory when not specified on command line or init file."""
"""
Choose the runtime root directory when not specified on command line or
init file.
"""
return _find_root()
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
"""
Legacy function which returns InvokeAIAppConfig.get_config()
"""
return InvokeAIAppConfig.get_config(**kwargs)
@ -461,7 +482,7 @@ def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve()
else:
root = Path("~/invokeai").expanduser().resolve()

View File

@ -27,7 +27,7 @@ class EventServiceBase:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
payload=dict(event=event_name, data=payload),
)
# Define events here for every event in the system.
@ -48,18 +48,18 @@ class EventServiceBase:
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node.get("id"),
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump() if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node.get("id"),
source_node_id=source_node_id,
progress_image=progress_image.model_dump() if progress_image is not None else None,
step=step,
order=order,
total_steps=total_steps,
),
)
def emit_invocation_complete(
@ -75,15 +75,15 @@ class EventServiceBase:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
result=result,
),
)
def emit_invocation_error(
@ -100,16 +100,16 @@ class EventServiceBase:
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
error_type=error_type,
error=error,
),
)
def emit_invocation_started(
@ -124,14 +124,14 @@ class EventServiceBase:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
),
)
def emit_graph_execution_complete(
@ -140,12 +140,12 @@ class EventServiceBase:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_model_load_started(
@ -162,16 +162,16 @@ class EventServiceBase:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
),
)
def emit_model_load_completed(
@ -189,19 +189,19 @@ class EventServiceBase:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_name": model_name,
"base_model": base_model,
"model_type": model_type,
"submodel": submodel,
"hash": model_info.hash,
"location": str(model_info.location),
"precision": str(model_info.precision),
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
hash=model_info.hash,
location=str(model_info.location),
precision=str(model_info.precision),
),
)
def emit_session_retrieval_error(
@ -216,14 +216,14 @@ class EventServiceBase:
"""Emitted when session retrieval fails"""
self.__emit_queue_event(
event_name="session_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"error_type": error_type,
"error": error,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
)
def emit_invocation_retrieval_error(
@ -239,15 +239,15 @@ class EventServiceBase:
"""Emitted when invocation retrieval fails"""
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"error_type": error_type,
"error": error,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)
def emit_session_canceled(
@ -260,12 +260,12 @@ class EventServiceBase:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_queue_item_status_changed(
@ -277,39 +277,39 @@ class EventServiceBase:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(),
"queue_status": queue_status.model_dump(),
},
payload=dict(
queue_id=queue_status.queue_id,
queue_item=dict(
queue_id=session_queue_item.queue_id,
item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
batch_status=batch_status.model_dump(),
queue_status=queue_status.model_dump(),
),
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
payload=dict(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
),
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload={"queue_id": queue_id},
payload=dict(queue_id=queue_id),
)

View File

@ -25,7 +25,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
self.__cache = {}
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config

View File

@ -90,23 +90,25 @@ class ImageRecordDeleteException(Exception):
IMAGE_DTO_COLS = ", ".join(
[
"images." + c
for c in [
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
"starred",
]
]
list(
map(
lambda c: "images." + c,
[
"image_name",
"image_origin",
"image_category",
"width",
"height",
"session_id",
"node_id",
"is_intermediate",
"created_at",
"updated_at",
"deleted_at",
"starred",
],
)
)
)

View File

@ -263,7 +263,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
@ -307,7 +307,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
@ -386,7 +386,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = [r[0] for r in result]
image_names = list(map(lambda r: r[0], result))
self._cursor.execute(
"""--sql
DELETE FROM images

View File

@ -21,8 +21,8 @@ class ImageServiceABC(ABC):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
"""Register a callback for when an image is changed"""

View File

@ -217,16 +217,18 @@ class ImageService(ImageServiceABC):
board_id,
)
image_dtos = [
image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
image_dtos = list(
map(
lambda r: image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
),
results.items,
)
for r in results.items
]
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,

View File

@ -1,5 +1,5 @@
from abc import ABC
class InvocationProcessorABC(ABC): # noqa: B024
class InvocationProcessorABC(ABC):
pass

View File

@ -26,7 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker_thread = Thread(
name="invoker_processor",
target=self.__process,
kwargs={"stop_event": self.__stop_event},
kwargs=dict(stop_event=self.__stop_event),
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()

View File

@ -14,7 +14,7 @@ class MemoryInvocationQueue(InvocationQueueABC):
def __init__(self):
self.__queue = Queue()
self.__cancellations = {}
self.__cancellations = dict()
def get(self) -> InvocationQueueItem:
item = self.__queue.get()

View File

@ -22,7 +22,6 @@ if TYPE_CHECKING:
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
@ -50,7 +49,6 @@ class InvocationServices:
latents: "LatentsStorageBase"
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -78,7 +76,6 @@ class InvocationServices:
latents: "LatentsStorageBase",
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -104,7 +101,6 @@ class InvocationServices:
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -122,7 +122,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
def log_stats(self):
completed = set()
errored = set()
for graph_id, _node_log in self._stats.items():
for graph_id, node_log in self._stats.items():
try:
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
except Exception:
@ -142,7 +142,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
loaded = sum([v for v in cache_stats.loaded_model_sizes.values()]) / GIG
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")

View File

@ -15,8 +15,8 @@ class ItemStorageABC(ABC, Generic[T]):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
"""Base item storage class"""

View File

@ -112,7 +112,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = [self._parse_item(r[0]) for r in result]
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
@ -132,7 +132,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
)
result = self._cursor.fetchall()
items = [self._parse_item(r[0]) for r in result]
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",

View File

@ -13,8 +13,8 @@ class LatentsStorageBase(ABC):
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = []
self._on_deleted_callbacks = []
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
@abstractmethod
def get(self, name: str) -> torch.Tensor:

View File

@ -19,7 +19,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
super().__init__()
self.__underlying_storage = underlying_storage
self.__cache = {}
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size

View File

@ -1,8 +0,0 @@
"""Init file for model record services."""
from .model_records_base import ( # noqa F401
DuplicateModelException,
InvalidModelException,
ModelRecordServiceBase,
UnknownModelException,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401

View File

@ -1,169 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for storing and retrieving model configuration records.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception):
"""Raised on an attempt to add a model with the same key twice."""
class InvalidModelException(Exception):
"""Raised when an invalid model is detected."""
class UnknownModelException(Exception):
"""Raised on an attempt to fetch or delete a model with a nonexistent key."""
class ConfigFileVersionMismatchException(Exception):
"""Raised on an attempt to open a config with an incompatible version."""
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
pass
@abstractmethod
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
pass
@abstractmethod
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
pass
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the configuration for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
pass
@abstractmethod
def search_by_path(
self,
path: Union[str, Path],
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated path."""
pass
@abstractmethod
def search_by_hash(
self,
hash: str,
) -> List[AnyModelConfig]:
"""Return the model(s) having the indicated original hash."""
pass
@abstractmethod
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
pass
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_attr()
def model_info_by_name(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> AnyModelConfig:
"""
Return information about a single model using its name, base type and model type.
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
model_configs = self.search_by_attr(model_name=model_name, base_model=base_model, model_type=model_type)
if len(model_configs) > 1:
raise DuplicateModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
if len(model_configs) == 0:
raise UnknownModelException(
f"More than one model matched the search criteria: base_model='{base_model}', model_type='{model_type}', model_name='{model_name}'."
)
return model_configs[0]
def rename_model(
self,
key: str,
new_name: str,
) -> AnyModelConfig:
"""
Rename the indicated model. Just a special case of update_model().
In some implementations, renaming the model may involve changing where
it is stored on the filesystem. So this is broken out.
:param key: Model key
:param new_name: New name for model
"""
config = self.get_model(key)
config.name = new_name
return self.update_model(key, config)

View File

@ -1,396 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
SQL Implementation of the ModelRecordServiceBase API
Typical usage:
from invokeai.backend.model_manager import ModelConfigStoreSQL
store = ModelConfigStoreSQL(sqlite_db)
config = dict(
path='/tmp/pokemon.bin',
name='old name',
base_model='sd-1',
type='embedding',
format='embedding_file',
)
# adding - the key becomes the model's "key" field
store.add_model('key1', config)
# updating
config.name='new name'
store.update_model('key1', config)
# checking for existence
if store.exists('key1'):
print("yes")
# fetching config
new_config = store.get_model('key1')
print(new_config.name, new_config.base)
assert new_config.key == 'key1'
# deleting
store.del_model('key1')
# searching
configs = store.search_by_path(path='/tmp/pokemon.bin')
configs = store.search_by_hash('750a499f35e43b7e1b4d15c207aa2f01')
configs = store.search_by_attr(base_model='sd-2', model_type='main')
"""
import json
import sqlite3
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelType,
)
from ..shared.sqlite import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
)
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
_db: SqliteDatabase
_cursor: sqlite3.Cursor
def __init__(self, db: SqliteDatabase):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
:param conn: sqlite3 connection object
:param lock: threading Lock object
"""
super().__init__()
self._db = db
self._cursor = self._db.conn.cursor()
with self._db.lock:
# Enable foreign keys
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
:param key: Unique key for the model
:param config: Model configuration record, either a dict with the
required fields or a ModelConfigBase instance.
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None:
"""
Delete a model.
:param key: Unique key for the model to be deleted
Can raise an UnknownModelException
"""
with self._db.lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
:param key: Unique key for the model to be updated
:param config: Model configuration record. Either a dict with the
required fields, or a ModelConfigBase instance.
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE model_config
SET base=?,
type=?,
name=?,
path=?,
config=?
WHERE id=?;
""",
(record.base, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
:param key: Key of model config to be fetched.
Exceptions: UnknownModelException
"""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]))
return model
def exists(self, key: str) -> bool:
"""
Return True if a model with the indicated key exists in the databse.
:param key: Unique key for the model to be deleted
"""
count = 0
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
return count > 0
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
If none of the optional filters are passed, will return all
models in the database.
"""
results = []
where_clause = []
bindings = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated original_hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results

View File

@ -1,6 +1,7 @@
import traceback
from threading import BoundedSemaphore, Thread
from threading import BoundedSemaphore
from threading import Event as ThreadEvent
from threading import Thread
from typing import Optional
from fastapi_events.handlers.local import local_handler
@ -32,11 +33,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__thread = Thread(
name="session_processor",
target=self.__process,
kwargs={
"stop_event": self.__stop_event,
"poll_now_event": self.__poll_now_event,
"resume_event": self.__resume_event,
},
kwargs=dict(
stop_event=self.__stop_event, poll_now_event=self.__poll_now_event, resume_event=self.__resume_event
),
)
self.__thread.start()

View File

@ -129,12 +129,12 @@ class Batch(BaseModel):
return v
model_config = ConfigDict(
json_schema_extra={
"required": [
json_schema_extra=dict(
required=[
"graph",
"runs",
]
}
)
)
@ -191,8 +191,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
json_schema_extra=dict(
required=[
"item_id",
"status",
"batch_id",
@ -203,7 +203,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
"created_at",
"updated_at",
]
}
)
)
@ -222,8 +222,8 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
json_schema_extra=dict(
required=[
"item_id",
"status",
"batch_id",
@ -235,7 +235,7 @@ class SessionQueueItem(SessionQueueItemWithoutGraph):
"created_at",
"updated_at",
]
}
)
)
@ -355,7 +355,7 @@ def create_session_nfv_tuples(
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
# create generator to yield session,nfv tuples
count = 0
@ -383,7 +383,7 @@ def calc_session_count(batch: Batch) -> int:
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data.append(list(zip(*to_zip)))
data_product = list(product(*data))
return len(data_product) * batch.runs

View File

@ -78,7 +78,7 @@ def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[Li
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = []
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)

View File

@ -49,7 +49,7 @@ class Edge(BaseModel):
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_outputs = get_type_hints(node_type.get_output_type())
node_output_field = node_outputs.get(field) or None
return node_output_field
@ -188,7 +188,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
# TODO: Fill this out and move to invocations
@invocation("graph", version="1.0.0")
@invocation("graph")
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
@ -205,7 +205,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
)
@ -215,7 +215,7 @@ class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
collection: list[Any] = InputField(
description="The list of items to iterate over", default=[], ui_type=UIType._Collection
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
)
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
@ -227,7 +227,7 @@ class IterateInvocation(BaseInvocation):
@invocation_output("collect_output")
class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = OutputField(
description="The collection of input items", title="Collection", ui_type=UIType._Collection
description="The collection of input items", title="Collection", ui_type=UIType.Collection
)
@ -238,12 +238,12 @@ class CollectInvocation(BaseInvocation):
item: Optional[Any] = InputField(
default=None,
description="The item to collect (all inputs must be of the same type)",
ui_type=UIType._CollectionItem,
ui_type=UIType.CollectionItem,
title="Collection Item",
input=Input.Connection,
)
collection: list[Any] = InputField(
description="The collection, will be provided on execution", default=[], ui_hidden=True
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
)
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
@ -352,7 +352,7 @@ class Graph(BaseModel):
# Validate that all node ids are unique
node_ids = [n.id for n in self.nodes.values()]
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
@ -379,7 +379,7 @@ class Graph(BaseModel):
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_annotation().model_fields:
if edge.source.field not in source_node.get_output_type().model_fields:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
@ -616,7 +616,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = []
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
@ -658,7 +658,7 @@ class Graph(BaseModel):
self, node_path: str, prefix: Optional[str] = None
) -> list[tuple["Graph", Union[str, None], Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = []
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
@ -680,8 +680,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
if new_input is not None:
inputs.append(new_input)
@ -694,7 +694,7 @@ class Graph(BaseModel):
# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
# Input type must be a list
if get_origin(input_field) != list:
@ -713,8 +713,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
if new_input is not None:
inputs.append(new_input)
@ -722,16 +722,18 @@ class Graph(BaseModel):
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
# Validate that all inputs are derived from or match a single type
input_field_types = {
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
} # Get unique types
input_field_types = set(
[
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) is None else get_args(input_field))
if t != NoneType
]
) # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
@ -759,15 +761,15 @@ class Graph(BaseModel):
"""Returns a NetworkX DiGraph representing the layout of this graph"""
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from(list(self.nodes.keys()))
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
g.add_nodes_from([n for n in self.nodes.keys()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from(list(self.nodes.items()))
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
@ -789,7 +791,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
return g
@ -841,8 +843,8 @@ class GraphExecutionState(BaseModel):
return v
model_config = ConfigDict(
json_schema_extra={
"required": [
json_schema_extra=dict(
required=[
"id",
"graph",
"execution_graph",
@ -853,7 +855,7 @@ class GraphExecutionState(BaseModel):
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
)
def next(self) -> Optional[BaseInvocation]:
@ -893,7 +895,7 @@ class GraphExecutionState(BaseModel):
source_node = self.prepared_source_mapping[node_id]
prepared_nodes = self.source_prepared_mapping[source_node]
if all(n in self.executed for n in prepared_nodes):
if all([n in self.executed for n in prepared_nodes]):
self.executed.add(source_node)
self.executed_history.append(source_node)
@ -928,7 +930,7 @@ class GraphExecutionState(BaseModel):
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
self_iteration_count = len(input_collection)
new_nodes: list[str] = []
new_nodes: list[str] = list()
if self_iteration_count == 0:
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
return new_nodes
@ -938,7 +940,7 @@ class GraphExecutionState(BaseModel):
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
new_edges: list[Edge] = []
new_edges: list[Edge] = list()
for edge in input_edges:
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
new_edge = Edge(
@ -1032,7 +1034,7 @@ class GraphExecutionState(BaseModel):
# Create execution nodes
next_node = self.graph.get_node(next_node_id)
new_node_ids = []
new_node_ids = list()
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(
@ -1053,10 +1055,7 @@ class GraphExecutionState(BaseModel):
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
prepared_parent_mappings = [
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
for it in iterator_node_prepared_combinations
] # type: ignore
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
@ -1122,7 +1121,7 @@ class GraphExecutionState(BaseModel):
for edge in input_edges
if edge.destination.field == "item"
]
node.collection = output_collection
setattr(node, "collection", output_collection)
else:
for edge in input_edges:
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
@ -1202,7 +1201,7 @@ class LibraryGraph(BaseModel):
@field_validator("exposed_inputs", "exposed_outputs")
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
if len(v) != len({i.alias for i in v}):
if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias")
return v

View File

@ -1,5 +0,0 @@
"""
This module contains various classes, functions and models which are shared across the app, particularly by invocations.
Lifting these classes, functions and models into this shared module helps to reduce circular imports.
"""

View File

@ -1,66 +0,0 @@
class FieldDescriptions:
denoising_start = "When to start denoising, expressed a percentage of total steps"
denoising_end = "When to stop denoising, expressed a percentage of total steps"
cfg_scale = "Classifier-Free Guidance scale"
scheduler = "Scheduler to use during inference"
positive_cond = "Positive conditioning tensor"
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
skipped_layers = "Number of layers to skip in text encoder"
seed = "Seed for random number generation"
steps = "Number of steps to run"
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
t2i_adapter = "T2I-Adapter(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
metadata = "Optional metadata to be saved with the image"
metadata_collection = "Collection of Metadata"
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
metadata_item_label = "Label for this metadata item"
metadata_item_value = "The value for this metadata item (may be any type)"
workflow = "Optional workflow to be saved with the image"
interp_mode = "Interpolation mode"
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"
scribble_mode = "Whether or not to use scribble mode"
scale_factor = "The factor by which to scale"
blend_alpha = (
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
)
num_1 = "The first number"
num_2 = "The second number"
mask = "The mask to use for the operation"
board = "The board to save the image to"
image = "The image to process"
tile_size = "Tile size"
inclusive_low = "The inclusive low value"
exclusive_high = "The exclusive high value"
decimal_places = "The number of decimal places to round to"
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."

View File

@ -1,16 +0,0 @@
from pydantic import BaseModel, Field
from invokeai.app.shared.fields import FieldDescriptions
class FreeUConfig(BaseModel):
"""
Configuration for the FreeU hyperparameters.
- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu
- https://github.com/ChenyangSi/FreeU
"""
s1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_s2)
b1: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b1)
b2: float = Field(ge=-1, le=3, description=FieldDescriptions.freeu_b2)

View File

@ -59,7 +59,7 @@ def thin_one_time(x, kernels):
def lvmin_thin(x, prunings=True):
y = x
for _i in range(32):
for i in range(32):
y, is_done = thin_one_time(y, lvmin_kernels)
if is_done:
break

View File

@ -21,11 +21,11 @@ def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
# sanity check make sure the graph is at least reasonably shaped
if (
not isinstance(graph, dict)
type(graph) is not dict
or "nodes" not in graph
or not isinstance(graph["nodes"], dict)
or type(graph["nodes"]) is not dict
or "edges" not in graph
or not isinstance(graph["edges"], list)
or type(graph["edges"]) is not list
):
# something has gone terribly awry, return an empty dict
return None

View File

@ -88,7 +88,7 @@ class PromptFormatter:
t2i = self.t2i
opt = self.opt
switches = []
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")

View File

@ -1,29 +0,0 @@
BSD 3-Clause License
Copyright (c) 2021, Xintao Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,274 +0,0 @@
import math
from enum import Enum
from pathlib import Path
from typing import Any, Optional
import cv2
import numpy as np
import numpy.typing as npt
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.util.devices import choose_torch_device
"""
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
License is BSD3, copied to `LICENSE` in this directory.
The adaptation here has a few changes:
- Remove print statements, use `tqdm` to show progress
- Remove unused "outscale" logic, which simply scales the final image to a given factor
- Remove `dni_weight` logic, which was only used when multiple models were used
- Remove logic to fetch models from network
- Add types, rename a few things
"""
class ImageMode(str, Enum):
L = "L"
RGB = "RGB"
RGBA = "RGBA"
class RealESRGAN:
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
output: torch.Tensor
def __init__(
self,
scale: int,
model_path: Path,
model: RRDBNet,
tile: int = 0,
tile_pad: int = 10,
pre_pad: int = 10,
half: bool = False,
) -> None:
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale: Optional[int] = None
self.half = half
self.device = choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img: MatLike) -> None:
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
img_tensor: torch.Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img_tensor.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = torch.nn.functional.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if h % self.mod_scale != 0:
self.mod_pad_h = self.mod_scale - h % self.mod_scale
if w % self.mod_scale != 0:
self.mod_pad_w = self.mod_scale - w % self.mod_scale
self.img = torch.nn.functional.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect")
def process(self) -> None:
# model inference
self.output = self.model(self.img)
def tile_process(self) -> None:
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
total_steps = tiles_y * tiles_x
for i in tqdm(range(total_steps), desc="Upscaling"):
y = i // tiles_x
x = i % tiles_x
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
input_tile = self.img[
:,
:,
input_start_y_pad:input_end_y_pad,
input_start_x_pad:input_end_x_pad,
]
# upscale tile
with torch.no_grad():
output_tile = self.model(input_tile)
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[
:,
:,
output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile,
]
def post_process(self) -> torch.Tensor:
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.mod_pad_h * self.scale,
0 : w - self.mod_pad_w * self.scale,
]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.pre_pad * self.scale,
0 : w - self.pre_pad * self.scale,
]
return self.output
@torch.no_grad()
def upscale(self, img: MatLike, esrgan_alpha_upscale: bool = True) -> npt.NDArray[Any]:
np_img = img.astype(np.float32)
alpha: Optional[np.ndarray] = None
if np.max(np_img) > 256:
# 16-bit image
max_range = 65535
else:
max_range = 255
np_img = np_img / max_range
if len(np_img.shape) == 2:
# grayscale image
img_mode = ImageMode.L
np_img = cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB)
elif np_img.shape[2] == 4:
# RGBA image with alpha channel
img_mode = ImageMode.RGBA
alpha = np_img[:, :, 3]
np_img = np_img[:, :, 0:3]
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
if esrgan_alpha_upscale:
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = ImageMode.RGB
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(np_img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_tensor = self.post_process()
output_img: npt.NDArray[Any] = output_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode is ImageMode.L:
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode is ImageMode.RGBA:
if esrgan_alpha_upscale:
assert alpha is not None
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha_tensor = self.post_process()
output_alpha: npt.NDArray[Any] = output_alpha_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
assert alpha is not None
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(
alpha,
(w * self.scale, h * self.scale),
interpolation=cv2.INTER_LINEAR,
)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
return output

View File

@ -88,7 +88,7 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
if isinstance(image, str):
if type(image) is str:
image = Image.open(image).convert("RGB")
image = ImageOps.exif_transpose(image)

View File

@ -40,7 +40,7 @@ class InitImageResizer:
(rw, rh) = (int(scale * im.width), int(scale * im.height))
# round everything to multiples of 64
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
width, height, rw, rh = map(lambda x: x - x % 64, (width, height, rw, rh))
# no resize necessary, but return a copy
if im.width == width and im.height == height:

View File

@ -32,7 +32,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from huggingface_hub import HfFolder
from huggingface_hub import login as hf_hub_login
from omegaconf import OmegaConf
from pydantic import ValidationError
from pydantic.error_wrappers import ValidationError
from tqdm import tqdm
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -197,7 +197,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
def download_conversion_models():
target_dir = config.models_path / "core/convert"
kwargs = {} # for future use
kwargs = dict() # for future use
try:
logger.info("Downloading core tokenizers and text encoders")
@ -252,26 +252,26 @@ def download_conversion_models():
def download_realesrgan():
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"description": "RealESRGAN_x4plus.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"description": "RealESRGAN_x4plus_anime_6B.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"dest": "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"description": "ESRGAN_SRx4_DF2KOST_official.pth",
},
{
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
"dest": "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
"description": "RealESRGAN_x2plus.pth",
},
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
description="RealESRGAN_x4plus.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
description="RealESRGAN_x4plus_anime_6B.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description="ESRGAN_SRx4_DF2KOST_official.pth",
),
dict(
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description="RealESRGAN_x2plus.pth",
),
]
for model in URLs:
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
@ -680,7 +680,7 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
if program_opts.default_only
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all
else [],
else list(),
)

View File

@ -38,7 +38,6 @@ SAMPLER_CHOICES = [
"k_heun",
"k_lms",
"plms",
"lcm",
]
PRECISION_CHOICES = [

View File

@ -123,6 +123,8 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
for f in files:
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
# let them be copied as part of a tree copy operation
@ -141,6 +143,8 @@ class MigrateTo3(object):
logger.error(str(e))
except KeyboardInterrupt:
raise
except Exception as e:
logger.error(str(e))
def migrate_support_models(self):
"""
@ -178,10 +182,10 @@ class MigrateTo3(object):
"""
dest_directory = self.dest_models
kwargs = {
"cache_dir": self.root_directory / "models/hub",
kwargs = dict(
cache_dir=self.root_directory / "models/hub",
# local_files_only = True
}
)
try:
logger.info("Migrating core tokenizers and text encoders")
target_dir = dest_directory / "core" / "convert"
@ -312,11 +316,11 @@ class MigrateTo3(object):
dest_dir = self.dest_models
cache = self.root_directory / "models/hub"
kwargs = {
"cache_dir": cache,
"safety_checker": None,
kwargs = dict(
cache_dir=cache,
safety_checker=None,
# local_files_only = True,
}
)
owner, repo_name = repo_id.split("/")
model_name = model_name or repo_name

View File

@ -120,7 +120,7 @@ class ModelInstall(object):
be treated uniformly. It also sorts the models alphabetically
by their name, to improve the display somewhat.
"""
model_dict = {}
model_dict = dict()
# first populate with the entries in INITIAL_MODELS.yaml
for key, value in self.datasets.items():
@ -134,7 +134,7 @@ class ModelInstall(object):
model_dict[key] = model_info
# supplement with entries in models.yaml
installed_models = list(self.mgr.list_models())
installed_models = [x for x in self.mgr.list_models()]
for md in installed_models:
base = md["base_model"]
@ -176,7 +176,7 @@ class ModelInstall(object):
# logic here a little reversed to maintain backward compatibility
def starter_models(self, all_models: bool = False) -> Set[str]:
models = set()
for key, _value in self.datasets.items():
for key, value in self.datasets.items():
name, base, model_type = ModelManager.parse_key(key)
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
models.add(key)
@ -184,7 +184,7 @@ class ModelInstall(object):
def recommended_models(self) -> Set[str]:
starters = self.starter_models(all_models=True)
return {x for x in starters if self.datasets[x].get("recommended", False)}
return set([x for x in starters if self.datasets[x].get("recommended", False)])
def default_model(self) -> str:
starters = self.starter_models()
@ -234,7 +234,7 @@ class ModelInstall(object):
"""
if not models_installed:
models_installed = {}
models_installed = dict()
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
@ -252,14 +252,10 @@ class ModelInstall(object):
# folders style or similar
elif path.is_dir() and any(
(path / x).exists()
for x in {
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"pytorch_lora_weights.safetensors",
}
[
(path / x).exists()
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
]
):
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
@ -361,7 +357,7 @@ class ModelInstall(object):
for suffix in ["safetensors", "bin"]:
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
location = self._download_hf_model(
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder
) # LoRA
break
elif (
@ -431,17 +427,17 @@ class ModelInstall(object):
rel_path = self.relative_to_root(path, self.config.models_path)
attributes = {
"path": str(rel_path),
"description": str(description),
"model_format": info.format,
}
attributes = dict(
path=str(rel_path),
description=str(description),
model_format=info.format,
)
legacy_conf = None
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update(
{
"variant": info.variant_type,
}
dict(
variant=info.variant_type,
)
)
if info.format == "checkpoint":
try:
@ -472,7 +468,7 @@ class ModelInstall(object):
)
if legacy_conf:
attributes.update({"config": str(legacy_conf)})
attributes.update(dict(config=str(legacy_conf)))
return attributes
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
@ -517,7 +513,7 @@ class ModelInstall(object):
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
_, name = repo_id.split("/")
location = staging / name
paths = []
paths = list()
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(

View File

@ -130,9 +130,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.

View File

@ -56,7 +56,7 @@ class PerceiverAttention(nn.Module):
x = self.norm1(x)
latents = self.norm2(latents)
b, L, _ = latents.shape
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
@ -72,7 +72,7 @@ class PerceiverAttention(nn.Module):
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, L, -1)
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)

View File

@ -1,13 +1,12 @@
# ruff: noqa: I001, F401
"""
Initialization file for invokeai.backend.model_management
"""
# This import must be first
from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType
from .lora import ModelPatcher, ONNXModelPatcher
from .model_cache import ModelCache
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType # noqa: F401 isort: split
from .models import (
from .lora import ModelPatcher, ONNXModelPatcher # noqa: F401
from .model_cache import ModelCache # noqa: F401
from .models import ( # noqa: F401
BaseModelType,
DuplicateModelException,
ModelNotFoundException,
@ -17,4 +16,4 @@ from .models import (
)
# This import must be last
from .model_merge import MergeInterpolationMethod, ModelMerger
from .model_merge import ModelMerger, MergeInterpolationMethod # noqa: F401 isort: split

View File

@ -269,7 +269,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
resolution *= 2
up_block_types = []
for _i in range(len(block_out_channels)):
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
@ -1223,7 +1223,7 @@ def download_from_original_stable_diffusion_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1664,7 +1664,7 @@ def download_controlnet_from_original_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)

View File

@ -12,8 +12,6 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from .models.lora import LoRAModel
"""
@ -104,7 +102,7 @@ class ModelPatcher:
loras: List[Tuple[LoRAModel, float]],
prefix: str,
):
original_weights = {}
original_weights = dict()
try:
with torch.no_grad():
for lora, lora_weight in loras:
@ -166,15 +164,6 @@ class ModelPatcher:
init_tokens_count = None
new_tokens_added = None
# TODO: This is required since Transformers 4.32 see
# https://github.com/huggingface/transformers/pull/25088
# More information by NVIDIA:
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
# This value might need to be changed in the future and take the GPUs model into account as there seem
# to be ideal values for different GPUS. This value is temporary!
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
pad_to_multiple_of = 8
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
@ -184,7 +173,7 @@ class ModelPatcher:
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti_name, index):
trigger = ti_name
@ -199,7 +188,7 @@ class ModelPatcher:
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
@ -231,7 +220,7 @@ class ModelPatcher:
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
text_encoder.resize_token_embeddings(init_tokens_count)
@classmethod
@contextmanager
@ -242,7 +231,7 @@ class ModelPatcher:
):
skipped_layers = []
try:
for _i in range(clip_skip):
for i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield
@ -251,25 +240,6 @@ class ModelPatcher:
while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
@classmethod
@contextmanager
def apply_freeu(
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
):
did_apply_freeu = False
try:
if freeu_config is not None:
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
did_apply_freeu = True
yield
finally:
if did_apply_freeu:
unet.disable_freeu()
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
@ -324,7 +294,7 @@ class TextualInversionManager(BaseTextualInversionManager):
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.pad_tokens = dict()
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
@ -385,10 +355,10 @@ class ONNXModelPatcher:
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = {}
orig_weights = dict()
try:
blended_loras = {}
blended_loras = dict()
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
@ -404,7 +374,7 @@ class ONNXModelPatcher:
else:
blended_loras[layer_key] = layer_weight
node_names = {}
node_names = dict()
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name

View File

@ -66,13 +66,11 @@ class CacheStats(object):
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@ -134,7 +132,7 @@ class ModelCache(object):
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = {}
self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype = precision
@ -149,8 +147,8 @@ class ModelCache(object):
# used for stats collection
self.stats = None
self._cached_models = {}
self._cache_stack = []
self._cached_models = dict()
self._cache_stack = list()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:

View File

@ -26,5 +26,5 @@ def skip_torch_weight_init():
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
for torch_module, saved_function in zip(torch_modules, saved_functions):
torch_module.reset_parameters = saved_function

Some files were not shown because too many files have changed in this diff Show More