Compare commits

..

20 Commits

Author SHA1 Message Date
b07d7846da Merge branch 'main' into psyche/fix/nodes-denylist 2024-04-28 15:10:54 -04:00
40f35b3088 fix test_deny_nodes() test to not interfere with later tests 2024-04-28 15:10:43 -04:00
59deef97c5 Merge branch 'main' into lstein/feat/config-migration 2024-04-28 14:31:43 -04:00
d852ca7a8d added test for non-contiguous migration routines 2024-04-28 14:31:38 -04:00
d24877561d reinstated failing deny_nodes validation test for Graph 2024-04-25 00:22:09 -04:00
8144a263de updated and reinstated the test_deny_nodes() unit test 2024-04-24 22:14:48 -04:00
ab086a7069 Merge branch 'main' into lstein/feat/config-migration 2024-04-24 21:37:46 -04:00
048306b417 Merge branch 'main' into lstein/feat/config-migration 2024-04-24 21:37:12 -04:00
6eaed9a9cb check for strictly contiguous from_version->to_version ranges 2024-04-24 21:36:28 -04:00
ab9ebef345 tests(config): fix typo 2024-04-23 17:52:51 +10:00
984dd93798 tests(config): add failing test case to for config migrator 2024-04-23 17:50:31 +10:00
d12fb7db68 fix(config): fix duplicate migration logic
This was checking a `Version` object against a `MigrationEntry`, but what we want is to check the version object against `MigrationEntry.from_version`
2024-04-23 17:25:53 +10:00
5d411e446a tidy(config): use a type alias for the migration function 2024-04-23 17:21:05 +10:00
6f128c86b4 tidy(config): use dataclass for MigrationEntry
The only pydantic usage was to convert strings to `Version` objects. The reason to do this conversion was to allow the register decorator to accept strings. MigrationEntry is only created inside this class, so we can just create versions from each migration when instantiating MigrationEntry instead.

Also, pydantic doesn't provide runtime time checking for arbitrary classes like Version, so we don't get any real benefit.
2024-04-23 17:19:54 +10:00
aca9e44a3a fix(config): use TypeAlias instead of TypeVar
TypeVar is for generics, but the usage here is as an alias
2024-04-23 17:12:19 +10:00
e39f035264 tidy(config): removed extraneous ABC
We don't need separate implementations for this class, let's not complicate it with an ABC
2024-04-23 17:11:13 +10:00
b612c73954 tidy(config): remove unused TYPE_CHECKING block 2024-04-23 17:09:50 +10:00
36495b730d use packaging.version rather than version-parse 2024-04-18 23:07:54 -04:00
6ad1948a44 add InvokeAIAppConfig schema migration system 2024-04-18 21:33:54 -04:00
f07a46f195 fix(nodes): respect nodes denylist
In #5838 graph validation was updated to resolve the issue with the nodes union and import order. That broke the nodes denylist functionality.

However, because the corresponding test was marked as `xfail`, we didn't catch the issue.

- Fix the nodes denylist handling
- Update the tests
2024-03-08 12:52:41 +11:00
1324 changed files with 60801 additions and 83754 deletions

View File

@ -9,9 +9,9 @@ runs:
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v4
uses: pnpm/action-setup@v2
with:
version: 8.15.6
version: 8
run_install: false
- name: get pnpm store directory

View File

@ -8,7 +8,7 @@
## QA Instructions
<!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.-->
<!--WHEN APPLICABLE: Describe how we can test the changes in this PR.-->
## Merge Plan

View File

@ -62,7 +62,7 @@ jobs:
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff==0.6.0
run: pip install ruff
shell: bash
- name: ruff check

View File

@ -60,7 +60,7 @@ jobs:
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-14
os: macOS-12
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022

View File

@ -18,7 +18,6 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
@ -71,6 +70,3 @@ installer-zip:
tag-release:
cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@ -12,24 +12,12 @@
Invoke is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. Invoke offers an industry leading web-based UI, and serves as the foundation for multiple commercial products.
Invoke is available in two editions:
| **Community Edition** | **Professional Edition** |
|----------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| **For users looking for a locally installed, self-hosted and self-managed service** | **For users or teams looking for a cloud-hosted, fully managed service** |
| - Free to use under a commercially-friendly license | - Monthly subscription fee with three different plan levels |
| - Download and install on compatible hardware | - Offers additional benefits, including multi-user support, improved model training, and more |
| - Includes all core studio features: generate, refine, iterate on images, and build workflows | - Hosted in the cloud for easy, secure model access and scalability |
| Quick Start -> [Installation and Updates][installation docs] | More Information -> [www.invoke.com/pricing](https://www.invoke.com/pricing) |
[Installation][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs]
<div align="center">
![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d)
# Documentation
| **Quick Links** |
|----------------------------------------------------------------------------------------------------------------------------|
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
</div>
## Quick Start
@ -49,33 +37,6 @@ Invoke is available in two editions:
More detail, including hardware requirements and manual install instructions, are available in the [installation documentation][installation docs].
## Docker Container
We publish official container images in Github Container Registry: https://github.com/invoke-ai/InvokeAI/pkgs/container/invokeai. Both CUDA and ROCm images are available. Check the above link for relevant tags.
> [!IMPORTANT]
> Ensure that Docker is set up to use the GPU. Refer to [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] documentation.
### Generate!
Run the container, modifying the command as necessary:
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
Then open `http://localhost:9090` and install some models using the Model Manager tab to begin generating.
For ROCm, add `--device /dev/kfd --device /dev/dri` to the `docker run` command.
### Persist your data
You will likely want to persist your workspace outside of the container. Use the `--volume /home/myuser/invokeai:/invokeai` flag to mount some local directory (using its **absolute** path) to the `/invokeai` path inside the container. Your generated images and models will reside there. You can use this directory with other InvokeAI installations, or switch between runtime directories as needed.
### DIY
Build your own image and customize the environment to match your needs using our `docker-compose` stack. See [README.md](./docker/README.md) in the [docker](./docker) directory.
## Troubleshooting, FAQ and Support
Please review our [FAQ][faq] for solutions to common installation problems and other issues.
@ -153,5 +114,3 @@ Original portions of the software are Copyright © 2024 by respective contributo
[latest release link]: https://github.com/invoke-ai/InvokeAI/releases/latest
[translation status badge]: https://hosted.weblate.org/widgets/invokeai/-/svg-badge.svg
[translation status link]: https://hosted.weblate.org/engage/invokeai/
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html

View File

@ -19,9 +19,8 @@
## INVOKEAI_PORT is the port on which the InvokeAI web interface will be available
# INVOKEAI_PORT=9090
## GPU_DRIVER can be set to either `cuda` or `rocm` to enable GPU support in the container accordingly.
# GPU_DRIVER=cuda #| rocm
## GPU_DRIVER can be set to either `nvidia` or `rocm` to enable GPU support in the container accordingly.
# GPU_DRIVER=nvidia #| rocm
## CONTAINER_UID can be set to the UID of the user on the host system that should own the files in the container.
## It is usually not necessary to change this. Use `id -u` on the host system to find the UID.
# CONTAINER_UID=1000

View File

@ -55,7 +55,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
FROM node:20-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
WORKDIR /build

View File

@ -1,88 +1,41 @@
# Invoke in Docker
# InvokeAI Containerized
First things first:
All commands should be run within the `docker` directory: `cd docker`
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
## Quickstart :rocket:
## Quickstart
On a known working Linux+Docker+CUDA (Nvidia) system, execute `./run.sh` in this directory. It will take a few minutes - depending on your internet speed - to install the core models. Once the application starts up, open `http://localhost:9090` in your browser to Invoke!
No `docker compose`, no persistence, single command, using the official images:
For more configuration options (using an AMD GPU, custom root directory location, etc): read on.
**CUDA (NVIDIA GPU):**
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
**ROCm (AMD GPU):**
```bash
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
```
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
### Data persistence
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
```bash
docker run --volume /some/local/path:/invokeai {...etc...}
```
`/some/local/path/invokeai` will contain all your data.
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
## Customize the container
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
```bash
cd docker
cp .env.sample .env
# edit .env to your liking if you need to; it is well commented.
./run.sh
```
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
>[!TIP]
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
## Docker setup in detail
## Detailed setup
#### Linux
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
- The deprecated `docker-compose` (hyphenated) CLI continues to work for now.
3. Ensure docker daemon is able to access the GPU.
- [NVIDIA docs](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
- [AMD docs](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html)
- You may need to install [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
#### macOS
> [!TIP]
> You'll be better off installing Invoke directly on your system, because Docker can not use the GPU on macOS.
If you are still reading:
1. Ensure Docker has at least 16GB RAM
2. Enable VirtioFS for file sharing
3. Enable `docker compose` V2 support
This is done via Docker Desktop preferences.
This is done via Docker Desktop preferences
### Configure the Invoke Environment
### Configure Invoke environment
1. Make a copy of `.env.sample` and name it `.env` (`cp .env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to the desired location of the InvokeAI runtime directory. It may be an existing directory from a previous installation (post 4.0.0).
1. Make a copy of `.env.sample` and name it `.env` (`cp .env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
a. the desired location of the InvokeAI runtime directory, or
b. an existing, v3.0.0 compatible runtime directory.
1. Execute `run.sh`
The image will be built automatically if needed.
The runtime directory (holding models and outputs) will be created in the location specified by `INVOKEAI_ROOT`. The default location is `~/invokeai`. Navigate to the Model Manager tab and install some models before generating.
The runtime directory (holding models and outputs) will be created in the location specified by `INVOKEAI_ROOT`. The default location is `~/invokeai`. The runtime directory will be populated with the base configs and models necessary to start generating.
### Use a GPU
@ -90,9 +43,9 @@ The runtime directory (holding models and outputs) will be created in the locati
- WSL2 is *required* for Windows.
- only `x86_64` architecture is supported.
The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker/NVIDIA/AMD documentation for the most up-to-date instructions for using your GPU with Docker.
The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker documentation for the most up-to-date instructions for using your GPU with Docker.
To use an AMD GPU, set `GPU_DRIVER=rocm` in your `.env` file before running `./run.sh`.
To use an AMD GPU, set `GPU_DRIVER=rocm` in your `.env` file.
## Customize
@ -106,12 +59,30 @@ Values are optional, but setting `INVOKEAI_ROOT` is highly recommended. The defa
INVOKEAI_ROOT=/Volumes/WorkDrive/invokeai
HUGGINGFACE_TOKEN=the_actual_token
CONTAINER_UID=1000
GPU_DRIVER=cuda
GPU_DRIVER=nvidia
```
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
---
## Even Moar Customizing!
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
### Reconfigure the runtime directory
Can be used to download additional models from the supported model list
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
```yaml
command:
- invokeai-configure
- --yes
```
Or install models:
```yaml
command:
- invokeai-model-install
```

View File

@ -1,5 +1,7 @@
# Copyright (c) 2023 Eugene Brodsky https://github.com/ebr
version: '3.8'
x-invokeai: &invokeai
image: "local/invokeai:latest"
build:
@ -30,7 +32,7 @@ x-invokeai: &invokeai
services:
invokeai-cuda:
invokeai-nvidia:
<<: *invokeai
deploy:
resources:

View File

@ -23,18 +23,18 @@ usermod -u ${USER_ID} ${USER} 1>/dev/null
# but it is useful to have the full SSH server e.g. on Runpod.
# (use SCP to copy files to/from the image, etc)
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]; then
apt-get update
apt-get install -y openssh-server
pushd "$HOME"
mkdir -p .ssh
echo "${PUBLIC_KEY}" >.ssh/authorized_keys
chmod -R 700 .ssh
popd
service ssh start
apt-get update
apt-get install -y openssh-server
pushd "$HOME"
mkdir -p .ssh
echo "${PUBLIC_KEY}" > .ssh/authorized_keys
chmod -R 700 .ssh
popd
service ssh start
fi
mkdir -p "${INVOKEAI_ROOT}"
chown --recursive ${USER} "${INVOKEAI_ROOT}" || true
chown --recursive ${USER} "${INVOKEAI_ROOT}"
cd "${INVOKEAI_ROOT}"
# Run the CMD as the Container User (not root).

View File

@ -8,15 +8,11 @@ run() {
local build_args=""
local profile=""
# create .env file if it doesn't exist, otherwise docker compose will fail
touch .env
# parse .env file for build args
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
# default to 'cuda' profile
[[ -z "$profile" ]] && profile="cuda"
[[ -z "$profile" ]] && profile="nvidia"
local service_name="invokeai-$profile"

View File

@ -128,8 +128,7 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.
Two job types are defined. `DownloadJob` and
`MultiFileDownloadJob`. The former is a pydantic object with the
The only job type currently implemented is `DownloadJob`, a pydantic object with the
following fields:
| **Field** | **Type** | **Default** | **Description** |
@ -139,7 +138,7 @@ following fields:
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
@ -191,33 +190,6 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.
The `MultiFileDownloadJob` is used for diffusers model downloads,
which contain multiple files and directories under a common root:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `download_parts` | Set[DownloadJob]| | Component download jobs |
| `dest` | Path | | Where to download to |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the root of the downloaded files |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
Note that the MultiFileDownloadJob does not support the `priority`,
`job_started`, `job_ended` or `content_type` attributes. You can get
these from the individual download jobs in `download_parts`.
### Callbacks
Download jobs can be associated with a series of callbacks, each with
@ -279,40 +251,11 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
#### job = queue.download(source, dest, priority, access_token)
Create a new download job and put it on the queue, returning the
DownloadJob object.
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
This is similar to download(), but instead of taking a single source,
it accepts a `parts` argument consisting of a list of
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
where the URL is the location of the remote file, and the Path is the
destination.
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
consists of a url/path pair. Note that the path *must* be relative.
The method returns a `MultiFileDownloadJob`.
```
from invokeai.backend.model_manager.metadata import RemoteModelFile
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
path='my_model/textencoder/pytorch_model.safetensors'
)
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
path='my_model/vae/diffusers_model.safetensors'
)
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
dest='/tmp/downloads',
on_progress=TqdmProgress().update)
queue.wait_for_job(job)
print(f"The files were downloaded to {job.download_path}")
```
#### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s.

View File

@ -397,25 +397,26 @@ In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import get_config
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db, logger)
db = SqliteDatabase(config, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()
installer = ModelInstallService(app_config=config,
installer = ModelInstallService(app_config=config,
record_store=record_store,
download_queue=queue
)
download_queue=queue
)
installer.start()
```
@ -1366,20 +1367,12 @@ the in-memory loaded model:
| `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
### get_model_by_key(key, [submodel]) -> LoadedModel
The `get_model_by_key()` method will retrieve the model using its
unique database key. For example:
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
`get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Using the Loaded Model in Inference
Because the loader can return multiple model types, it is typed to
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
`LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model
@ -1387,33 +1380,17 @@ in the execution device for the duration of the context, and returns
the model. Use it like this:
```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model as vae:
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae:
image = vae.decode(latents)[0]
```
The object returned by the LoadedModel context manager is an
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
In addition, you may call `LoadedModel.model_on_device()`, a context
manager that returns a tuple of the model's state dict in CPU and the
model itself in VRAM. It is used to optimize the LoRA patching and
unpatching process:
```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
Since not all models have state dicts, the `state_dict` return value
can be None.
`get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events
When the `context` argument is passed to `load_model_*()`, it will
@ -1601,59 +1578,3 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.
## Invocation Context Model Manager API
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

View File

@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
"Custom" fields will always be treated as stateless fields.
##### Single and Collection Fields
##### Collection and Scalar Fields
Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field.
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field.
- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`).
- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`).
- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`).
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
## Implementation
@ -173,7 +173,8 @@ Field types are represented as structured objects:
```ts
type FieldType = {
name: string;
cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION';
isCollection: boolean;
isCollectionOrScalar: boolean;
};
```
@ -185,7 +186,7 @@ There are 4 general cases for field type parsing.
When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property.
We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`.
We create a field type name from this `type` string (e.g. `string` -> `StringField`).
##### Complex Types
@ -199,13 +200,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi
When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type.
We use the item type for field type name. The cardinality is `"COLLECTION"`.
We use the item type for field type name, adding `isCollection: true` to the field type.
##### Single or Collection Types
##### Collection or Scalar Types
When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type.
##### Optional Fields

View File

@ -165,7 +165,7 @@ Additionally, each section can be expanded with the "Show Advanced" button in o
There are several ways to install IP-Adapter models with an existing InvokeAI installation:
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models.
2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](https://www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
#### Using IP-Adapter

View File

@ -4,6 +4,278 @@ title: Training
# :material-file-document: Training
Invoke Training has moved to its own repository, with a dedicated UI for accessing common scripts like Textual Inversion and LoRA training.
# Textual Inversion Training
## **Personalizing Text-to-Image Generation**
You can find more by visiting the repo at https://github.com/invoke-ai/invoke-training
You may personalize the generated images to provide your own styles or objects
by training a new LDM checkpoint and introducing a new vocabulary to the fixed
model as a (.pt) embeddings file. Alternatively, you may use or train
HuggingFace Concepts embeddings files (.bin) from
<https://huggingface.co/sd-concepts-library> and its associated
notebooks.
## **Hardware and Software Requirements**
You will need a GPU to perform training in a reasonable length of
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
training process further. During training, about ~8 GB is temporarily
needed in order to store intermediate models, checkpoints and logs.
## **Preparing for Training**
To train, prepare a folder that contains 3-5 images that illustrate
the object or concept. It is good to provide a variety of examples or
poses to avoid overtraining the system. Format these images as PNG
(preferred) or JPG. You do not need to resize or crop the images in
advance, but for more control you may wish to do so.
Place the training images in a directory on the machine InvokeAI runs
on. We recommend placing them in a subdirectory of the
`text-inversion-training-data` folder located in the InvokeAI root
directory, ordinarily `~/invokeai` (Linux/Mac), or
`C:\Users\your_name\invokeai` (Windows). For example, to create an
embedding for the "psychedelic" style, you'd place the training images
into the directory
`~invokeai/text-inversion-training-data/psychedelic`.
## **Launching Training Using the Console Front End**
InvokeAI 2.3 and higher comes with a text console-based training front
end. From within the `invoke.sh`/`invoke.bat` Invoke launcher script,
start training tool selecting choice (3):
```sh
1 "Generate images with a browser-based interface"
2 "Explore InvokeAI nodes using a command-line interface"
3 "Textual inversion training"
4 "Merge models (diffusers type only)"
5 "Download and install models"
6 "Change InvokeAI startup options"
7 "Re-run the configure script to fix a broken install or to complete a major upgrade"
8 "Open the developer console"
9 "Update InvokeAI"
```
Alternatively, you can select option (8) or from the command line, with the InvokeAI virtual environment active,
you can then launch the front end with the command `invokeai-ti --gui`.
This will launch a text-based front end that will look like this:
<figure markdown>
![ti-frontend](../assets/textual-inversion/ti-frontend.png)
</figure>
The interface is keyboard-based. Move from field to field using
control-N (^N) to move to the next field and control-P (^P) to the
previous one. <Tab> and <shift-TAB> work as well. Once a field is
active, use the cursor keys. In a checkbox group, use the up and down
cursor keys to move from choice to choice, and <space> to select a
choice. In a scrollbar, use the left and right cursor keys to increase
and decrease the value of the scroll. In textfields, type the desired
values.
The number of parameters may look intimidating, but in most cases the
predefined defaults work fine. The red circled fields in the above
illustration are the ones you will adjust most frequently.
### Model Name
This will list all the diffusers models that are currently
installed. Select the one you wish to use as the basis for your
embedding. Be aware that if you use a SD-1.X-based model for your
training, you will only be able to use this embedding with other
SD-1.X-based models. Similarly, if you train on SD-2.X, you will only
be able to use the embeddings with models based on SD-2.X.
### Trigger Term
This is the prompt term you will use to trigger the embedding. Type a
single word or phrase you wish to use as the trigger, example
"psychedelic" (without angle brackets). Within InvokeAI, you will then
be able to activate the trigger using the syntax `<psychedelic>`.
### Initializer
This is a single character that is used internally during the training
process as a placeholder for the trigger term. It defaults to "*" and
can usually be left alone.
### Resume from last saved checkpoint
As training proceeds, textual inversion will write a series of
intermediate files that can be used to resume training from where it
was left off in the case of an interruption. This checkbox will be
automatically selected if you provide a previously used trigger term
and at least one checkpoint file is found on disk.
Note that as of 20 January 2023, resume does not seem to be working
properly due to an issue with the upstream code.
### Data Training Directory
This is the location of the images to be used for training. When you
select a trigger term like "my-trigger", the frontend will prepopulate
this field with `~/invokeai/text-inversion-training-data/my-trigger`,
but you can change the path to wherever you want.
### Output Destination Directory
This is the location of the logs, checkpoint files, and embedding
files created during training. When you select a trigger term like
"my-trigger", the frontend will prepopulate this field with
`~/invokeai/text-inversion-output/my-trigger`, but you can change the
path to wherever you want.
### Image resolution
The images in the training directory will be automatically scaled to
the value you use here. For best results, you will want to use the
same default resolution of the underlying model (512 pixels for
SD-1.5, 768 for the larger version of SD-2.1).
### Center crop images
If this is selected, your images will be center cropped to make them
square before resizing them to the desired resolution. Center cropping
can indiscriminately cut off the top of subjects' heads for portrait
aspect images, so if you have images like this, you may wish to use a
photoeditor to manually crop them to a square aspect ratio.
### Mixed precision
Select the floating point precision for the embedding. "no" will
result in a full 32-bit precision, "fp16" will provide 16-bit
precision, and "bf16" will provide mixed precision (only available
when XFormers is used).
### Max training steps
How many steps the training will take before the model converges. Most
training sets will converge with 2000-3000 steps.
### Batch size
This adjusts how many training images are processed simultaneously in
each step. Higher values will cause the training process to run more
quickly, but use more memory. The default size will run with GPUs with
as little as 12 GB.
### Learning rate
The rate at which the system adjusts its internal weights during
training. Higher values risk overtraining (getting the same image each
time), and lower values will take more steps to train a good
model. The default of 0.0005 is conservative; you may wish to increase
it to 0.005 to speed up training.
### Scale learning rate by number of GPUs, steps and batch size
If this is selected (the default) the system will adjust the provided
learning rate to improve performance.
### Use xformers acceleration
This will activate XFormers memory-efficient attention. You need to
have XFormers installed for this to have an effect.
### Learning rate scheduler
This adjusts how the learning rate changes over the course of
training. The default "constant" means to use a constant learning rate
for the entire training session. The other values scale the learning
rate according to various formulas.
Only "constant" is supported by the XFormers library.
### Gradient accumulation steps
This is a parameter that allows you to use bigger batch sizes than
your GPU's VRAM would ordinarily accommodate, at the cost of some
performance.
### Warmup steps
If "constant_with_warmup" is selected in the learning rate scheduler,
then this provides the number of warmup steps. Warmup steps have a
very low learning rate, and are one way of preventing early
overtraining.
## The training run
Start the training run by advancing to the OK button (bottom right)
and pressing <enter>. A series of progress messages will be displayed
as the training process proceeds. This may take an hour or two,
depending on settings and the speed of your system. Various log and
checkpoint files will be written into the output directory (ordinarily
`~/invokeai/text-inversion-output/my-model/`)
At the end of successful training, the system will copy the file
`learned_embeds.bin` into the InvokeAI root directory's `embeddings`
directory, using a subdirectory named after the trigger token. For
example, if the trigger token was `psychedelic`, then look for the
embeddings file in
`~/invokeai/embeddings/psychedelic/learned_embeds.bin`
You may now launch InvokeAI and try out a prompt that uses the trigger
term. For example `a plate of banana sushi in <psychedelic> style`.
## **Training with the Command-Line Script**
Training can also be done using a traditional command-line script. It
can be launched from within the "developer's console", or from the
command line after activating InvokeAI's virtual environment.
It accepts a large number of arguments, which can be summarized by
passing the `--help` argument:
```sh
invokeai-ti --help
```
Typical usage is shown here:
```sh
invokeai-ti \
--model=stable-diffusion-1.5 \
--resolution=512 \
--learnable_property=style \
--initializer_token='*' \
--placeholder_token='<psychedelic>' \
--train_data_dir=/home/lstein/invokeai/training-data/psychedelic \
--output_dir=/home/lstein/invokeai/text-inversion-training/psychedelic \
--scale_lr \
--train_batch_size=8 \
--gradient_accumulation_steps=4 \
--max_train_steps=3000 \
--learning_rate=0.0005 \
--resume_from_checkpoint=latest \
--lr_scheduler=constant \
--mixed_precision=fp16 \
--only_save_embeds
```
## Troubleshooting
### `Cannot load embedding for <trigger>. It was trained on a model with token dimension 1024, but the current model has token dimension 768`
Messages like this indicate you trained the embedding on a different base model than the currently selected one.
For example, in the error above, the training was done on SD2.1 (768x768) but it was used on SD1.5 (512x512).
## Reading
For more information on textual inversion, please see the following
resources:
* The [textual inversion repository](https://github.com/rinongal/textual_inversion) and
associated paper for details and limitations.
* [HuggingFace's textual inversion training
page](https://huggingface.co/docs/diffusers/training/text_inversion)
* [HuggingFace example script
documentation](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)
(Note that this script is similar to, but not identical, to
`textual_inversion`, but produces embed files that are completely compatible.
---
copyright (c) 2023, Lincoln Stein and the InvokeAI Development Team

View File

@ -154,18 +154,6 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them.
## `ModuleNotFoundError: No module named 'controlnet_aux'`
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself

View File

@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
4. The VAE decodes the final latent image from latent space into image space.
Image-to-image is a similar process, with only step 1 being different:
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).

View File

@ -1,10 +1,8 @@
# Automatic Install & Updates
# Automatic Install
**The same packaged installer file can be used for both new installs and updates.**
Using the installer for updates will leave everything you've added since installation, and just update the core libraries used to run Invoke.
Simply use the same path you installed to originally.
The installer is used for both new installs and updates.
Both release and pre-release versions can be installed using the installer. It also supports install through a wheel if needed.
Both release and pre-release versions can be installed using it. It also supports install a wheel if needed.
Be sure to review the [installation requirements] and ensure your system has everything it needs to install Invoke.
@ -98,7 +96,7 @@ Updating is exactly the same as installing - download the latest installer, choo
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
[installation requirements]: INSTALL_REQUIREMENTS.md
[installation requirements]: INSTALLATION.md#installation-requirements
[FAQ]: ../help/FAQ.md
[install some models]: 050_INSTALLING_MODELS.md
[configuration docs]: ../features/CONFIGURATION.md

View File

@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The
### Requirements
Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md).
Before you start, go through the [installation requirements].
### Installation Walkthrough
@ -79,7 +79,7 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
- You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
!!! example "Install with an extra index URL"
@ -116,4 +116,4 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME
!!! warning
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.

View File

@ -4,37 +4,50 @@ title: Installing with Docker
# :fontawesome-brands-docker: Docker
!!! warning "macOS users"
!!! warning "macOS and AMD GPU Users"
Docker can not access the GPU on macOS, so your generation speeds will be slow. [Install InvokeAI](INSTALLATION.md) instead.
We highly recommend to Install InvokeAI locally using [these instructions](INSTALLATION.md),
because Docker containers can not access the GPU on macOS.
!!! warning "AMD GPU Users"
Container support for AMD GPUs has been reported to work by the community, but has not received
extensive testing. Please make sure to set the `GPU_DRIVER=rocm` environment variable (see below), and
use the `build.sh` script to build the image for this to take effect at build time.
!!! tip "Linux and Windows Users"
Configure Docker to access your machine's GPU.
For optimal performance, configure your Docker daemon to access your machine's GPU.
Docker Desktop on Windows [includes GPU support](https://www.docker.com/blog/wsl-2-gpu-support-for-docker-desktop-on-nvidia-gpus/).
Linux users should follow the [NVIDIA](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) or [AMD](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html) documentation.
Linux users should install and configure the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
## Why containers?
They provide a flexible, reliable way to build and deploy InvokeAI.
See [Processes](https://12factor.net/processes) under the Twelve-Factor App
methodology for details on why running applications in such a stateless fashion is important.
The container is configured for CUDA by default, but can be built to support AMD GPUs
by setting the `GPU_DRIVER=rocm` environment variable at Docker image build time.
Developers on Apple silicon (M1/M2/M3): You
[can't access your GPU cores from Docker containers](https://github.com/pytorch/pytorch/issues/81224)
and performance is reduced compared with running it directly on macOS but for
development purposes it's fine. Once you're done with development tasks on your
laptop you can build for the target platform and architecture and deploy to
another environment with NVIDIA GPUs on-premises or in the cloud.
## TL;DR
Ensure your Docker setup is able to use your GPU. Then:
```bash
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
```
Once the container starts up, open http://localhost:9090 in your browser, install some models, and start generating.
## Build-It-Yourself
All the docker materials are located inside the [docker](https://github.com/invoke-ai/InvokeAI/tree/main/docker) directory in the Git repo.
This assumes properly configured Docker on Linux or Windows/WSL2. Read on for detailed customization options.
```bash
# docker compose commands should be run from the `docker` directory
cd docker
cp .env.sample .env
docker compose up
```
We also ship the `run.sh` convenience script. See the `docker/README.md` file for detailed instructions on how to customize the docker setup to your needs.
## Installation in a Linux container (desktop)
### Prerequisites
@ -45,9 +58,18 @@ Preferences, Resources, Advanced. Increase the CPUs and Memory to avoid this
[Issue](https://github.com/invoke-ai/InvokeAI/issues/342). You may need to
increase Swap and Disk image size too.
#### Get a Huggingface-Token
Besides the Docker Agent you will need an Account on
[huggingface.co](https://huggingface.co/join).
After you succesfully registered your account, go to
[huggingface.co/settings/tokens](https://huggingface.co/settings/tokens), create
a token and copy it, since you will need in for the next step.
### Setup
Set up your environment variables. In the `docker` directory, make a copy of `.env.sample` and name it `.env`. Make changes as necessary.
Set up your environmnent variables. In the `docker` directory, make a copy of `.env.sample` and name it `.env`. Make changes as necessary.
Any environment variables supported by InvokeAI can be set here - please see the [CONFIGURATION](../features/CONFIGURATION.md) for further detail.
@ -81,9 +103,10 @@ Once the container starts up (and configures the InvokeAI root directory if this
## Troubleshooting / FAQ
- Q: I am running on Windows under WSL2, and am seeing a "no such file or directory" error.
- A: Your `docker-entrypoint.sh` might have has Windows (CRLF) line endings, depending how you cloned the repository.
To solve this, change the line endings in the `docker-entrypoint.sh` file to `LF`. You can do this in VSCode
- A: Your `docker-entrypoint.sh` file likely has Windows (CRLF) as opposed to Unix (LF) line endings,
and you may have cloned this repository before the issue was fixed. To solve this, please change
the line endings in the `docker-entrypoint.sh` file to `LF`. You can do this in VSCode
(`Ctrl+P` and search for "line endings"), or by using the `dos2unix` utility in WSL.
Finally, you may delete `docker-entrypoint.sh` followed by `git pull; git checkout docker/docker-entrypoint.sh`
to reset the file to its most recent version.
For more information on this issue, see [Docker Desktop documentation](https://docs.docker.com/desktop/troubleshoot/topics/#avoid-unexpected-syntax-errors-use-unix-style-line-endings-for-files-in-containers)
For more information on this issue, please see the [Docker Desktop documentation](https://docs.docker.com/desktop/troubleshoot/topics/#avoid-unexpected-syntax-errors-use-unix-style-line-endings-for-files-in-containers)

View File

@ -1,4 +1,4 @@
# Installation and Updating Overview
# Installation Overview
Before installing, review the [installation requirements] to ensure your system is set up properly.
@ -6,21 +6,14 @@ See the [FAQ] for frequently-encountered installation issues.
If you need more help, join our [discord] or [create an issue].
<h2>Automatic Install & Updates </h2>
<h2>Automatic Install</h2>
✅ The automatic install is the best way to run InvokeAI. Check out the [installation guide] to get started.
⬆️ The same installer is also the best way to update InvokeAI - Simply rerun it for the same folder you installed to.
The installation process simply manages installation for the core libraries & application dependencies that run Invoke.
Any models, images, or other assets in the Invoke root folder won't be affected by the installation process.
<h2>Manual Install</h2>
If you are familiar with python and want more control over the packages that are installed, you can [install InvokeAI manually via PyPI].
Updates are managed by reinstalling the latest version through PyPi.
<h2>Developer Install</h2>
If you want to contribute to InvokeAI, consult the [developer install guide].

View File

@ -37,13 +37,13 @@ Invoke runs best with a dedicated GPU, but will fall back to running on CPU, alb
=== "Nvidia"
```
Any GPU with at least 8GB VRAM.
Any GPU with at least 8GB VRAM. Linux only.
```
=== "AMD"
```
Any GPU with at least 16GB VRAM. Linux only.
Any GPU with at least 16GB VRAM.
```
=== "Mac"

View File

@ -10,10 +10,11 @@ set INVOKEAI_ROOT=.
echo Desired action:
echo 1. Generate images with the browser-based interface
echo 2. Open the developer console
echo 3. Command-line help
echo 3. Run the InvokeAI image database maintenance script
echo 4. Command-line help
echo Q - Quit
echo.
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.
echo.
set /P choice="Please enter 1-4, Q: [1] "
if not defined choice set choice=1
@ -33,6 +34,9 @@ IF /I "%choice%" == "1" (
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k
) ELSE IF /I "%choice%" == "3" (
echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "4" (
echo Displaying command line help...
python .venv\Scripts\invokeai-web.exe --help %*
pause

View File

@ -17,7 +17,7 @@
set -eu
# Ensure we're in the correct folder in case user's CWD is somewhere else
scriptdir=$(dirname $(readlink -f "$0"))
scriptdir=$(dirname "$0")
cd "$scriptdir"
. .venv/bin/activate
@ -47,6 +47,11 @@ do_choice() {
bash --init-file "$file_name"
;;
3)
clear
printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
;;
4)
clear
printf "Command-line help\n"
invokeai-web --help
@ -66,7 +71,8 @@ do_line_input() {
printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n"
printf "2: Open the developer console\n"
printf "3: Command-line help\n"
printf "3: Run the InvokeAI image database maintenance script\n"
printf "4: Command-line help\n"
printf "Q: Quit\n\n"
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n"
read -p "Please enter 1-4, Q: [1] " yn

View File

@ -1,45 +1,40 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from logging import Logger
import torch
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_images.board_images_default import BoardImagesService
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.names.names_default import SimpleNameService
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
from invokeai.app.services.session_processor.session_processor_default import (
DefaultSessionProcessor,
DefaultSessionRunner,
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?
def check_internet() -> bool:
@ -66,12 +61,7 @@ class ApiDependencies:
invoker: Invoker
@staticmethod
def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
@ -82,7 +72,6 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
db = init_db(config=config, logger=logger, image_files=image_files)
@ -93,7 +82,7 @@ class ApiDependencies:
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id, loop=loop)
events = FastAPIEventService(event_handler_id)
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
@ -104,22 +93,20 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
model_record_service=ModelRecordServiceSQL(db=db),
download_queue=download_queue_service,
events=events,
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
services = InvocationServices(
board_image_records=board_image_records,
@ -145,8 +132,6 @@ class ApiDependencies:
workflow_records=workflow_records,
tensors=tensors,
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
)
ApiDependencies.invoker = Invoker(services)

View File

@ -0,0 +1,52 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -10,13 +10,15 @@ from fastapi import Body
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging
from invokeai.version import __version__
from ..dependencies import ApiDependencies
class LogLevel(int, Enum):
NotSet = logging.NOTSET
@ -107,7 +109,9 @@ async def get_config() -> AppConfig:
upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
nsfw_methods = ["nsfw_checker"]
nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker")
watermarking_methods = ["invisible_watermark"]

View File

@ -2,7 +2,7 @@ from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])

View File

@ -4,11 +4,12 @@ from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@ -31,7 +32,6 @@ class DeleteBoardResult(BaseModel):
)
async def create_board(
board_name: str = Query(description="The name of the board to create"),
is_private: bool = Query(default=False, description="Whether the board is private"),
) -> BoardDTO:
"""Creates a board"""
try:
@ -118,13 +118,15 @@ async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all(include_archived)
return ApiDependencies.invoker.services.boards.get_all()
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
else:
raise HTTPException(
status_code=400,

View File

@ -8,12 +8,13 @@ from fastapi.routing import APIRouter
from pydantic.networks import AnyHttpUrl
from starlette.exceptions import HTTPException
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.download import (
DownloadJob,
UnknownJobIDException,
)
from ..dependencies import ApiDependencies
download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"])

View File

@ -6,18 +6,15 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field, JsonValue
from pydantic import BaseModel, Field, ValidationError
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@ -45,17 +42,13 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
metadata: Optional[JsonValue] = Body(
default=None, description="The metadata to associate with the image", embed=True
),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
_metadata = None
_workflow = None
_graph = None
metadata = None
workflow = None
contents = await file.read()
try:
@ -69,28 +62,22 @@ async def upload_image(
# TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image
metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
pass
metadata_raw = pil_image.info.get("invokeai_metadata", None)
if metadata_raw:
try:
metadata = MetadataFieldValidator.validate_json(metadata_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
# attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
pass
# attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass
if workflow_raw is not None:
try:
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
try:
image_dto = ApiDependencies.invoker.services.images.create(
@ -99,9 +86,8 @@ async def upload_image(
image_category=image_category,
session_id=session_id,
board_id=board_id,
metadata=_metadata,
workflow=_workflow,
graph=_graph,
metadata=metadata,
workflow=workflow,
is_intermediate=is_intermediate,
)
@ -199,27 +185,21 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
class WorkflowAndGraphResponse(BaseModel):
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> WorkflowAndGraphResponse:
) -> Optional[WorkflowWithoutID]:
try:
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
return ApiDependencies.invoker.services.images.get_workflow(image_name)
except Exception:
raise HTTPException(status_code=404)
@images_router.get(
@images_router.api_route(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
operation_id="get_image_full",
response_class=Response,
responses={
@ -230,30 +210,24 @@ async def get_image_workflow(
404: {"description": "Image not found"},
},
)
@images_router.head(
"/i/{image_name}/full",
operation_id="get_image_full_head",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name)
with open(path, "rb") as f:
content = f.read()
response = Response(content, media_type="image/png")
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
response = FileResponse(
path,
media_type="image/png",
filename=image_name,
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
return response
except Exception:
raise HTTPException(status_code=404)
@ -273,14 +247,15 @@ async def get_image_full(
)
async def get_image_thumbnail(
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> Response:
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
with open(path, "rb") as f:
content = f.read()
response = Response(content, media_type="image/webp")
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
@ -324,14 +299,16 @@ async def list_image_dtos(
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
)
return image_dtos

View File

@ -3,23 +3,22 @@
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import List, Optional, Type
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordChanges,
UnknownModelException,
@ -30,12 +29,15 @@ from invokeai.backend.model_manager.config import (
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.starter_models import STARTER_MODELS, StarterModel, StarterModelWithoutDependencies
from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
@ -50,13 +52,6 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True)
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Add a cover image URL to a model configuration."""
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
config.cover_image = cover_image
return config
##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
@ -123,7 +118,8 @@ async def list_model_records(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
for model in found_models:
model = add_cover_image_to_model_config(model, ApiDependencies)
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models)
@ -164,13 +160,28 @@ async def get_model_record(
key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig:
"""Get a model record"""
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
return add_cover_image_to_model_config(config, ApiDependencies)
config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary(
# page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@ -283,15 +294,14 @@ async def update_model_record(
installer = ApiDependencies.invoker.services.model_manager.install
try:
record_store.update_model(key, changes=changes)
config = installer.sync_model_path(key)
config = add_cover_image_to_model_config(config, ApiDependencies)
model_response: AnyModelConfig = installer.sync_model_path(key)
logger.info(f"Updated model: {key}")
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return config
return model_response
@model_manager_router.get(
@ -430,11 +440,13 @@ async def delete_model_image(
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
example={"name": "string", "description": "string"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
@ -449,9 +461,8 @@ async def install_model(
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is a ModelRecordChanges object. Fields in this object will override
the ones that are probed automatically. Pass an empty object to accept
all the defaults.
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`access_token` is an optional access token for use with Urls that require
authentication.
@ -486,133 +497,6 @@ async def install_model(
return result
@model_manager_router.get(
"/install/huggingface",
operation_id="install_hugging_face_model",
responses={
201: {"description": "The model is being installed"},
400: {"description": "Bad request"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_class=HTMLResponse,
)
async def install_hugging_face_model(
source: str = Query(description="HuggingFace repo_id to install"),
) -> HTMLResponse:
"""Install a Hugging Face model using a string identifier."""
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
if message:
message = f"<p>{message}</p>"
title_class = "error" if is_error else "success"
return f"""
<html>
<head>
<title>{title}</title>
<style>
body {{
text-align: center;
background-color: hsl(220 12% 10% / 1);
font-family: Helvetica, sans-serif;
color: hsl(220 12% 86% / 1);
}}
.repo-id {{
color: hsl(220 12% 68% / 1);
}}
.error {{
color: hsl(0 42% 68% / 1)
}}
.message-box {{
display: inline-block;
border-radius: 5px;
background-color: hsl(220 12% 20% / 1);
padding-inline-end: 30px;
padding: 20px;
padding-inline-start: 30px;
padding-inline-end: 30px;
}}
.container {{
display: flex;
width: 100%;
height: 100%;
align-items: center;
justify-content: center;
}}
a {{
color: inherit
}}
a:visited {{
color: inherit
}}
a:active {{
color: inherit
}}
</style>
</head>
<body style="background-color: hsl(220 12% 10% / 1);">
<div class="container">
<div class="message-box">
<h2 class="{title_class}">{heading}</h2>
{message}
<p class="repo-id">Repo ID: {repo_id}</p>
</div>
</div>
</body>
</html>
"""
try:
metadata = HuggingFaceMetadataFetch().from_id(source)
assert isinstance(metadata, ModelMetadataWithFiles)
except UnknownMetadataException:
title = "Unable to Install Model"
heading = "No HuggingFace repository found with that repo ID."
message = "Ensure the repo ID is correct and try again."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
if metadata.is_diffusers:
installer.heuristic_import(
source=source,
inplace=False,
)
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
installer.heuristic_import(
source=str(metadata.ckpt_urls[0]),
inplace=False,
)
else:
title = "Unable to Install Model"
heading = "This HuggingFace repo has multiple models."
message = "Please use the Model Manager to install this model."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
title = "Model Install Started"
heading = "Your HuggingFace model is installing now."
message = "You can close this tab and check the Model Manager for installation progress."
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
except Exception as e:
logger.error(str(e))
title = "Unable to Install Model"
heading = "There was an problem installing this model."
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
@model_manager_router.get(
"/install",
operation_id="list_model_installs",
@ -730,54 +614,48 @@ async def convert_model(
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
converted_model = loader.load_model(model_config)
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path) # type: ignore
assert convert_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:
new_key = installer.install_path(
convert_path,
config=ModelRecordChanges(
name=original_name,
description=model_config.description,
hash=model_config.hash,
source=model_config.source,
),
)
except Exception as e:
logger.error(str(e))
store.update_model(key, changes=ModelRecordChanges(name=original_name))
raise HTTPException(status_code=409, detail=str(e))
# Update the model image if the model had one
# loading the model will convert it into a cached diffusers file
try:
model_image = ApiDependencies.invoker.services.model_images.get(key)
ApiDependencies.invoker.services.model_images.save(model_image, new_key)
ApiDependencies.invoker.services.model_images.delete(key)
except ModelImageFileNotFoundException:
pass
cc_size = loader.convert_cache.max_size
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
loader._convert_cache.max_size = 1.0
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
finally:
loader._convert_cache.max_size = cc_size
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# delete the original safetensors file
installer.delete(key)
# delete the temporary directory
# shutil.rmtree(cache_path)
# delete the cached version
shutil.rmtree(cache_path)
# return the config record for the new diffusers directory
new_config = store.get_model(new_key)
new_config = add_cover_image_to_model_config(new_config, ApiDependencies)
new_config: AnyModelConfig = store.get_model(new_key)
return new_config

View File

@ -4,14 +4,12 @@ from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
ClearResult,
EnqueueBatchResult,
PruneResult,
@ -21,6 +19,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from ..dependencies import ApiDependencies
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
@ -106,19 +106,6 @@ async def cancel_by_batch_ids(
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
"/{queue_id}/cancel_by_origin",
operation_id="cancel_by_origin",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_origin(
queue_id: str = Path(description="The queue id to perform this operation on"),
origin: str = Query(description="The origin to cancel all queue items for"),
) -> CancelByOriginResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",
@ -216,7 +203,6 @@ async def get_batch_status(
responses={
200: {"model": SessionQueueItem},
},
response_model_exclude_none=True,
)
async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),

View File

@ -1,274 +0,0 @@
import csv
import io
import json
import traceback
from typing import Optional
import pydantic
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
from fastapi.responses import FileResponse
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
from invokeai.app.services.style_preset_records.style_preset_records_common import (
InvalidPresetImportDataError,
PresetData,
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordWithImage,
StylePresetWithoutId,
UnsupportedFileTypeError,
parse_presets_from_file,
)
class StylePresetFormData(BaseModel):
name: str = Field(description="Preset name")
positive_prompt: str = Field(description="Positive prompt")
negative_prompt: str = Field(description="Negative prompt")
type: PresetType = Field(description="Preset type")
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
@style_presets_router.get(
"/i/{style_preset_id}",
operation_id="get_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def get_style_preset(
style_preset_id: str = Path(description="The style preset to get"),
) -> StylePresetRecordWithImage:
"""Gets a style preset"""
try:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
except StylePresetNotFoundError:
raise HTTPException(status_code=404, detail="Style preset not found")
@style_presets_router.patch(
"/i/{style_preset_id}",
operation_id="update_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def update_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
style_preset_id: str = Path(description="The id of the style preset to update"),
data: str = Form(description="The data of the style preset to update"),
) -> StylePresetRecordWithImage:
"""Updates a style preset"""
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
else:
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
style_preset_id=style_preset_id, changes=changes
)
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
@style_presets_router.delete(
"/i/{style_preset_id}",
operation_id="delete_style_preset",
)
async def delete_style_preset(
style_preset_id: str = Path(description="The style preset to delete"),
) -> None:
"""Deletes a style preset"""
try:
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
except StylePresetImageFileNotFoundException:
pass
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
@style_presets_router.post(
"/",
operation_id="create_style_preset",
responses={
200: {"model": StylePresetRecordWithImage},
},
)
async def create_style_preset(
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
data: str = Form(description="The data of the style preset to create"),
) -> StylePresetRecordWithImage:
"""Creates a style preset"""
try:
parsed_data = json.loads(data)
validated_data = StylePresetFormData(**parsed_data)
name = validated_data.name
type = validated_data.type
positive_prompt = validated_data.positive_prompt
negative_prompt = validated_data.negative_prompt
except pydantic.ValidationError:
raise HTTPException(status_code=400, detail="Invalid preset data")
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
if image is not None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
try:
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
@style_presets_router.get(
"/",
operation_id="list_style_presets",
responses={
200: {"model": list[StylePresetRecordWithImage]},
},
)
async def list_style_presets() -> list[StylePresetRecordWithImage]:
"""Gets a page of style presets"""
style_presets_with_image: list[StylePresetRecordWithImage] = []
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
for preset in style_presets:
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
style_presets_with_image.append(style_preset_with_image)
return style_presets_with_image
@style_presets_router.get(
"/i/{style_preset_id}/image",
operation_id="get_style_preset_image",
responses={
200: {
"description": "The style preset image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The style preset image could not be found"},
},
status_code=200,
)
async def get_style_preset_image(
style_preset_id: str = Path(description="The id of the style preset image to get"),
) -> FileResponse:
"""Gets an image file that previews the model"""
try:
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
response = FileResponse(
path,
media_type="image/png",
filename=style_preset_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
@style_presets_router.get(
"/export",
operation_id="export_style_presets",
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
status_code=200,
)
async def export_style_presets():
# Create an in-memory stream to store the CSV data
output = io.StringIO()
writer = csv.writer(output)
# Write the header
writer.writerow(["name", "prompt", "negative_prompt"])
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
for preset in style_presets:
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
csv_data = output.getvalue()
output.close()
return Response(
content=csv_data,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
)
@style_presets_router.post(
"/import",
operation_id="import_style_presets",
)
async def import_style_presets(file: UploadFile = File(description="The file to import")):
try:
style_presets = await parse_presets_from_file(file)
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
except InvalidPresetImportDataError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=400, detail=str(e))
except UnsupportedFileTypeError as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail=str(e))

View File

@ -1,125 +1,66 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from socketio import ASGIApp, AsyncServer
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadEventBase,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io queue room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk downloads room.
This is a pydantic model to ensure the data is in the correct format."""
bulk_download_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
}
MODEL_EVENTS = {
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
from ..services.events.events_base import EventServiceBase
class SocketIO:
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
__sio: AsyncServer
__app: ASGIApp
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
)
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@ -3,7 +3,9 @@ import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
import torch
import uvicorn
@ -11,18 +13,26 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.api.routers import (
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
from .api.routers import (
app_info,
board_images,
boards,
@ -30,15 +40,15 @@ from invokeai.app.api.routers import (
images,
model_manager,
session_queue,
style_presets,
utilities,
workflows,
)
from invokeai.app.api.sockets import SocketIO
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config()
@ -56,13 +66,11 @@ mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
yield
# Shut down threads
ApiDependencies.shutdown()
@ -109,9 +117,86 @@ app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.include_router(style_presets.style_presets_router, prefix="/api")
app.openapi = get_openapi_func(app)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False)
@ -165,7 +250,6 @@ def invoke_api() -> None:
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_port(port=port + 1)
else:
@ -188,6 +272,8 @@ def invoke_api() -> None:
check_cudnn(logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app=app,
host=app_config.host,

View File

@ -40,7 +40,7 @@ from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.logging import InvokeAILogger
if TYPE_CHECKING:
from invokeai.app.services.invocation_services import InvocationServices
from ..services.invocation_services import InvocationServices
logger = InvokeAILogger.get_logger()
@ -98,13 +98,11 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@ -114,12 +112,11 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
@classmethod
@ -128,13 +125,12 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"])
@classmethod
@ -171,7 +167,6 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
@ -182,17 +177,15 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
@classmethod
@ -228,7 +221,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
@ -244,7 +237,6 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"])
@abstractmethod
@ -318,7 +310,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(),
validate_assignment=True,
json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=False,
json_schema_serialization_defaults_required=True,
coerce_numbers_to_str=True,
)

View File

@ -1,98 +0,0 @@
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
@invocation(
"lblend",
title="Blend Latents",
tags=["latents", "blend"],
category="latents",
version="1.0.3",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
latents_a: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
latents_b: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents_a = context.tensors.load(self.latents_a.latents_name)
latents_b = context.tensors.load(self.latents_b.latents_name)
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
v0: Union[torch.Tensor, npt.NDArray[Any]],
v1: Union[torch.Tensor, npt.NDArray[Any]],
DOT_THRESHOLD: float = 0.9995,
) -> Union[torch.Tensor, npt.NDArray[Any]]:
"""
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
return v2_torch
else:
assert isinstance(v2, np.ndarray)
return v2
# blend
bl = slerp(self.alpha, latents_a, latents_b)
assert isinstance(bl, torch.Tensor)
blended_latents: torch.Tensor = bl # for type checking convenience
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)

View File

@ -4,12 +4,13 @@
import numpy as np
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField
from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField
@invocation(
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"

View File

@ -5,7 +5,6 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
@ -15,7 +14,6 @@ from invokeai.app.invocations.fields import (
TensorField,
UIComponent,
)
from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
@ -28,6 +26,9 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
)
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
# unconditioned: Optional[torch.Tensor]
@ -64,7 +65,11 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
@ -79,25 +84,19 @@ class CompelInvocation(BaseInvocation):
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
tokenizer=patched_tokenizer,
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -107,7 +106,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
log_tokenization_for_conjunction(conjunction, tokenizer)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -137,7 +136,11 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@ -174,28 +177,20 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
cached_weights=cached_weights,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=patched_tokenizer,
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -208,7 +203,7 @@ class SDXLPromptInvocationBase:
if context.config.get().log_tokenization:
# TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
log_tokenization_for_conjunction(conjunction, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@ -1,6 +1,6 @@
from typing import Literal
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
LATENT_SCALE_FACTOR = 8
"""
@ -10,7 +10,8 @@ factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
"""
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

View File

@ -2,7 +2,6 @@
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union
import cv2
@ -21,19 +20,11 @@ from controlnet_aux import (
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator
from transformers import pipeline
from transformers.pipelines import DepthEstimationPipeline
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
UIType,
@ -44,14 +35,15 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
class ControlField(BaseModel):
@ -87,13 +79,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
@ -147,7 +139,6 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@ -173,13 +164,13 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
version="1.3.3",
version="1.3.2",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
)
@ -207,13 +198,13 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# safe not supported in controlnet_aux v0.0.3
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
@ -236,13 +227,13 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image:
@ -258,13 +249,13 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
processor = LineartAnimeProcessor()
@ -281,20 +272,19 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
version="1.2.4",
version="1.2.3",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
@ -313,15 +303,15 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
@ -330,17 +320,17 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.3"
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.2"
)
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
@ -353,17 +343,17 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.3"
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.2"
)
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
@ -380,18 +370,18 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
h: int = InputField(default=512, ge=0, description="Content shuffle `h` parameter")
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
@ -410,12 +400,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
@ -426,17 +416,17 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.2.4",
version="1.2.3",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image,
@ -453,7 +443,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
@ -461,10 +451,10 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
thr_a: float = InputField(default=0, description="Leres parameter `thr_a`")
thr_b: float = InputField(default=0, description="Leres parameter `thr_b`")
boost: bool = InputField(default=False, description="Whether to use boost mode")
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
@ -482,7 +472,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
@ -506,8 +496,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, image: Image.Image) -> Image.Image:
np_img = np.array(image, dtype=np.uint8)
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
@ -522,15 +512,15 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.2.4",
version="1.2.3",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
@ -569,14 +559,14 @@ class SamDetectorReproducibleColors(SamDetector):
title="Color Map Processor",
tags=["controlnet"],
category="controlnet",
version="1.2.3",
version="1.2.2",
)
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image"""
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image) -> Image.Image:
def run_processor(self, image: Image.Image):
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]
@ -593,14 +583,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
return color_map
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
DEPTH_ANYTHING_MODELS = {
"large": "LiheYoung/depth-anything-large-hf",
"base": "LiheYoung/depth-anything-base-hf",
"small": "LiheYoung/depth-anything-small-hf",
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
}
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
@invocation(
@ -608,33 +591,22 @@ DEPTH_ANYTHING_MODELS = {
title="Depth Anything Processor",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
version="1.1.3",
version="1.1.1",
)
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a depth map based on the Depth Anything algorithm"""
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small_v2", description="The size of the depth model to use"
default="small", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
def load_depth_anything(model_path: Path):
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
return DepthAnythingPipeline(depth_anything_pipeline)
def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)
# Resizing to user target specified size
new_height = int(image.size[1] * (self.resolution / image.size[0]))
depth_map = depth_map.resize((self.resolution, new_height))
return depth_map
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
@invocation(
@ -642,7 +614,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
title="DW Openpose Image Processor",
tags=["controlnet", "dwpose", "openpose"],
category="controlnet",
version="1.1.1",
version="1.1.0",
)
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Generates an openpose pose from an image using DWPose"""
@ -650,13 +622,10 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_body: bool = InputField(default=True)
draw_face: bool = InputField(default=False)
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
def run_processor(self, image: Image.Image):
dw_openpose = DWOpenposeDetector()
processed_image = dw_openpose(
image,
draw_face=self.draw_face,
@ -665,27 +634,3 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
resolution=self.image_resolution,
)
return processed_image
@invocation(
"heuristic_resize",
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
version="1.0.1",
classification=Classification.Prototype,
)
class HeuristicResizeInvocation(BaseInvocation):
"""Resize an image using a heuristic method. Preserves edge maps."""
image: ImageField = InputField(description="The image to resize")
width: int = InputField(default=512, ge=1, description="The width to resize to (px)")
height: int = InputField(default=512, ge=1, description="The height to resize to (px)")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
np_img = pil_to_np(image)
np_resized = heuristic_resize(np_img, (self.width, self.height))
resized = np_to_pil(np_resized)
image_dto = context.images.save(image=resized)
return ImageOutput.build(image_dto)

View File

@ -1,80 +0,0 @@
from typing import Optional
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import DenoiseMaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"create_denoise_mask",
title="Create Denoise Mask",
tags=["mask", "denoise"],
category="latents",
version="1.0.2",
)
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=4,
)
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
else:
image_tensor = None
mask = self.prep_mask_tensor(
context.images.get_pil(self.mask.image_name),
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)
else:
masked_latents_name = None
mask_name = context.tensors.save(tensor=mask)
return DenoiseMaskOutput.build(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
gradient=False,
)

View File

@ -1,139 +0,0 @@
from typing import Literal, Optional
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation_output("gradient_mask_output")
class GradientMaskOutput(BaseInvocationOutput):
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
expanded_mask_area: ImageField = OutputField(
description="Image representing the total gradient area of the mask. For paste-back purposes."
)
@invocation(
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.2.0",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
minimum_denoise: float = InputField(
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
)
image: Optional[ImageField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] Image",
ui_order=6,
)
unet: Optional[UNetField] = InputField(
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
default=None,
input=Input.Connection,
title="[OPTIONAL] UNet",
ui_order=5,
)
vae: Optional[VAEField] = InputField(
default=None,
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
title="[OPTIONAL] VAE",
input=Input.Connection,
ui_order=7,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32,
ui_order=9,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
if self.edge_radius > 0:
if self.coherence_mode == "Box Blur":
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
else: # Gaussian Blur OR Staged
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0
threshold = 1 - self.minimum_denoise
if self.coherence_mode == "Staged":
# wherever the blur_tensor is less than fully masked, convert it to threshold
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
else:
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
else:
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
# compute a [0, 1] mask from the blur_tensor
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
expanded_image_dto = context.images.save(expanded_mask_image)
masked_latents_name = None
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
masked_latents_name = context.tensors.save(tensor=masked_latents)
return GradientMaskOutput(
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
)

View File

@ -1,61 +0,0 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
# The Crop Latents node was copied from @skunkworxdark's implementation here:
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
@invocation(
"crop_latents",
title="Crop Latents",
tags=["latents", "crop"],
category="latents",
version="1.0.2",
)
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
# Currently, if the class names conflict then 'GET /openapi.json' fails.
class CropLatentsCoreInvocation(BaseInvocation):
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
divisible by the latent scale factor of 8.
"""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
x: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
y: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
width: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
height: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
cropped_latents = latents[..., y1:y2, x1:x2]
name = context.tensors.save(tensor=cropped_latents)
return LatentsOutput.build(latents_name=name, latents=cropped_latents)

View File

@ -5,11 +5,13 @@ import cv2 as cv
import numpy
from PIL import Image, ImageOps
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.3.1")
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Callable, Optional, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
@ -40,7 +40,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@ -49,8 +48,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
# region Misc Field Types
@ -127,20 +124,16 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
spandrel_image_to_image_model = "Image-to-Image model"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
@ -167,7 +160,6 @@ class FieldDescriptions:
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"
@ -236,12 +228,6 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
@ -253,31 +239,6 @@ class ConditioningField(BaseModel):
)
class BoundingBoxField(BaseModel):
"""A bounding box primitive value."""
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
score: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
"when the bounding box was produced by a detector and has an associated confidence score.",
)
@model_validator(mode="after")
def check_coords(self):
if self.x_min > self.x_max:
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
if self.y_min > self.y_max:
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
return self
class MetadataField(RootModel[dict[str, Any]]):
"""
Pydantic model for metadata with custom root of type dict[str, Any].

View File

@ -1,86 +0,0 @@
from typing import Literal
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@invocation(
"flux_text_encoder",
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a flux image."""
clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
# Load T5.
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
prompt_embeds = t5_encoder(prompt)
with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds

View File

@ -1,172 +0,0 @@
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16
# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)
img, img_ids = prepare_latent_img_patches(x)
is_schnell = "schnell" in transformer_info.config.config_path
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with transformer_info as transformer:
assert isinstance(transformer, Flux)
def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException
# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]
# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format
# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))
# (width, height) = image.size
# width *= 8
# height *= 8
# dataURL = image_to_dataURL(image, image_format="JPEG")
# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )
x = denoise(
model=transformer,
img=img,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
)
x = unpack(x.float(), self.height, self.width)
return x
def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
return img_pil

View File

@ -1,100 +0,0 @@
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
from transformers import pipeline
from transformers.pipelines import ZeroShotObjectDetectionPipeline
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
}
@invocation(
"grounding_dino",
title="Grounding DINO (Text Prompt Object Detection)",
tags=["prompt", "object detection"],
category="image",
version="1.0.0",
)
class GroundingDinoInvocation(BaseInvocation):
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
# Reference:
# - https://arxiv.org/pdf/2303.05499
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
prompt: str = InputField(description="The prompt describing the object to segment.")
image: ImageField = InputField(description="The image to segment.")
detection_threshold: float = InputField(
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
ge=0.0,
le=1.0,
default=0.3,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
# The model expects a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
detections = self._detect(
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
)
# Convert detections to BoundingBoxCollectionOutput.
bounding_boxes: list[BoundingBoxField] = []
for detection in detections:
bounding_boxes.append(
BoundingBoxField(
x_min=detection.box.xmin,
x_max=detection.box.xmax,
y_min=detection.box.ymin,
y_max=detection.box.ymax,
score=detection.score,
)
)
return BoundingBoxCollectionOutput(collection=bounding_boxes)
@staticmethod
def _load_grounding_dino(model_path: Path):
grounding_dino_pipeline = pipeline(
model=str(model_path),
task="zero-shot-object-detection",
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
return GroundingDinoPipeline(grounding_dino_pipeline)
def _detect(
self,
context: InvocationContext,
image: Image.Image,
labels: list[str],
threshold: float = 0.3,
) -> list[DetectionResult]:
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
# actually makes a difference.
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@ -1,65 +0,0 @@
import math
from typing import Tuple
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import UNetField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
@invocation_output("ideal_size_output")
class IdealSizeOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
width: int = OutputField(description="The ideal width of the image (in pixels)")
height: int = OutputField(description="The ideal height of the image (in pixels)")
@invocation(
"ideal_size",
title="Ideal Size",
tags=["latents", "math", "ideal_size"],
version="1.0.3",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "
"initial generation artifacts if too large)",
)
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base == BaseModelType.StableDiffusionXL:
dimension = 1024
dimension = dimension * self.multiplier
min_dimension = math.floor(dimension * 0.5)
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
if aspect > 1.0:
init_height = max(min_dimension, math.sqrt(model_area / aspect))
init_width = init_height * aspect
else:
init_width = max(min_dimension, math.sqrt(model_area * aspect))
init_height = init_width / aspect
scaled_width, scaled_height = self.trim_to_multiple_of(
math.floor(init_width),
math.floor(init_height),
)
return IdealSizeOutput(width=scaled_width, height=scaled_height)

View File

@ -1,24 +1,18 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import Literal, Optional
import cv2
import numpy
from PIL import Image, ImageChops, ImageFilter, ImageOps
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import (
ColorField,
FieldDescriptions,
ImageField,
InputField,
OutputField,
WithBoard,
WithMetadata,
)
@ -28,6 +22,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Classification, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
class ShowImageInvocation(BaseInvocation):
@ -508,7 +504,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Blur NSFW Image",
tags=["image", "nsfw"],
category="image",
version="1.2.3",
version="1.2.2",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add blur to NSFW-flagged images"""
@ -520,12 +516,23 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
logger = context.logger
logger.debug("Running NSFW checker")
image = SafetyChecker.blur_if_nsfw(image)
if SafetyChecker.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution, (0, 0), caution)
image = blurry_image
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
def _get_caution_img(self) -> Image.Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))
@invocation(
"img_watermark",
@ -1013,62 +1020,3 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@invocation_output("canvas_v2_mask_and_crop_output")
class CanvasV2MaskAndCropOutput(ImageOutput):
offset_x: int = OutputField(description="The x offset of the image, after cropping")
offset_y: int = OutputField(description="The y offset of the image, after cropping")
@invocation(
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
source_image: ImageField | None = InputField(
default=None,
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
)
generated_image: ImageField = InputField(description="The image to apply the mask to")
mask: ImageField = InputField(description="The mask to apply")
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
if self.source_image:
generated_image = context.images.get_pil(self.generated_image.image_name)
source_image = context.images.get_pil(self.source_image.image_name)
source_image.paste(generated_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
else:
generated_image = context.images.get_pil(self.generated_image.image_name)
generated_image.putalpha(mask)
image_dto = context.images.save(image=generated_image)
# bbox = image.getbbox()
# image = image.crop(bbox)
return CanvasV2MaskAndCropOutput(
image=ImageField(image_name=image_dto.image_name),
offset_x=0,
offset_y=0,
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,143 +0,0 @@
from contextlib import nullcontext
from functools import singledispatchmethod
import einops
import torch
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation(
"i2l",
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.1.0",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
image: ImageField = InputField(
description="The image to encode",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
# else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
# latents = latents.half()
if tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=tile_size,
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@singledispatchmethod
@staticmethod
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents: torch.Tensor = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
return latents
@_encode_to_tensor.register
@staticmethod
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert isinstance(vae, torch.nn.Module)
latents: torch.FloatTensor = vae.encode(image_tensor).latents
return latents

View File

@ -3,9 +3,7 @@ from typing import Literal, get_args
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
@ -16,6 +14,10 @@ from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, in
from invokeai.backend.image_util.infill_methods.tile import infill_tile
from invokeai.backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
logger = InvokeAILogger.get_logger()
@ -40,16 +42,15 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infill the image with the specified method"""
pass
def load_image(self) -> tuple[Image.Image, bool]:
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled"""
image = self._context.images.get_pil(self.image.image_name)
image = context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
# Retrieve and process image to be infilled
input_image, has_alpha = self.load_image()
input_image, has_alpha = self.load_image(context)
# If the input image has no alpha channel, return it
if has_alpha is False:
@ -132,12 +133,8 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
return lama(image)
lama = LaMA()
return lama(image)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@ -67,6 +67,7 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)

File diff suppressed because it is too large Load Diff

View File

@ -1,121 +0,0 @@
from contextlib import nullcontext
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@invocation(
"l2i",
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.3.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if self.tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=self.tile_size,
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@ -1,10 +1,9 @@
import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
@invocation(
@ -119,27 +118,3 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
height=mask.shape[1],
width=mask.shape[2],
)
@invocation(
"tensor_mask_to_image",
title="Tensor Mask to Image",
tags=["mask"],
category="mask",
version="1.0.0",
)
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Convert a mask tensor to an image."""
mask: TensorField = InputField(description="The mask tensor to convert.")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.tensors.load(self.mask.tensor_name)
# Ensure that the mask is binary.
if mask.dtype != torch.bool:
mask = mask > 0.5
mask_np = (mask.float() * 255).byte().cpu().numpy()
mask_pil = Image.fromarray(mask_np, mode="L")
image_dto = context.images.save(image=mask_pil)
return ImageOutput.build(image_dto)

View File

@ -5,11 +5,12 @@ from typing import Literal
import numpy as np
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, InputField
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.1")
class AddInvocation(BaseInvocation):

View File

@ -14,7 +14,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from invokeai.version.invokeai_version import __version__
from ...version import __version__
class MetadataItemField(BaseModel):

View File

@ -1,25 +1,18 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
@ -67,15 +60,6 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@ -109,152 +93,19 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass
@invocation_output("model_identifier_output")
class ModelIdentifierOutput(BaseInvocationOutput):
"""Model identifier output"""
model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")
@invocation(
"model_identifier",
title="Model identifier",
tags=["model"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
error."""
model: ModelIdentifierField = InputField(description="The model to select", title="Model")
def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.3",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key
if not context.models.exists(model_key):
raise ValueError(f"Unknown model: {model_key}")
transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case SubModelType.VAE:
return self._pull_model_from_mm(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
ModelType.VAE,
BaseModelType.Flux,
)
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._pull_model_from_mm(
context,
submodel,
"clip-vit-large-patch14",
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._pull_model_from_mm(
context,
submodel,
self.t5_encoder.name,
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
def _pull_model_from_mm(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
@invocation(
"main_model_loader",
title="Main Model",
tags=["model"],
category="model",
version="1.0.3",
version="1.0.2",
)
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@ -283,12 +134,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2")
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@ -339,75 +190,6 @@ class LoRALoaderInvocation(BaseInvocation):
return output
@invocation_output("lora_selector_output")
class LoRASelectorOutput(BaseInvocationOutput):
"""Model loader output"""
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> LoRASelectorOutput:
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
output = LoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base in (BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2)
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output
@invocation_output("sdxl_lora_loader_output")
class SDXLLoRALoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output"""
@ -422,13 +204,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA",
tags=["lora", "model"],
category="model",
version="1.0.3",
version="1.0.2",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
@ -497,78 +279,12 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
return output
@invocation(
"sdxl_lora_collection_loader",
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.0.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
clip2: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
output = SDXLLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base is BaseModelType.StableDiffusionXL
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.clip2 is not None:
if output.clip2 is None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(lora)
return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
)
def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@ -4,12 +4,18 @@
import torch
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.util.devices import TorchDevice
from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
"""
Utilities

View File

@ -39,11 +39,12 @@ from easing_functions import (
)
from matplotlib.ticker import MaxNLocator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField
from invokeai.app.invocations.primitives import FloatCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField
@invocation(
"float_range",

View File

@ -4,15 +4,12 @@ from typing import Optional
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
BoundingBoxField,
ColorField,
ConditioningField,
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
ImageField,
Input,
InputField,
@ -24,6 +21,13 @@ from invokeai.app.invocations.fields import (
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
"""
Primitives: Boolean, Integer, Float, String, Image, Latents, Conditioning, Color
- primitive nodes
@ -415,17 +419,6 @@ class MaskOutput(BaseInvocationOutput):
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("flux_conditioning_output")
class FluxConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""
@ -482,42 +475,3 @@ class ConditioningCollectionInvocation(BaseInvocation):
# endregion
# region BoundingBox
@invocation_output("bounding_box_output")
class BoundingBoxOutput(BaseInvocationOutput):
"""Base class for nodes that output a single bounding box"""
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
@invocation_output("bounding_box_collection_output")
class BoundingBoxCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of bounding boxes"""
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
@invocation(
"bounding_box",
title="Bounding Box",
tags=["primitives", "segmentation", "collection", "bounding box"],
category="primitives",
version="1.0.0",
)
class BoundingBoxInvocation(BaseInvocation):
"""Create a bounding box manually by supplying box coordinates"""
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
return BoundingBoxOutput(bounding_box=bounding_box)
# endregion

View File

@ -5,11 +5,12 @@ import numpy as np
from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPromptGenerator
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import InputField, UIComponent
from invokeai.app.invocations.primitives import StringCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, UIComponent
@invocation(
"dynamic_prompt",

View File

@ -1,103 +0,0 @@
from typing import Literal
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
)
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
@invocation(
"lresize",
title="Resize Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
width: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@invocation(
"lscale",
title="Scale Latents",
tags=["latents", "resize"],
category="latents",
version="1.0.2",
)
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)

View File

@ -1,34 +0,0 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
OutputField,
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput):
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation(
"scheduler",
title="Scheduler",
tags=["scheduler"],
category="latents",
version="1.0.0",
)
class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler."""
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
def invoke(self, context: InvocationContext) -> SchedulerOutput:
return SchedulerOutput(scheduler=self.scheduler)

View File

@ -1,9 +1,15 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField
@invocation_output("sdxl_model_loader_output")
class SDXLModelLoaderOutput(BaseInvocationOutput):
@ -24,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
@ -61,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.3",
version="1.0.2",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?

View File

@ -1,161 +0,0 @@
from pathlib import Path
from typing import Literal
import numpy as np
import torch
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
"segment-anything-base": "facebook/sam-vit-base",
"segment-anything-large": "facebook/sam-vit-large",
"segment-anything-huge": "facebook/sam-vit-huge",
}
@invocation(
"segment_anything",
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.0.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
# Reference:
# - https://arxiv.org/pdf/2304.02643
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
image: ImageField = InputField(description="The image to segment.")
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
default=True,
)
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
description="The filtering to apply to the detected masks before merging them into a final output.",
default="all",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if len(self.bounding_boxes) == 0:
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else:
masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
mask_tensor_name = context.tensors.save(combined_mask)
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
def _segment(
self,
context: InvocationContext,
image: Image.Image,
) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks)
return masks
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
"""Convert the tensor output from the Segment Anything model from a tensor of shape
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
"""
assert masks.dtype == torch.bool
# [num_masks, channels, height, width] -> [num_masks, height, width]
masks, _ = masks.max(dim=1)
# Split the first dimension into a list of masks.
return list(masks.cpu().unbind(dim=0))
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply polygon refinement to the masks.
Convert each mask to a polygon, then back to a mask. This has the following effect:
- Smooth the edges of the mask slightly.
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces.
- Removes holes from the mask.
"""
# Convert tensor masks to np masks.
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# Apply polygon refinement.
for idx, mask in enumerate(np_masks):
shape = mask.shape
assert len(shape) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
np_masks[idx] = mask
# Convert np masks back to tensor masks.
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
return masks
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
"""Filter the detected masks based on the specified mask filter."""
assert len(masks) == len(bounding_boxes)
if self.mask_filter == "all":
return masks
elif self.mask_filter == "largest":
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
return [masks[max_score_idx]]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")

View File

@ -1,253 +0,0 @@
from typing import Callable
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
image: ImageField = InputField(description="The input image")
image_to_image_model: ModelIdentifierField = InputField(
title="Image-to-Image Model",
description=FieldDescriptions.spandrel_image_to_image_model,
ui_type=UIType.SpandrelImageToImageModel,
)
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
)
@classmethod
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
return Tile(
coords=TBLR(
top=tile.coords.top * scale,
bottom=tile.coords.bottom * scale,
left=tile.coords.left * scale,
right=tile.coords.right * scale,
),
overlap=TBLR(
top=tile.overlap.top * scale,
bottom=tile.overlap.bottom * scale,
left=tile.overlap.left * scale,
right=tile.overlap.right * scale,
),
)
@classmethod
def upscale_image(
cls,
image: Image.Image,
tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
) -> Image.Image:
# Compute the image tiles.
if tile_size > 0:
min_overlap = 20
tiles = calc_tiles_min_overlap(
image_height=image.height,
image_width=image.width,
tile_height=tile_size,
tile_width=tile_size,
min_overlap=min_overlap,
)
else:
# No tiling. Generate a single tile that covers the entire image.
min_overlap = 0
tiles = [
Tile(
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
# over tiles left-to-right, top-to-bottom.
tiles = sorted(tiles, key=lambda x: x.coords.left)
tiles = sorted(tiles, key=lambda x: x.coords.top)
# Prepare input image for inference.
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
# Scale the tiles for re-assembling the final image.
scale = spandrel_model.scale
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
# Prepare the output tensor.
_, channels, height, width = image_tensor.shape
output_tensor = torch.zeros(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
# Convert the output tile into the output tensor's format.
# (N, C, H, W) -> (C, H, W)
output_tile = output_tile.squeeze(0)
# (C, H, W) -> (H, W, C)
output_tile = output_tile.permute(1, 2, 0)
output_tile = output_tile.clamp(0, 1)
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
# Merge the output tile into the output tensor.
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
# it seems unnecessary, but we may find a need in the future.
top_overlap = scaled_tile.overlap.top // 2
left_overlap = scaled_tile.overlap.left // 2
output_tensor[
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
:,
] = output_tile[top_overlap:, left_overlap:, :]
# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
return pil_image
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@invocation(
"spandrel_image_to_image_autoscale",
title="Image-to-Image (Autoscale)",
tags=["upscale"],
category="upscale",
version="1.0.0",
)
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
scale: float = InputField(
default=4.0,
gt=0.0,
le=16.0,
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
)
fit_to_multiple_of_8: bool = InputField(
default=False,
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale)
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
# to be considered an upscale model.
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
iterations += 1
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit.
if iterations >= 5:
context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
)
break
else:
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
# to be the same as the processed image size.
# The output size is now the size of the processed image.
target_width = pil_image.width
target_height = pil_image.height
# Warn the user if they requested a scale greater than 1.
if self.scale > 1:
context.logger.warning(
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
)
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
# in the final resize
if self.fit_to_multiple_of_8:
target_width = int(target_width // 8 * 8)
target_height = int(target_height // 8 * 8)
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)

View File

@ -2,11 +2,17 @@
import re
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import InputField, OutputField, UIComponent
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .fields import InputField, OutputField, UIComponent
from .primitives import StringOutput
@invocation_output("string_pos_neg_output")
class StringPosNegOutput(BaseInvocationOutput):

View File

@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
@ -55,6 +55,7 @@ class T2IAdapterInvocation(BaseInvocation):
t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.T2IAdapterModel,
)

View File

@ -1,287 +0,0 @@
import copy
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
LatentsField,
UIType,
)
from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
from invokeai.backend.tiles.utils import TBLR
from invokeai.backend.util.devices import TorchDevice
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
"""Crop a ControlNetData object to a region."""
# Create a shallow copy of the control_data object.
control_data_copy = copy.copy(control_data)
# The ControlNet reference image is the only attribute that needs to be cropped.
control_data_copy.image_tensor = control_data.image_tensor[
:,
:,
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
]
return control_data_copy
@invocation(
"tiled_multi_diffusion_denoise_latents",
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
different parameters for each region to harness the full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField(
default=32,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
"amount.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
control: ControlField | list[ControlField] | None = InputField(
default=None,
input=Input.Connection,
)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def create_pipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
) -> MultiDiffusionPipeline:
# TODO(ryand): Get rid of this FakeVae hack.
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(),
text_encoder=None,
tokenizer=None,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
# - How much overlap 'context' to provide for each denoising step.
# - How much overlap to use during merging/blending.
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
# Get the unet's config so that we can pass the base to sd_step_callback().
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Prepare the MultiDiffusionRegionConditioning list.
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=latent_tile_overlap,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
callback=step_callback,
)
result_latents = result_latents.to("cpu")
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

View File

@ -1,4 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal
import cv2
@ -6,12 +7,16 @@ import numpy as np
from PIL import Image
from pydantic import ConfigDict
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
# TODO: Populate this from disk?
# TODO: Use model manager to load?
@ -47,6 +52,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None
netscale = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@ -89,25 +95,28 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg)
raise ValueError(msg)
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
)
with loadnet as loadnet_model:
upscaler = RealESRGAN(
scale=netscale,
loadnet=loadnet_model,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
upscaler = RealESRGAN(
scale=netscale,
model_path=esrgan_model_path,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
TorchDevice.empty_cache()
image_dto = context.images.save(image=pil_image)

View File

@ -2,11 +2,12 @@ import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .board_image_records_base import BoardImageRecordStorageBase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_conn: sqlite3.Connection

View File

@ -1,8 +1,9 @@
from typing import Optional
from invokeai.app.services.board_images.board_images_base import BoardImagesServiceABC
from invokeai.app.services.invoker import Invoker
from .board_images_base import BoardImagesServiceABC
class BoardImagesService(BoardImagesServiceABC):
__invoker: Invoker

View File

@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from .board_records_common import BoardChanges, BoardRecord
class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@ -39,12 +40,16 @@ class BoardRecordStorageBase(ABC):
@abstractmethod
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
def get_all(
self,
) -> list[BoardRecord]:
"""Gets all board records."""
pass

View File

@ -22,10 +22,6 @@ class BoardRecord(BaseModelExcludeNull):
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
"""The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
"""Whether the board is private."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
@ -39,8 +35,6 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
is_private = board_dict.get("is_private", False)
return BoardRecord(
board_id=board_id,
@ -49,15 +43,12 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
archived=archived,
is_private=is_private,
)
class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.")
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
class BoardRecordNotFoundException(Exception):

View File

@ -2,8 +2,12 @@ import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.board_records.board_records_common import (
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
from .board_records_base import BoardRecordStorageBase
from .board_records_common import (
BoardChanges,
BoardRecord,
BoardRecordDeleteException,
@ -11,9 +15,6 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecordSaveException,
deserialize_board_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
@ -124,17 +125,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
self._cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@ -144,49 +134,35 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
return self.get(board_id)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Build base query
base_query = """
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
)
count = cast(int, self._cursor.fetchone()[0])
@ -198,25 +174,20 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
finally:
self._lock.release()
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
def get_all(
self,
) -> list[BoardRecord]:
try:
self._lock.acquire()
base_query = """
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

View File

@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from .boards_common import BoardDTO
class BoardServiceABC(ABC):
"""High-level service for board management."""
@ -43,12 +44,16 @@ class BoardServiceABC(ABC):
@abstractmethod
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
def get_all(
self,
) -> list[BoardDTO]:
"""Gets all boards."""
pass

View File

@ -2,7 +2,7 @@ from typing import Optional
from pydantic import Field
from invokeai.app.services.board_records.board_records_common import BoardRecord
from ..board_records.board_records_common import BoardRecord
class BoardDTO(BoardRecord):

View File

@ -1,9 +1,11 @@
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from .boards_base import BoardServiceABC
from .boards_common import board_record_to_dto
class BoardService(BoardServiceABC):
__invoker: Invoker
@ -46,10 +48,8 @@ class BoardService(BoardServiceABC):
def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@ -63,8 +63,8 @@ class BoardService(BoardServiceABC):
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
def get_all(self) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)

View File

@ -4,7 +4,6 @@ from typing import Optional, Union
from zipfile import ZipFile
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_base import BulkDownloadBase
from invokeai.app.services.bulk_download.bulk_download_common import (
DEFAULT_BULK_DOWNLOAD_ID,
BulkDownloadException,
@ -16,6 +15,8 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string
from .bulk_download_base import BulkDownloadBase
class BulkDownloadService(BulkDownloadBase):
def start(self, invoker: Invoker) -> None:
@ -105,7 +106,9 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_completed(
@ -115,8 +118,10 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_failed(
@ -126,8 +131,11 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_error(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
)
def stop(self, *args, **kwargs):

View File

@ -1,6 +1,7 @@
"""Init file for InvokeAI configure package."""
from invokeai.app.services.config.config_common import PagingArgumentParser
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
from .config_default import InvokeAIAppConfig, get_config
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import copy
import locale
import os
import re
@ -21,18 +20,21 @@ import invokeai.configs as model_configs
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_migrate import ConfigMigrator
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@ -85,13 +87,11 @@ class InvokeAIAppConfig(BaseSettings):
log_tokenization: Enable logging of parsed prompt tokens.
patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).
download_cache_dir: Path to the directory that contains dynamically downloaded models.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
style_presets_dir: Path to directory for style presets.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
@ -103,6 +103,7 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
@ -113,7 +114,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@ -148,13 +148,11 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
@ -171,8 +169,9 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
@ -187,7 +186,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
@ -302,21 +300,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the models directory, resolved to an absolute path.."""
return self._resolve(self.models_dir)
@property
def style_presets_path(self) -> Path:
"""Path to the style presets directory, resolved to an absolute path.."""
return self._resolve(self.style_presets_dir)
@property
def convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path.."""
return self._resolve(self.convert_cache_dir)
@property
def download_cache_path(self) -> Path:
"""Path to the downloaded models directory, resolved to an absolute path.."""
return self._resolve(self.download_cache_dir)
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory, resolved to an absolute path.."""
@ -362,84 +350,6 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a v4.0.0.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An `InvokeAIAppConfig` config dict.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
parsed_config_dict["schema_version"] = "4.0.0"
return parsed_config_dict
def migrate_v4_0_0_to_4_0_1_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to a v4.0.1 config dictionary
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A config dict with the settings migrated to v4.0.1.
"""
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# precision "autocast" was replaced by "auto" in v4.0.1
if parsed_config_dict.get("precision") == "autocast":
parsed_config_dict["precision"] = "auto"
parsed_config_dict["schema_version"] = "4.0.1"
return parsed_config_dict
def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.1 config dictionary to a v4.0.2 config dictionary.
Args:
config_dict: A dictionary of settings from a v4.0.1 config file.
Returns:
An config dict with the settings migrated to v4.0.2.
"""
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# convert_cache was removed in 4.0.2
parsed_config_dict.pop("convert_cache", None)
parsed_config_dict["schema_version"] = "4.0.2"
return parsed_config_dict
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@ -451,37 +361,24 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict: dict[str, Any] = yaml.safe_load(file)
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
migrated = False
if "InvokeAI" in loaded_config_dict:
migrated = True
loaded_config_dict = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
if loaded_config_dict["schema_version"] == "4.0.0":
migrated = True
loaded_config_dict = migrate_v4_0_0_to_4_0_1_config_dict(loaded_config_dict)
if loaded_config_dict["schema_version"] == "4.0.1":
migrated = True
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
if migrated:
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# load and write without environment variables
migrated_config = DefaultInvokeAIAppConfig.model_validate(loaded_config_dict)
migrated_config.write_file(config_path)
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@ -531,6 +428,7 @@ def get_config() -> InvokeAIAppConfig:
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
@ -539,3 +437,73 @@ def get_config() -> InvokeAIAppConfig:
default_config.write_file(config.config_file_path, as_example=False)
return config
####################################################
# VERSION MIGRATIONS
####################################################
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
return parsed_config_dict

View File

@ -0,0 +1,89 @@
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Utility class for migrating among versions of the InvokeAI app config schema.
"""
from dataclasses import dataclass
from typing import Any, Callable, List, TypeAlias
from packaging.version import Version
AppConfigDict: TypeAlias = dict[str, Any]
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class MigrationEntry:
"""Defines an individual migration."""
from_version: Version
to_version: Version
function: MigrationFunction
class ConfigMigrator:
"""This class allows migrators to register their input and output versions."""
_migrations: List[MigrationEntry] = []
@classmethod
def register(
cls,
from_version: str,
to_version: str,
) -> Callable[[MigrationFunction], MigrationFunction]:
"""Define a decorator which registers the migration between two versions."""
def decorator(function: MigrationFunction) -> MigrationFunction:
if any(from_version == m.from_version for m in cls._migrations):
raise ValueError(
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
)
cls._migrations.append(
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
)
return function
return decorator
@staticmethod
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None:
current_version = Version("3.0.0")
for m in migrations:
if current_version != m.from_version:
raise ValueError(
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
)
current_version = m.to_version
@classmethod
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
"""
Use the registered migration steps to bring config up to latest version.
:param config: The original configuration.
:return: The new configuration, lifted up to the latest version.
As a side effect, the new configuration will be written to disk.
If an inconsistency in the registered migration steps' `from_version`
and `to_version` parameters are identified, this will raise a
ValueError exception.
"""
# Sort migrations by version number and raise a ValueError if
# any version range overlaps are detected.
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
cls._check_for_discontinuities(sorted_migrations)
if "InvokeAI" in config_dict:
version = Version("3.0.0")
else:
version = Version(config_dict["schema_version"])
for migration in sorted_migrations:
if version == migration.from_version and version < migration.to_version:
config_dict = migration.function(config_dict)
version = migration.to_version
config_dict["schema_version"] = str(version)
return config_dict

View File

@ -1,17 +1,10 @@
"""Init file for download queue."""
from invokeai.app.services.download.download_base import (
DownloadJob,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
UnknownJobIDException,
)
from invokeai.app.services.download.download_default import DownloadQueueService, TqdmProgress
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
"MultiFileDownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",

View File

@ -5,13 +5,11 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional, Set, Union
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.metadata import RemoteModelFile
class DownloadJobStatus(str, Enum):
"""State of a download job."""
@ -35,23 +33,30 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
DownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads."""
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
# set internally during download process
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
@ -69,6 +74,14 @@ class DownloadJobBase(BaseModel):
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
@ -85,11 +98,6 @@ class DownloadJobBase(BaseModel):
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
@property
def waiting(self) -> bool:
"""Return true if the job is waiting to run."""
return self.status == DownloadJobStatus.WAITING
@property
def running(self) -> bool:
"""Return true if the job is running."""
@ -146,37 +154,6 @@ class DownloadJobBase(BaseModel):
self._on_cancelled = on_cancelled
@total_ordering
class DownloadJob(DownloadJobBase):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
class MultiFileDownloadJob(DownloadJobBase):
"""Class to monitor and control multifile downloads."""
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@ -224,48 +201,6 @@ class DownloadQueueServiceBase(ABC):
"""
pass
@abstractmethod
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
"""
Create and enqueue a multifile download job.
:param parts: Set of URL / filename pairs
:param dest: Path to download to. See below.
:param access_token: Access token to download the indicated files. If not provided,
each file's URL may be matched to an access token using the config file matching
system.
:param submit_job: If true [default] then submit the job for execution. Otherwise,
you will need to pass the job to submit_multifile_download().
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object pointing to a directory. All downloads
with be placed inside this directory. The callbacks will receive the
MultiFileDownloadJob.
"""
pass
@abstractmethod
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
"""
Enqueue a previously-created multi-file download job.
:param job: A MultiFileDownloadJob created with multifile_download()
"""
pass
@abstractmethod
def submit_download_job(
self,
@ -317,7 +252,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def cancel_job(self, job: DownloadJobBase) -> None:
def cancel_job(self, job: DownloadJob) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@ -327,7 +262,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,

View File

@ -8,30 +8,27 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Literal, Optional, Set
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.download.download_base import (
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadExceptionHandler,
DownloadJob,
DownloadJobBase,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
ServiceInactiveException,
UnknownJobIDException,
)
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.model_manager.metadata import RemoteModelFile
from invokeai.backend.util.logging import InvokeAILogger
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
@ -43,24 +40,20 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional["EventServiceBase"] = None,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event()
self._job_terminated_event = threading.Event()
self._job_completed_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@ -112,16 +105,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
job.id = self._next_id()
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
with self._lock:
job.id = self._next_job_id
self._next_job_id += 1
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
def download(
self,
@ -144,7 +139,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source,
dest=dest,
priority=priority,
access_token=access_token or self._lookup_access_token(source),
access_token=access_token,
)
self.submit_download_job(
job,
@ -156,63 +151,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
return job
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
mfdj.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
for part in parts:
url = part.url
path = dest / part.path
assert path.is_relative_to(dest), "only relative download paths accepted"
job = DownloadJob(
source=url,
dest=path,
access_token=access_token or self._lookup_access_token(url),
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
if submit_job:
self.submit_multifile_download(mfdj)
return mfdj
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
for download_job in job.download_parts:
self.submit_download_job(
download_job,
on_start=self._mfd_started,
on_progress=self._mfd_progress,
on_complete=self._mfd_complete,
on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error,
)
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
@ -234,14 +176,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJobBase) -> None:
def cancel_job(self, job: DownloadJob) -> None:
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
with self._lock:
job.cancel()
def cancel_all_jobs(self) -> None:
@ -250,12 +192,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
self._job_terminated_event.clear()
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
@ -284,25 +226,22 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
except Exception as excp:
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
finally:
job.job_ended = get_iso_timestamp()
self._job_terminated_event.set() # signal a change to terminal state
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
self._job_terminated_event.set()
self._job_completed_event.set() # signal a change to terminal state
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
@ -379,8 +318,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
in_progress_path.rename(job.download_path)
def _validate_filename(self, directory: str, filename: str) -> bool:
pc_name_max = get_pc_name_max(directory)
pc_path_max = get_pc_path_max(directory)
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
pc_path_max = (
os.pathconf(directory, "PC_PATH_MAX") if hasattr(os, "pathconf") else 32767
) # hardcoded for windows with long names enabled
if "/" in filename:
return False
if filename.startswith(".."):
@ -394,53 +335,79 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
self._execute_cb(job, "on_start")
if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_started(job)
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None:
self._execute_cb(job, "on_progress")
if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_progress(job)
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
self._execute_cb(job, "on_complete")
if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_complete(job)
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return
job.status = DownloadJobStatus.CANCELLED
self._execute_cb(job, "on_cancelled")
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_cancelled(job)
# if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.CANCELLED
self._execute_cb(parent_job, "on_cancelled")
self._event_bus.emit_download_cancelled(str(job.source))
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
self._execute_cb(job, "on_error", excp)
if job.on_error:
try:
job.on_error(job, excp)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_error(job)
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
@ -451,117 +418,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
except OSError as excp:
self._logger.warning(excp)
########################################
# callbacks used for multifile downloads
########################################
def _mfd_started(self, download_job: DownloadJob) -> None:
self._logger.info(f"File download started: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
assert download_job.download_path is not None
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
mf_job.download_path = (
mf_job.dest / path_relative_to_destdir.parts[0]
) # keep just the first component of the path
self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.cancelled:
for part in mf_job.download_parts:
self.cancel_job(part)
elif mf_job.running:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
self._execute_cb(mf_job, "on_progress")
def _mfd_complete(self, download_job: DownloadJob) -> None:
self._logger.info(f"Download complete: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
# are there any more active jobs left in this task?
if mf_job.running and all(x.complete for x in mf_job.download_parts):
mf_job.status = DownloadJobStatus.COMPLETED
self._execute_cb(mf_job, "on_complete")
# we're done with this sub-job
self._job_terminated_event.set()
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
self._logger.warning(f"Download cancelled: {download_job.source}")
mf_job.cancel()
for s in mf_job.download_parts:
self.cancel_job(s)
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
mf_job.status = download_job.status
mf_job.error = download_job.error
mf_job.error_type = download_job.error_type
self._execute_cb(mf_job, "on_error", excp)
self._logger.error(
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
)
for s in [x for x in mf_job.download_parts if x.running]:
self.cancel_job(s)
self._download_part2parent.pop(download_job.source)
self._job_terminated_event.set()
def _execute_cb(
self,
job: DownloadJob | MultiFileDownloadJob,
callback_name: Literal[
"on_start",
"on_progress",
"on_complete",
"on_cancelled",
"on_error",
],
excp: Optional[Exception] = None,
) -> None:
if callback := getattr(job, callback_name, None):
args = [job, excp] if excp else [job]
try:
callback(*args)
except Exception as e:
self._logger.error(
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
)
def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"):
try:
return os.pathconf(directory, "PC_NAME_MAX")
except OSError:
# macOS w/ external drives raise OSError
pass
return 260 # hardcoded for windows
def get_pc_path_max(directory: str) -> int:
if hasattr(os, "pathconf"):
try:
return os.pathconf(directory, "PC_PATH_MAX")
except OSError:
# some platforms may not have this value
pass
return 32767 # hardcoded for windows with long names enabled
# Example on_progress event handler to display a TQDM status bar
# Activate with:

View File

@ -1,199 +1,486 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import TYPE_CHECKING, Optional
from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event: "EventBase") -> None:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
# region: Invocation
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def emit_invocation_denoise_progress(
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation is complete"""
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str,
error_message: str,
error_traceback: str,
error: str,
) -> None:
"""Emitted when an invocation encounters an error"""
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
},
)
# endregion
def emit_invocation_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
# region Queue
def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_queue_item_status_changed(
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
self,
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None:
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
self.__emit_queue_event(
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
# endregion
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
# region Download
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_started(self, job: "DownloadJob") -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(job))
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
def emit_download_progress(self, job: "DownloadJob") -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(job))
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, job: "DownloadJob") -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(job))
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
def emit_download_cancelled(self, job: "DownloadJob") -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(job))
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_error(self, job: "DownloadJob") -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(job))
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
# endregion
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
# region Model loading
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
def emit_model_install_downloading(
self,
source: str,
local_path: str,
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
"""
Emit at intervals while the install job is in progress (remote models only).
# endregion
:param source: Source of the model
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
# region Model install
def emit_model_install_downloads_done(self, source: str) -> None:
"""
Emit once when all parts are downloaded, but before the probing and registration start.
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is started (remote models only)."""
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
"""Emitted once when an install job is started (after any download)."""
self.dispatch(ModelInstallStartedEvent.build(job))
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is completed successfully."""
self.dispatch(ModelInstallCompleteEvent.build(job))
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is cancelled."""
self.dispatch(ModelInstallCancelledEvent.build(job))
def emit_model_install_cancelled(self, source: str, id: int) -> None:
"""
Emit when an install job is cancelled.
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job encounters an exception."""
self.dispatch(ModelInstallErrorEvent.build(job))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
# endregion
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
"""
Emit when an install job encounters an exception.
# region Bulk image download
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is started"""
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is complete"""
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
# endregion
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
)

View File

@ -1,634 +0,0 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol, Generic[TEvent]):
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
"""Register a function to handle specific events.
:param events: An event or set of events to handle
:param func: The function to handle the events
"""
events = events if isinstance(events, set) else {events}
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
class QueueEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
origin: str | None = Field(default=None, description="The origin of the batch")
class InvocationEventBase(QueueItemEventBase):
"""Base class for invocation events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
)
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
__event_name__ = "invocation_complete"
result: AnyInvocationOutput = Field(description="The result of the invocation")
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
result=result,
)
@payload_schema.register
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The error type")
error_message: str = Field(description="The error message")
error_traceback: str = Field(description="The error traceback")
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
error_type: str,
error_message: str,
error_traceback: str,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
@payload_schema.register
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
@classmethod
def build(
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
origin=queue_item.origin,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
)
@payload_schema.register
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
origin=enqueue_result.batch.origin,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
source: str = Field(description="The source of the download")
@payload_schema.register
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
@payload_schema.register
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadStartedEvent(ModelEventBase):
"""Event model for model_install_download_started"""
__event_name__ = "model_install_download_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
)

View File

@ -1,44 +0,0 @@
import asyncio
import threading
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import EventBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
self.event_handler_id = event_handler_id
self._queue = asyncio.Queue[EventBase | None]()
self._stop_event = threading.Event()
self._loop = loop
# We need to store a reference to the task so it doesn't get GC'd
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
self._background_tasks: set[asyncio.Task[None]] = set()
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
def dispatch(self, event: EventBase) -> None:
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = await self._queue.get()
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -4,6 +4,9 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@ -30,9 +33,8 @@ class ImageFileStorageBase(ABC):
self,
image: PILImageType,
image_name: str,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -44,11 +46,6 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[str]:
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets the workflow of an image."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets the graph of an image."""
pass

View File

@ -1,30 +1,36 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from pathlib import Path
from queue import Queue
from typing import Optional, Union
from typing import Dict, Optional, Union
from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_files.image_files_common import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
)
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase
from .image_files_common import ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[Path, PILImageType]
__max_cache_size: int
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
self.__cache: dict[Path, PILImageType] = {}
self.__cache_ids = Queue[Path]()
self.__cache = {}
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / "thumbnails"
# Validate required output folders at launch
self.__validate_storage_folders()
@ -50,9 +56,8 @@ class DiskImageFileStorage(ImageFileStorageBase):
self,
image: PILImageType,
image_name: str,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -63,14 +68,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
info_dict = {}
if metadata is not None:
info_dict["invokeai_metadata"] = metadata
pnginfo.add_text("invokeai_metadata", metadata)
metadata_json = metadata.model_dump_json()
info_dict["invokeai_metadata"] = metadata_json
pnginfo.add_text("invokeai_metadata", metadata_json)
if workflow is not None:
info_dict["invokeai_workflow"] = workflow
pnginfo.add_text("invokeai_workflow", workflow)
if graph is not None:
info_dict["invokeai_graph"] = graph
pnginfo.add_text("invokeai_graph", graph)
workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
@ -96,7 +100,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path = self.get_path(image_name)
if image_path.exists():
image_path.unlink()
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
@ -104,7 +108,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
thumbnail_path = self.get_path(thumbnail_name, True)
if thumbnail_path.exists():
thumbnail_path.unlink()
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
@ -125,18 +129,11 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def get_workflow(self, image_name: str) -> str | None:
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if isinstance(workflow, str):
return workflow
return None
def get_graph(self, image_name: str) -> str | None:
image = self.get(image_name)
graph = image.info.get("invokeai_graph", None)
if isinstance(graph, str):
return graph
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
return None
def __validate_storage_folders(self) -> None:

View File

@ -3,14 +3,9 @@ from datetime import datetime
from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
class ImageRecordStorageBase(ABC):
@ -42,13 +37,10 @@ class ImageRecordStorageBase(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@ -88,7 +80,7 @@ class ImageRecordStorageBase(ABC):
starred: Optional[bool] = False,
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
metadata: Optional[MetadataField] = None,
) -> datetime:
"""Saves an image record."""
pass

View File

@ -4,8 +4,11 @@ from datetime import datetime
from typing import Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .image_records_base import ImageRecordStorageBase
from .image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageRecord,
@ -16,9 +19,6 @@ from invokeai.app.services.image_records.image_records_common import (
ResourceOrigin,
deserialize_image_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
@ -144,13 +144,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
@ -211,21 +208,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_pagination = """--sql
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
@ -343,9 +328,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
starred: Optional[bool] = False,
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
metadata: Optional[MetadataField] = None,
) -> datetime:
try:
metadata_json = metadata.model_dump_json() if metadata is not None else None
self._lock.acquire()
self._cursor.execute(
"""--sql
@ -372,7 +358,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height,
node_id,
session_id,
metadata,
metadata_json,
is_intermediate,
starred,
has_workflow,

View File

@ -12,7 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC):
@ -51,9 +51,8 @@ class ImageServiceABC(ABC):
session_id: Optional[str] = None,
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[str] = None,
workflow: Optional[str] = None,
graph: Optional[str] = None,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -88,12 +87,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets an image's workflow."""
pass
@ -117,13 +111,10 @@ class ImageServiceABC(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass

Some files were not shown because too many files have changed in this diff Show More