Merge remote-tracking branch 'origin/main' into maryhipp/forward-args-through-image-service

This commit is contained in:
maryhipp 2024-01-16 12:09:26 -05:00
commit 956c8f911d
1342 changed files with 74684 additions and 244703 deletions

8
.github/CODEOWNERS vendored
View File

@ -1,5 +1,5 @@
# continuous integration
/.github/workflows/ @lstein @blessedcoolant @hipsterusername
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
@ -10,7 +10,7 @@
# installation and configuration
/pyproject.toml @lstein @blessedcoolant @hipsterusername
/docker/ @lstein @blessedcoolant @hipsterusername
/docker/ @lstein @blessedcoolant @hipsterusername @ebr
/scripts/ @ebr @lstein @hipsterusername
/installer/ @lstein @ebr @hipsterusername
/invokeai/assets @lstein @ebr @hipsterusername
@ -26,9 +26,7 @@
# front ends
/invokeai/frontend/CLI @lstein @hipsterusername
/invokeai/frontend/install @lstein @ebr @hipsterusername
/invokeai/frontend/install @lstein @ebr @hipsterusername
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername

View File

@ -42,6 +42,21 @@ Please provide steps on how to test changes, any hardware or
software specifications as well as any other pertinent information.
-->
## Merge Plan
<!--
A merge plan describes how this PR should be handled after it is approved.
Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is merged"
A merge plan is particularly important for large PRs or PRs that touch the
database in any way.
-->
## Added/updated tests?
- [ ] Yes

View File

@ -40,10 +40,14 @@ jobs:
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
run: |
echo "----- Free space before cleanup"
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
echo "----- Free space after cleanup"
df -h
- name: Checkout
uses: actions/checkout@v3
@ -91,6 +95,7 @@ jobs:
# password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build container
timeout-minutes: 40
id: docker_build
uses: docker/build-push-action@v4
with:

View File

@ -22,12 +22,22 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Setup Node 18
uses: actions/setup-node@v3
uses: actions/setup-node@v4
with:
node-version: '18'
- uses: actions/checkout@v3
- run: 'yarn install --frozen-lockfile'
- run: 'yarn run lint:tsc'
- run: 'yarn run lint:madge'
- run: 'yarn run lint:eslint'
- run: 'yarn run lint:prettier'
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install dependencies
run: 'pnpm install --prefer-frozen-lockfile'
- name: Typescript
run: 'pnpm run lint:tsc'
- name: Madge
run: 'pnpm run lint:madge'
- name: ESLint
run: 'pnpm run lint:eslint'
- name: Prettier
run: 'pnpm run lint:prettier'

View File

@ -1,13 +1,15 @@
name: PyPI Release
on:
push:
paths:
- 'invokeai/version/invokeai_version.py'
workflow_dispatch:
inputs:
publish_package:
description: 'Publish build on PyPi? [true/false]'
required: true
default: 'false'
jobs:
release:
build-and-release:
if: github.repository == 'invoke-ai/InvokeAI'
runs-on: ubuntu-22.04
env:
@ -15,19 +17,43 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
TWINE_NON_INTERACTIVE: 1
steps:
- name: checkout sources
uses: actions/checkout@v3
- name: Checkout
uses: actions/checkout@v4
- name: install deps
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
working-directory: invokeai/frontend/web
- name: Build frontend
run: pnpm run build
working-directory: invokeai/frontend/web
- name: Install python dependencies
run: pip install --upgrade build twine
- name: build package
- name: Build python package
run: python3 -m build
- name: check distribution
- name: Upload build as workflow artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: dist
- name: Check distribution
run: twine check dist/*
- name: check PyPI versions
- name: Check PyPI versions
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
run: |
pip install --upgrade requests
@ -36,6 +62,6 @@ jobs:
EXISTS=scripts.pypi_helper.local_on_pypi(); \
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
- name: upload package
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != ''
- name: Publish build on PyPi
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
run: twine upload dist/*

3
.gitignore vendored
View File

@ -16,7 +16,7 @@ __pycache__/
.Python
build/
develop-eggs/
# dist/
dist/
downloads/
eggs/
.eggs/
@ -187,3 +187,4 @@ installer/install.bat
installer/install.sh
installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/

View File

@ -1,6 +1,20 @@
# simple Makefile with scripts that are otherwise hard to remember
# to use, run from the repo root `make <command>`
default: help
help:
@echo Developer commands:
@echo
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
ruff check . --fix
@ -18,4 +32,21 @@ mypy:
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
# (many files are ignored by the config, so this is useful for checking all files)
mypy-all:
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
# Build the frontend
frontend-build:
cd invokeai/frontend/web && pnpm build
# Run the frontend in dev mode
frontend-dev:
cd invokeai/frontend/web && pnpm dev
# Installer zip file
installer-zip:
cd installer && ./create_installer.sh
# Tag the release
tag-release:
cd installer && ./tag_release.sh

View File

@ -1,10 +1,10 @@
<div align="center">
![project hero](https://github.com/invoke-ai/InvokeAI/assets/31807370/1a917d94-e099-4fa1-a70f-7dd8d0691018)
![project hero](https://github.com/invoke-ai/InvokeAI/assets/31807370/6e3728c7-e90e-4711-905c-3b55844ff5be)
# Invoke AI - Generative AI for Professional Creatives
## Professional Creative Tools for Stable Diffusion, Custom-Trained Models, and more.
To learn more about Invoke AI, get started instantly, or implement our Business solutions, visit [invoke.ai](https://invoke.ai)
# Invoke - Professional Creative AI Tools for Visual Media
## To learn more about Invoke, or implement our Business solutions, visit [invoke.com](https://www.invoke.com/about)
[![discord badge]][discord link]
@ -56,7 +56,9 @@ the foundation for multiple commercial products.
<div align="center">
![canvas preview](https://github.com/invoke-ai/InvokeAI/raw/main/docs/assets/canvas_preview.png)
![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d)
</div>
@ -125,8 +127,8 @@ and go to http://localhost:9090.
You must have Python 3.10 through 3.11 installed on your machine. Earlier or
later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)
Node.js also needs to be installed along with `pnpm` (can be installed with
the command `npm install -g pnpm` if needed)
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
@ -270,7 +272,7 @@ upgrade script.** See the next section for a Windows recipe.
3. Select option [1] to upgrade to the latest release.
4. Once the upgrade is finished you will be returned to the launcher
menu. Select option [7] "Re-run the configure script to fix a broken
menu. Select option [6] "Re-run the configure script to fix a broken
install or to complete a major upgrade".
This will run the configure script against the v2.3 directory and

View File

@ -2,14 +2,17 @@
## Any environment variables supported by InvokeAI can be specified here,
## in addition to the examples below.
# INVOKEAI_ROOT is the path to a path on the local filesystem where InvokeAI will store data.
# HOST_INVOKEAI_ROOT is the path on the docker host's filesystem where InvokeAI will store data.
# Outputs will also be stored here by default.
# This **must** be an absolute path.
INVOKEAI_ROOT=
# If relative, it will be relative to the docker directory in which the docker-compose.yml file is located
#HOST_INVOKEAI_ROOT=../../invokeai-data
# INVOKEAI_ROOT is the path to the root of the InvokeAI repository within the container.
# INVOKEAI_ROOT=~/invokeai
# Get this value from your HuggingFace account settings page.
# HUGGING_FACE_HUB_TOKEN=
## optional variables specific to the docker setup.
# GPU_DRIVER=cuda # or rocm
# GPU_DRIVER=nvidia #| rocm
# CONTAINER_UID=1000

View File

@ -59,14 +59,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# #### Build the Web UI ------------------------------------
FROM node:18 AS web-builder
FROM node:20-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./
RUN --mount=type=cache,target=/usr/lib/node_modules \
npm install --include dev
RUN --mount=type=cache,target=/usr/lib/node_modules \
yarn vite build
RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN npx vite build
#### Runtime stage ---------------------------------------
@ -100,6 +102,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv/invokeai
ENV INVOKEAI_ROOT=/invokeai
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
# --link requires buldkit w/ dockerfile syntax 1.4
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
@ -117,7 +121,7 @@ WORKDIR ${INVOKEAI_SRC}
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python3 -c "from patchmatch import patch_match"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R 1000:1000 ${INVOKEAI_ROOT}
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]

View File

@ -1,6 +1,14 @@
# InvokeAI Containerized
All commands are to be run from the `docker` directory: `cd docker`
All commands should be run within the `docker` directory: `cd docker`
## Quickstart :rocket:
On a known working Linux+Docker+CUDA (Nvidia) system, execute `./run.sh` in this directory. It will take a few minutes - depending on your internet speed - to install the core models. Once the application starts up, open `http://localhost:9090` in your browser to Invoke!
For more configuration options (using an AMD GPU, custom root directory location, etc): read on.
## Detailed setup
#### Linux
@ -18,12 +26,12 @@ All commands are to be run from the `docker` directory: `cd docker`
This is done via Docker Desktop preferences
## Quickstart
### Configure Invoke environment
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
a. the desired location of the InvokeAI runtime directory, or
b. an existing, v3.0.0 compatible runtime directory.
1. `docker compose up`
1. Execute `run.sh`
The image will be built automatically if needed.
@ -37,19 +45,21 @@ The runtime directory (holding models and outputs) will be created in the locati
The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker documentation for the most up-to-date instructions for using your GPU with Docker.
To use an AMD GPU, set `GPU_DRIVER=rocm` in your `.env` file.
## Customize
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `docker compose up`, your custom values will be used.
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `run.sh`, your custom values will be used.
You can also set these values in `docker-compose.yml` directly, but `.env` will help avoid conflicts when code is updated.
Example (values are optional, but setting `INVOKEAI_ROOT` is highly recommended):
Values are optional, but setting `INVOKEAI_ROOT` is highly recommended. The default is `~/invokeai`. Example:
```bash
INVOKEAI_ROOT=/Volumes/WorkDrive/invokeai
HUGGINGFACE_TOKEN=the_actual_token
CONTAINER_UID=1000
GPU_DRIVER=cuda
GPU_DRIVER=nvidia
```
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.

View File

@ -1,11 +0,0 @@
#!/usr/bin/env bash
set -e
build_args=""
[[ -f ".env" ]] && build_args=$(awk '$1 ~ /\=[^$]/ {print "--build-arg " $0 " "}' .env)
echo "docker compose build args:"
echo $build_args
docker compose build $build_args

View File

@ -2,23 +2,8 @@
version: '3.8'
services:
invokeai:
x-invokeai: &invokeai
image: "local/invokeai:latest"
# edit below to run on a container runtime other than nvidia-container-runtime.
# not yet tested with rocm/AMD GPUs
# Comment out the "deploy" section to run on CPU only
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# For AMD support, comment out the deploy section above and uncomment the devices section below:
#devices:
# - /dev/kfd:/dev/kfd
# - /dev/dri:/dev/dri
build:
context: ..
dockerfile: docker/Dockerfile
@ -36,7 +21,9 @@ services:
ports:
- "${INVOKEAI_PORT:-9090}:9090"
volumes:
- ${INVOKEAI_ROOT:-~/invokeai}:${INVOKEAI_ROOT:-/invokeai}
- type: bind
source: ${HOST_INVOKEAI_ROOT:-${INVOKEAI_ROOT:-~/invokeai}}
target: ${INVOKEAI_ROOT:-/invokeai}
- ${HF_HOME:-~/.cache/huggingface}:${HF_HOME:-/invokeai/.cache/huggingface}
# - ${INVOKEAI_MODELS_DIR:-${INVOKEAI_ROOT:-/invokeai/models}}
# - ${INVOKEAI_MODELS_CONFIG_PATH:-${INVOKEAI_ROOT:-/invokeai/configs/models.yaml}}
@ -50,3 +37,27 @@ services:
# - |
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
# invokeai-nodes-web --host 0.0.0.0
services:
invokeai-nvidia:
<<: *invokeai
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
invokeai-cpu:
<<: *invokeai
profiles:
- cpu
invokeai-rocm:
<<: *invokeai
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
profiles:
- rocm

View File

@ -1,11 +1,32 @@
#!/usr/bin/env bash
set -e
set -e -o pipefail
# This script is provided for backwards compatibility with the old docker setup.
# it doesn't do much aside from wrapping the usual docker compose CLI.
run() {
local scriptdir=$(dirname "${BASH_SOURCE[0]}")
cd "$scriptdir" || exit 1
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
cd "$SCRIPTDIR" || exit 1
local build_args=""
local profile=""
docker compose up -d
docker compose logs -f
touch .env
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
[[ -z "$profile" ]] && profile="nvidia"
local service_name="invokeai-$profile"
if [[ ! -z "$build_args" ]]; then
printf "%s\n" "docker compose build args:"
printf "%s\n" "$build_args"
fi
docker compose build $build_args
unset build_args
printf "%s\n" "starting service $service_name"
docker compose --profile "$profile" up -d "$service_name"
docker compose logs -f
}
run

Binary file not shown.

Before

Width:  |  Height:  |  Size: 297 KiB

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

After

Width:  |  Height:  |  Size: 4.9 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 169 KiB

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 194 KiB

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 209 KiB

After

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 114 KiB

After

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 187 KiB

After

Width:  |  Height:  |  Size: 123 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 112 KiB

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 167 KiB

After

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

View File

@ -0,0 +1,277 @@
# The InvokeAI Download Queue
The DownloadQueueService provides a multithreaded parallel download
queue for arbitrary URLs, with queue prioritization, event handling,
and restart capabilities.
## Simple Example
```
from invokeai.app.services.download import DownloadQueueService, TqdmProgress
download_queue = DownloadQueueService()
for url in ['https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true',
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true',
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png',
'https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor',
]:
# urls start downloading as soon as download() is called
download_queue.download(source=url,
dest='/tmp/downloads',
on_progress=TqdmProgress().update
)
download_queue.join() # wait for all downloads to finish
for job in download_queue.list_jobs():
print(job.model_dump_json(exclude_none=True, indent=4),"\n")
```
Output:
```
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true",
"dest": "/tmp/downloads",
"id": 0,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/a-painting-of-a-fire.png",
"job_started": "2023-12-04T05:34:41.742174",
"job_ended": "2023-12-04T05:34:42.592035",
"bytes": 666734,
"total_bytes": 666734
}
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true",
"dest": "/tmp/downloads",
"id": 1,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/birdhouse.png",
"job_started": "2023-12-04T05:34:41.741975",
"job_ended": "2023-12-04T05:34:42.652841",
"bytes": 774949,
"total_bytes": 774949
}
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png",
"dest": "/tmp/downloads",
"id": 2,
"priority": 10,
"status": "error",
"job_started": "2023-12-04T05:34:41.742079",
"job_ended": "2023-12-04T05:34:42.147625",
"bytes": 0,
"total_bytes": 0,
"error_type": "HTTPError(Not Found)",
"error": "Traceback (most recent call last):\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 182, in _download_next_item\n self._do_download(job)\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 206, in _do_download\n raise HTTPError(resp.reason)\nrequests.exceptions.HTTPError: Not Found\n"
}
{
"source": "https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor",
"dest": "/tmp/downloads",
"id": 3,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/xl_more_art-full_v1.safetensors",
"job_started": "2023-12-04T05:34:42.147645",
"job_ended": "2023-12-04T05:34:43.735990",
"bytes": 719020768,
"total_bytes": 719020768
}
```
## The API
The default download queue is `DownloadQueueService`, an
implementation of ABC `DownloadQueueServiceBase`. It juggles multiple
background download requests and provides facilities for interrogating
and cancelling the requests. Access to a current or past download task
is mediated via `DownloadJob` objects which report the current status
of a job request
### The Queue Object
A default download queue is located in
`ApiDependencies.invoker.services.download_queue`. However, you can
create additional instances if you need to isolate your queue from the
main one.
```
queue = DownloadQueueService(event_bus=events)
```
`DownloadQueueService()` takes three optional arguments:
| **Argument** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| `max_parallel_dl` | int | 5 | Maximum number of simultaneous downloads allowed |
| `event_bus` | EventServiceBase | None | System-wide FastAPI event bus for reporting download events |
| `requests_session` | requests.sessions.Session | None | An alternative requests Session object to use for the download |
`max_parallel_dl` specifies how many download jobs are allowed to run
simultaneously. Each will run in a different thread of execution.
`event_bus` is an EventServiceBase, typically the one created at
InvokeAI startup. If present, download events are periodically emitted
on this bus to allow clients to follow download progress.
`requests_session` is a url library requests Session object. It is
used for testing.
### The Job object
The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.
The only job type currently implemented is `DownloadJob`, a pydantic object with the
following fields:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `source` | AnyHttpUrl | | Where to download from |
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| `priority` | int | 10 | Job priority. Lower priorities run before higher priorities |
| |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the location of the downloaded file |
| `job_started` | float | | Timestamp for when the job started running |
| `job_ended` | float | | Timestamp for when the job completed or errored out |
| `job_sequence` | int | | A counter that is incremented each time a model is dequeued |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
When you create a job, you can assign it a `priority`. If multiple
jobs are queued, the job with the lowest priority runs first.
Every job has a `source` and a `dest`. `source` is a pydantic.networks AnyHttpUrl object.
The `dest` is a path on the local filesystem that specifies the
destination for the downloaded object. Its semantics are
described below.
When the job is submitted, it is assigned a numeric `id`. The id can
then be used to fetch the job object from the queue.
The `status` field is updated by the queue to indicate where the job
is in its lifecycle. Values are defined in the string enum
`DownloadJobStatus`, a symbol available from
`invokeai.app.services.download_manager`. Possible values are:
| **Value** | **String Value** | ** Description ** |
|--------------|---------------------|-------------------|
| `WAITING` | waiting | Job is on the queue but not yet running|
| `RUNNING` | running | The download is started |
| `COMPLETED` | completed | Job has finished its work without an error |
| `ERROR` | error | Job encountered an error and will not run again|
`job_started` and `job_ended` indicate when the job
was started (using a python timestamp) and when it completed.
In case of an error, the job's status will be set to `DownloadJobStatus.ERROR`, the text of the
Exception that caused the error will be placed in the `error_type`
field and the traceback that led to the error will be in `error`.
A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.
### Callbacks
Download jobs can be associated with a series of callbacks, each with
the signature `Callable[["DownloadJob"], None]`. The callbacks are assigned
using optional arguments `on_start`, `on_progress`, `on_complete` and
`on_error`. When the corresponding event occurs, the callback wil be
invoked and passed the job. The callback will be run in a `try:`
context in the same thread as the download job. Any exceptions that
occur during execution of the callback will be caught and converted
into a log error message, thereby allowing the download to continue.
#### `TqdmProgress`
The `invokeai.app.services.download.download_default` module defines a
class named `TqdmProgress` which can be used as an `on_progress`
handler to display a completion bar in the console. Use as follows:
```
from invokeai.app.services.download import TqdmProgress
download_queue.download(source='http://some.server.somewhere/some_file',
dest='/tmp/downloads',
on_progress=TqdmProgress().update
)
```
### Events
If the queue was initialized with the InvokeAI event bus (the case
when using `ApiDependencies.invoker.services.download_queue`), then
download events will also be issued on the bus. The events are:
* `download_started` -- This is issued when a job is taken off the
queue and a request is made to the remote server for the URL headers, but before any data
has been downloaded. The event payload will contain the keys `source`
and `download_path`. The latter contains the path that the URL will be
downloaded to.
* `download_progress -- This is issued periodically as the download
runs. The payload contains the keys `source`, `download_path`,
`current_bytes` and `total_bytes`. The latter two fields can be
used to display the percent complete.
* `download_complete` -- This is issued when the download completes
successfully. The payload contains the keys `source`, `download_path`
and `total_bytes`.
* `download_error` -- This is issued when the download stops because
of an error condition. The payload contains the fields `error_type`
and `error`. The former is the text representation of the exception,
and the latter is a traceback showing where the error occurred.
### Job control
To create a job call the queue's `download()` method. You can list all
jobs using `list_jobs()`, fetch a single job by its with
`id_to_job()`, cancel a running job with `cancel_job()`, cancel all
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.
#### job = queue.download(source, dest, priority, access_token)
Create a new download job and put it on the queue, returning the
DownloadJob object.
#### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s.
#### job = queue.id_to_job(id)
Return the job corresponding to given ID.
Return a list of all active and inactive `DownloadJob`s.
#### queue.prune_jobs()
Remove inactive (complete or errored) jobs from the listing returned
by `list_jobs()`.
#### queue.join()
Block until all pending jobs have run to completion or errored out.

File diff suppressed because it is too large Load Diff

View File

@ -46,17 +46,18 @@ We encourage you to ping @psychedelicious and @blessedcoolant on [Discord](http
```bash
node --version
```
2. Install [yarn classic](https://classic.yarnpkg.com/lang/en/) and confirm it is installed by running this:
2. Install [pnpm](https://pnpm.io/) and confirm it is installed by running this:
```bash
npm install --global yarn
yarn --version
npm install --global pnpm
pnpm --version
```
From `invokeai/frontend/web/` run `yarn install` to get everything set up.
From `invokeai/frontend/web/` run `pnpm install` to get everything set up.
Start everything in dev mode:
1. Ensure your virtual environment is running
2. Start the dev server: `yarn dev`
2. Start the dev server: `pnpm dev`
3. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
4. Point your browser to the dev server address e.g. [http://localhost:5173/](http://localhost:5173/)
@ -72,4 +73,4 @@ For a number of technical and logistical reasons, we need to commit UI build art
If you submit a PR, there is a good chance we will ask you to include a separate commit with a build of the app.
To build for production, run `yarn build`.
To build for production, run `pnpm build`.

53
docs/deprecated/2to3.md Normal file
View File

@ -0,0 +1,53 @@
## :octicons-log-16: Important Changes Since Version 2.3
### Nodes
Behind the scenes, InvokeAI has been completely rewritten to support
"nodes," small unitary operations that can be combined into graphs to
form arbitrary workflows. For example, there is a prompt node that
processes the prompt string and feeds it to a text2latent node that
generates a latent image. The latents are then fed to a latent2image
node that translates the latent image into a PNG.
The WebGUI has a node editor that allows you to graphically design and
execute custom node graphs. The ability to save and load graphs is
still a work in progress, but coming soon.
### Command-Line Interface Retired
All "invokeai" command-line interfaces have been retired as of version
3.4.
To launch the Web GUI from the command-line, use the command
`invokeai-web` rather than the traditional `invokeai --web`.
### ControlNet
This version of InvokeAI features ControlNet, a system that allows you
to achieve exact poses for human and animal figures by providing a
model to follow. Full details are found in [ControlNet](features/CONTROLNET.md)
### New Schedulers
The list of schedulers has been completely revamped and brought up to date:
| **Short Name** | **Scheduler** | **Notes** |
|----------------|---------------------------------|-----------------------------|
| **ddim** | DDIMScheduler | |
| **ddpm** | DDPMScheduler | |
| **deis** | DEISMultistepScheduler | |
| **lms** | LMSDiscreteScheduler | |
| **pndm** | PNDMScheduler | |
| **heun** | HeunDiscreteScheduler | original noise schedule |
| **heun_k** | HeunDiscreteScheduler | using karras noise schedule |
| **euler** | EulerDiscreteScheduler | original noise schedule |
| **euler_k** | EulerDiscreteScheduler | using karras noise schedule |
| **kdpm_2** | KDPM2DiscreteScheduler | |
| **kdpm_2_a** | KDPM2AncestralDiscreteScheduler | |
| **dpmpp_2s** | DPMSolverSinglestepScheduler | |
| **dpmpp_2m** | DPMSolverMultistepScheduler | original noise scnedule |
| **dpmpp_2m_k** | DPMSolverMultistepScheduler | using karras noise schedule |
| **unipc** | UniPCMultistepScheduler | CPU only |
| **lcm** | LCMScheduler | |
Please see [3.0.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.0.0) for further details.

View File

@ -154,14 +154,16 @@ groups in `invokeia.yaml`:
### Web Server
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
| Setting | Default Value | Description |
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].

View File

@ -229,29 +229,28 @@ clarity on the intent and common use cases we expect for utilizing them.
currently being rendered by your browser into a merged copy of the image. This
lowers the resource requirements and should improve performance.
### Seam Correction
### Compositing / Seam Correction
When doing Inpainting or Outpainting, Invoke needs to merge the pixels generated
by Stable Diffusion into your existing image. To do this, the area around the
`seam` at the boundary between your image and the new generation is
by Stable Diffusion into your existing image. This is achieved through compositing - the area around the the boundary between your image and the new generation is
automatically blended to produce a seamless output. In a fully automatic
process, a mask is generated to cover the seam, and then the area of the seam is
process, a mask is generated to cover the boundary, and then the area of the boundary is
Inpainted.
Although the default options should work well most of the time, sometimes it can
help to alter the parameters that control the seam Inpainting. A wider seam and
a blur setting of about 1/3 of the seam have been noted as producing
consistently strong results (e.g. 96 wide and 16 blur - adds up to 32 blur with
both sides). Seam strength of 0.7 is best for reducing hard seams.
help to alter the parameters that control the Compositing. A larger blur and
a blur setting have been noted as producing
consistently strong results . Strength of 0.7 is best for reducing hard seams.
- **Mode** - What part of the image will have the the Compositing applied to it.
- **Mask edge** will apply Compositing to the edge of the masked area
- **Mask** will apply Compositing to the entire masked area
- **Unmasked** will apply Compositing to the entire image
- **Steps** - Number of generation steps that will occur during the Coherence Pass, similar to Denoising Steps. Higher step counts will generally have better results.
- **Strength** - How much noise is added for the Coherence Pass, similar to Denoising Strength. A strength of 0 will result in an unchanged image, while a strength of 1 will result in an image with a completely new area as defined by the Mode setting.
- **Blur** - Adjusts the pixel radius of the the mask. A larger blur radius will cause the mask to extend past the visibly masked area, while too small of a blur radius will result in a mask that is smaller than the visibly masked area.
- **Blur Method** - The method of blur applied to the masked area.
- **Seam Size** - The size of the seam masked area. Set higher to make a larger
mask around the seam.
- **Seam Blur** - The size of the blur that is applied on _each_ side of the
masked area.
- **Seam Strength** - The Image To Image Strength parameter used for the
Inpainting generation that is applied to the seam area.
- **Seam Steps** - The number of generation steps that should be used to Inpaint
the seam.
### Infill & Scaling

View File

@ -18,7 +18,7 @@ title: Home
width: 100%;
max-width: 100%;
height: 50px;
background-color: #448AFF;
background-color: #35A4DB;
color: #fff;
font-size: 16px;
border: none;
@ -43,7 +43,7 @@ title: Home
<div align="center" markdown>
[![project logo](assets/invoke_ai_banner.png)](https://github.com/invoke-ai/InvokeAI)
[![project logo](https://github.com/invoke-ai/InvokeAI/assets/31807370/6e3728c7-e90e-4711-905c-3b55844ff5be)](https://github.com/invoke-ai/InvokeAI)
[![discord badge]][discord link]
@ -145,60 +145,6 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
- [Guide to InvokeAI Runtime Settings](features/CONFIGURATION.md)
- [Database Maintenance and other Command Line Utilities](features/UTILITIES.md)
## :octicons-log-16: Important Changes Since Version 2.3
### Nodes
Behind the scenes, InvokeAI has been completely rewritten to support
"nodes," small unitary operations that can be combined into graphs to
form arbitrary workflows. For example, there is a prompt node that
processes the prompt string and feeds it to a text2latent node that
generates a latent image. The latents are then fed to a latent2image
node that translates the latent image into a PNG.
The WebGUI has a node editor that allows you to graphically design and
execute custom node graphs. The ability to save and load graphs is
still a work in progress, but coming soon.
### Command-Line Interface Retired
All "invokeai" command-line interfaces have been retired as of version
3.4.
To launch the Web GUI from the command-line, use the command
`invokeai-web` rather than the traditional `invokeai --web`.
### ControlNet
This version of InvokeAI features ControlNet, a system that allows you
to achieve exact poses for human and animal figures by providing a
model to follow. Full details are found in [ControlNet](features/CONTROLNET.md)
### New Schedulers
The list of schedulers has been completely revamped and brought up to date:
| **Short Name** | **Scheduler** | **Notes** |
|----------------|---------------------------------|-----------------------------|
| **ddim** | DDIMScheduler | |
| **ddpm** | DDPMScheduler | |
| **deis** | DEISMultistepScheduler | |
| **lms** | LMSDiscreteScheduler | |
| **pndm** | PNDMScheduler | |
| **heun** | HeunDiscreteScheduler | original noise schedule |
| **heun_k** | HeunDiscreteScheduler | using karras noise schedule |
| **euler** | EulerDiscreteScheduler | original noise schedule |
| **euler_k** | EulerDiscreteScheduler | using karras noise schedule |
| **kdpm_2** | KDPM2DiscreteScheduler | |
| **kdpm_2_a** | KDPM2AncestralDiscreteScheduler | |
| **dpmpp_2s** | DPMSolverSinglestepScheduler | |
| **dpmpp_2m** | DPMSolverMultistepScheduler | original noise scnedule |
| **dpmpp_2m_k** | DPMSolverMultistepScheduler | using karras noise schedule |
| **unipc** | UniPCMultistepScheduler | CPU only |
| **lcm** | LCMScheduler | |
Please see [3.0.0 Release Notes](https://github.com/invoke-ai/InvokeAI/releases/tag/v3.0.0) for further details.
## :material-target: Troubleshooting
Please check out our **[:material-frequently-asked-questions:

View File

@ -293,6 +293,19 @@ manager, please follow these steps:
## Developer Install
!!! warning
InvokeAI uses a SQLite database. By running on `main`, you accept responsibility for your database. This
means making regular backups (especially before pulling) and/or fixing it yourself in the event that a
PR introduces a schema change.
If you don't need persistent backend storage, you can use an ephemeral in-memory database by setting
`use_memory_db: true` under `Path:` in your `invokeai.yaml` file.
If this is untenable, you should run the application via the official installer or a manual install of the
python package from pypi. These releases will not break your database.
If you have an interest in how InvokeAI works, or you would like to
add features or bugfixes, you are encouraged to install the source
code for InvokeAI. For this to work, you will need to install the
@ -388,3 +401,5 @@ environment variable INVOKEAI_ROOT to point to the installation directory.
Note that if you run into problems with the Conda installation, the InvokeAI
staff will **not** be able to help you out. Caveat Emptor!
[dev-chat]: https://discord.com/channels/1020123559063990373/1049495067846524939

View File

@ -0,0 +1,10 @@
document.addEventListener("DOMContentLoaded", function () {
var script = document.createElement("script");
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
script.setAttribute("data-website-id", "b5973bb1-476b-451e-8cf4-98de86745a10");
script.setAttribute("data-project-name", "Invoke.AI");
script.setAttribute("data-project-color", "#11213C");
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/113954515?s=280&v=4");
script.async = true;
document.head.appendChild(script);
});

View File

@ -6,10 +6,17 @@ If you're not familiar with Diffusion, take a look at our [Diffusion Overview.](
## Features
### Workflow Library
The Workflow Library enables you to save workflows to the Invoke database, allowing you to easily creating, modify and share workflows as needed.
A curated set of workflows are provided by default - these are designed to help explain important nodes' usage in the Workflow Editor.
![workflow_library](../assets/nodes/workflow_library.png)
### Linear View
The Workflow Editor allows you to create a UI for your workflow, to make it easier to iterate on your generations.
To add an input to the Linear UI, right click on the input label and select "Add to Linear View".
To add an input to the Linear UI, right click on the **input label** and select "Add to Linear View".
The Linear UI View will also be part of the saved workflow, allowing you share workflows and enable other to use them, regardless of complexity.
@ -30,7 +37,7 @@ Any node or input field can be renamed in the workflow editor. If the input fiel
Nodes have a "Use Cache" option in their footer. This allows for performance improvements by using the previously cached values during the workflow processing.
## Important Concepts
## Important Nodes & Concepts
There are several node grouping concepts that can be examined with a narrow focus. These (and other) groupings can be pieced together to make up functional graph setups, and are important to understanding how groups of nodes work together as part of a whole. Note that the screenshots below aren't examples of complete functioning node graphs (see Examples).
@ -56,7 +63,7 @@ The ImageToLatents node takes in a pixel image and a VAE and outputs a latents.
It is common to want to use both the same seed (for continuity) and random seeds (for variety). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed.
![groupsrandseed](../assets/nodes/groupsrandseed.png)
![groupsrandseed](../assets/nodes/groupsnoise.png)
### ControlNet

View File

@ -13,7 +13,12 @@ If you'd prefer, you can also just download the whole node folder from the linke
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
- Community Nodes
+ [Adapters-Linked](#adapters-linked-nodes)
+ [Average Images](#average-images)
+ [Clean Image Artifacts After Cut](#clean-image-artifacts-after-cut)
+ [Close Color Mask](#close-color-mask)
+ [Clothing Mask](#clothing-mask)
+ [Contrast Limited Adaptive Histogram Equalization](#contrast-limited-adaptive-histogram-equalization)
+ [Depth Map from Wavefront OBJ](#depth-map-from-wavefront-obj)
+ [Film Grain](#film-grain)
+ [Generative Grammar-Based Prompt Nodes](#generative-grammar-based-prompt-nodes)
@ -22,16 +27,24 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [Halftone](#halftone)
+ [Ideal Size](#ideal-size)
+ [Image and Mask Composition Pack](#image-and-mask-composition-pack)
+ [Image Dominant Color](#image-dominant-color)
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
+ [Image Picker](#image-picker)
+ [Image Resize Plus](#image-resize-plus)
+ [Load Video Frame](#load-video-frame)
+ [Make 3D](#make-3d)
+ [Mask Operations](#mask-operations)
+ [Match Histogram](#match-histogram)
+ [Metadata-Linked](#metadata-linked-nodes)
+ [Negative Image](#negative-image)
+ [Nightmare Promptgen](#nightmare-promptgen)
+ [Oobabooga](#oobabooga)
+ [Prompt Tools](#prompt-tools)
+ [Remote Image](#remote-image)
+ [Remove Background](#remove-background)
+ [Retroize](#retroize)
+ [Size Stepper Nodes](#size-stepper-nodes)
+ [Simple Skin Detection](#simple-skin-detection)
+ [Text font to Image](#text-font-to-image)
+ [Thresholding](#thresholding)
+ [Unsharp Mask](#unsharp-mask)
@ -41,6 +54,19 @@ To use a community workflow, download the the `.json` node graph file and load i
- [Help](#help)
--------------------------------
### Adapters Linked Nodes
**Description:** A set of nodes for linked adapters (ControlNet, IP-Adaptor & T2I-Adapter). This allows multiple adapters to be chained together without using a `collect` node which means it can be used inside an `iterate` node without any collecting on every iteration issues.
- `ControlNet-Linked` - Collects ControlNet info to pass to other nodes.
- `IP-Adapter-Linked` - Collects IP-Adapter info to pass to other nodes.
- `T2I-Adapter-Linked` - Collects T2I-Adapter info to pass to other nodes.
Note: These are inherited from the core nodes so any update to the core nodes should be reflected in these.
**Node Link:** https://github.com/skunkworxdark/adapters-linked-nodes
--------------------------------
### Average Images
@ -48,6 +74,46 @@ To use a community workflow, download the the `.json` node graph file and load i
**Node Link:** https://github.com/JPPhoto/average-images-node
--------------------------------
### Clean Image Artifacts After Cut
Description: Removes residual artifacts after an image is separated from its background.
Node Link: https://github.com/VeyDlin/clean-artifact-after-cut-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clean-artifact-after-cut-node/master/.readme/node.png" width="500" />
--------------------------------
### Close Color Mask
Description: Generates a mask for images based on a closely matching color, useful for color-based selections.
Node Link: https://github.com/VeyDlin/close-color-mask-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/close-color-mask-node/master/.readme/node.png" width="500" />
--------------------------------
### Clothing Mask
Description: Employs a U2NET neural network trained for the segmentation of clothing items in images.
Node Link: https://github.com/VeyDlin/clothing-mask-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clothing-mask-node/master/.readme/node.png" width="500" />
--------------------------------
### Contrast Limited Adaptive Histogram Equalization
Description: Enhances local image contrast using adaptive histogram equalization with contrast limiting.
Node Link: https://github.com/VeyDlin/clahe-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clahe-node/master/.readme/node.png" width="500" />
--------------------------------
### Depth Map from Wavefront OBJ
@ -164,6 +230,16 @@ This includes 15 Nodes:
</br><img src="https://raw.githubusercontent.com/dwringer/composition-nodes/main/composition_pack_overview.jpg" width="500" />
--------------------------------
### Image Dominant Color
Description: Identifies and extracts the dominant color from an image using k-means clustering.
Node Link: https://github.com/VeyDlin/image-dominant-color-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-dominant-color-node/master/.readme/node.png" width="500" />
--------------------------------
### Image to Character Art Image Nodes
@ -185,6 +261,17 @@ This includes 15 Nodes:
**Node Link:** https://github.com/JPPhoto/image-picker-node
--------------------------------
### Image Resize Plus
Description: Provides various image resizing options such as fill, stretch, fit, center, and crop.
Node Link: https://github.com/VeyDlin/image-resize-plus-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
--------------------------------
### Load Video Frame
@ -209,6 +296,16 @@ This includes 15 Nodes:
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png" width="300" />
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png" width="300" />
--------------------------------
### Mask Operations
Description: Offers logical operations (OR, SUB, AND) for combining and manipulating image masks.
Node Link: https://github.com/VeyDlin/mask-operations-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/mask-operations-node/master/.readme/node.png" width="500" />
--------------------------------
### Match Histogram
@ -226,6 +323,37 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
--------------------------------
### Metadata Linked Nodes
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
- `Metadata From Image` - Provides Metadata from an image.
- `Metadata To String` - Extracts a String value of a label from metadata.
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
- `Metadata To Float` - Extracts a Float value of a label from metadata.
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
--------------------------------
### Negative Image
Description: Creates a negative version of an image, effective for visual effects and mask inversion.
Node Link: https://github.com/VeyDlin/negative-image-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/negative-image-node/master/.readme/node.png" width="500" />
--------------------------------
### Nightmare Promptgen
**Description:** Nightmare Prompt Generator - Uses a local text generation model to create unique imaginative (but usually nightmarish) prompts for InvokeAI. By default, it allows you to choose from some gpt-neo models I finetuned on over 2500 of my own InvokeAI prompts in Compel format, but you're able to add your own, as well. Offers support for replacing any troublesome words with a random choice from list you can also define.
**Node Link:** [https://github.com/gogurtenjoyer/nightmare-promptgen](https://github.com/gogurtenjoyer/nightmare-promptgen)
--------------------------------
### Oobabooga
@ -289,6 +417,15 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
**Node Link:** https://github.com/fieldOfView/InvokeAI-remote_image
--------------------------------
### Remove Background
Description: An integration of the rembg package to remove backgrounds from images using multiple U2NET models.
Node Link: https://github.com/VeyDlin/remove-background-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/remove-background-node/master/.readme/node.png" width="500" />
--------------------------------
### Retroize
@ -301,6 +438,17 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
<img src="https://github.com/Ar7ific1al/InvokeAI_nodes_retroize/assets/2306586/de8b4fa6-324c-4c2d-b36c-297600c73974" width="500" />
--------------------------------
### Simple Skin Detection
Description: Detects skin in images based on predefined color thresholds.
Node Link: https://github.com/VeyDlin/simple-skin-detection-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/simple-skin-detection-node/master/.readme/node.png" width="500" />
--------------------------------
### Size Stepper Nodes
@ -386,6 +534,7 @@ See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/READ
<img src="https://github.com/skunkworxdark/XYGrid_nodes/blob/main/images/collage.png" width="300" />
--------------------------------
### Example Node Template

View File

@ -1,6 +1,6 @@
# Example Workflows
We've curated some example workflows for you to get started with Workflows in InvokeAI
We've curated some example workflows for you to get started with Workflows in InvokeAI! These can also be found in the Workflow Library, located in the Workflow Editor of Invoke.
To use them, right click on your desired workflow, follow the link to GitHub and click the "⬇" button to download the raw file. You can then use the "Load Workflow" functionality in InvokeAI to load the workflow and start generating images!

View File

@ -215,6 +215,7 @@ We thank them for all of their time and hard work.
- Robert Bolender
- Robin Rombach
- Rohan Barar
- rohinish404
- rpagliuca
- rromb
- Rupesh Sreeraman

View File

@ -0,0 +1,5 @@
:root {
--md-primary-fg-color: #35A4DB;
--md-primary-fg-color--light: #35A4DB;
--md-primary-fg-color--dark: #35A4DB;
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
{
"name": "Text to Image",
"name": "Text to Image - SD1.5",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 1.5/2",
"version": "1.0.1",
"version": "1.1.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, SD1.5, SD2, default",
"notes": "",
@ -18,10 +18,19 @@
{
"nodeId": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"fieldName": "prompt"
},
{
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
"fieldName": "width"
},
{
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
"fieldName": "height"
}
],
"meta": {
"version": "1.0.0"
"category": "default",
"version": "2.0.0"
},
"nodes": [
{
@ -30,44 +39,56 @@
"data": {
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "compel",
"label": "Negative Compel Prompt",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
"name": "prompt",
"type": "string",
"fieldKind": "input",
"label": "Negative Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
"name": "clip",
"type": "ClipField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
"name": "conditioning",
"type": "ConditioningField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
},
"label": "Negative Compel Prompt",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
}
},
"width": 320,
"height": 261,
"height": 259,
"position": {
"x": 995.7263915923627,
"y": 239.67783573351227
"x": 1000,
"y": 350
}
},
{
@ -76,37 +97,60 @@
"data": {
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "noise",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"nodePack": "invokeai",
"inputs": {
"seed": {
"id": "6431737c-918a-425d-a3b4-5d57e2f35d4d",
"name": "seed",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"width": {
"id": "38fc5b66-fe6e-47c8-bba9-daf58e454ed7",
"name": "width",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"height": {
"id": "16298330-e2bf-4872-a514-d6923df53cbb",
"name": "height",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"use_cpu": {
"id": "c7c436d3-7a7a-4e76-91e4-c6deb271623c",
"name": "use_cpu",
"type": "boolean",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
@ -114,35 +158,40 @@
"noise": {
"id": "50f650dc-0184-4e23-a927-0497a96fe954",
"name": "noise",
"type": "LatentsField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "bb8a452b-133d-42d1-ae4a-3843d7e4109a",
"name": "width",
"type": "integer",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "35cfaa12-3b8b-4b7a-a884-327ff3abddd9",
"name": "height",
"type": "integer",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
}
},
"width": 320,
"height": 389,
"height": 388,
"position": {
"x": 993.4442117555518,
"y": 605.6757415334787
"x": 600,
"y": 325
}
},
{
@ -151,13 +200,24 @@
"data": {
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"type": "main_model_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"model": {
"id": "993eabd2-40fd-44fe-bce7-5d0c7075ddab",
"name": "model",
"type": "MainModelField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MainModelField"
},
"value": {
"model_name": "stable-diffusion-v1-5",
"base_model": "sd-1",
@ -169,35 +229,40 @@
"unet": {
"id": "5c18c9db-328d-46d0-8cb9-143391c410be",
"name": "unet",
"type": "UNetField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "6effcac0-ec2f-4bf5-a49e-a2c29cf921f4",
"name": "clip",
"type": "ClipField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
},
"vae": {
"id": "57683ba3-f5f5-4f58-b9a2-4b83dacad4a1",
"name": "vae",
"type": "VaeField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
}
},
"width": 320,
"height": 226,
"position": {
"x": 163.04436745878343,
"y": 254.63156870373479
"x": 600,
"y": 25
}
},
{
@ -206,44 +271,56 @@
"data": {
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "compel",
"label": "Positive Compel Prompt",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
"name": "prompt",
"type": "string",
"fieldKind": "input",
"label": "Positive Prompt",
"value": ""
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": "Super cute tiger cub, national geographic award-winning photograph"
},
"clip": {
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
"name": "clip",
"type": "ClipField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
"name": "conditioning",
"type": "ConditioningField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
},
"label": "Positive Compel Prompt",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
}
},
"width": 320,
"height": 261,
"height": 259,
"position": {
"x": 595.7263915923627,
"y": 239.67783573351227
"x": 1000,
"y": 25
}
},
{
@ -252,21 +329,36 @@
"data": {
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"type": "rand_int",
"label": "Random Seed",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": false,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"low": {
"id": "3ec65a37-60ba-4b6c-a0b2-553dd7a84b84",
"name": "low",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"high": {
"id": "085f853a-1a5f-494d-8bec-e4ba29a3f2d1",
"name": "high",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 2147483647
}
},
@ -274,23 +366,20 @@
"value": {
"id": "812ade4d-7699-4261-b9fc-a6c9d2ab55ee",
"name": "value",
"type": "integer",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
},
"label": "Random Seed",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": false,
"version": "1.0.0"
}
},
"width": 320,
"height": 218,
"height": 32,
"position": {
"x": 541.094822888628,
"y": 694.5704476446829
"x": 600,
"y": 275
}
},
{
@ -299,144 +388,224 @@
"data": {
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "denoise_latents",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
"id": "90b7f4f8-ada7-4028-8100-d2e54f192052",
"name": "positive_conditioning",
"type": "ConditioningField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"negative_conditioning": {
"id": "9393779e-796c-4f64-b740-902a1177bf53",
"name": "negative_conditioning",
"type": "ConditioningField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"noise": {
"id": "8e17f1e5-4f98-40b1-b7f4-86aeeb4554c1",
"name": "noise",
"type": "LatentsField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"steps": {
"id": "9b63302d-6bd2-42c9-ac13-9b1afb51af88",
"name": "steps",
"type": "integer",
"fieldKind": "input",
"label": "",
"value": 10
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 50
},
"cfg_scale": {
"id": "87dd04d3-870e-49e1-98bf-af003a810109",
"name": "cfg_scale",
"type": "FloatPolymorphic",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "FloatField"
},
"value": 7.5
},
"denoising_start": {
"id": "f369d80f-4931-4740-9bcd-9f0620719fab",
"name": "denoising_start",
"type": "float",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"denoising_end": {
"id": "747d10e5-6f02-445c-994c-0604d814de8c",
"name": "denoising_end",
"type": "float",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 1
},
"scheduler": {
"id": "1de84a4e-3a24-4ec8-862b-16ce49633b9b",
"name": "scheduler",
"type": "Scheduler",
"fieldKind": "input",
"label": "",
"value": "euler"
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SchedulerField"
},
"value": "unipc"
},
"unet": {
"id": "ffa6fef4-3ce2-4bdb-9296-9a834849489b",
"name": "unet",
"type": "UNetField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"control": {
"id": "077b64cb-34be-4fcc-83f2-e399807a02bd",
"name": "control",
"type": "ControlPolymorphic",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "ControlField"
}
},
"ip_adapter": {
"id": "1d6948f7-3a65-4a65-a20c-768b287251aa",
"name": "ip_adapter",
"type": "IPAdapterPolymorphic",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "IPAdapterField"
}
},
"t2i_adapter": {
"id": "75e67b09-952f-4083-aaf4-6b804d690412",
"name": "t2i_adapter",
"type": "T2IAdapterPolymorphic",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "T2IAdapterField"
}
},
"cfg_rescale_multiplier": {
"id": "9101f0a6-5fe0-4826-b7b3-47e5d506826c",
"name": "cfg_rescale_multiplier",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"latents": {
"id": "334d4ba3-5a99-4195-82c5-86fb3f4f7d43",
"name": "latents",
"type": "LatentsField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"denoise_mask": {
"id": "0d3dbdbf-b014-4e95-8b18-ff2ff9cb0bfa",
"name": "denoise_mask",
"type": "DenoiseMaskField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "DenoiseMaskField"
}
}
},
"outputs": {
"latents": {
"id": "70fa5bbc-0c38-41bb-861a-74d6d78d2f38",
"name": "latents",
"type": "LatentsField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "98ee0e6c-82aa-4e8f-8be5-dc5f00ee47f0",
"name": "width",
"type": "integer",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "e8cb184a-5e1a-47c8-9695-4b8979564f5d",
"name": "height",
"type": "integer",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.4.0"
}
},
"width": 320,
"height": 646,
"height": 703,
"position": {
"x": 1476.5794704734735,
"y": 256.80174342731783
"x": 1400,
"y": 25
}
},
{
@ -445,153 +614,185 @@
"data": {
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "l2i",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": false,
"useCache": true,
"version": "1.2.0",
"nodePack": "invokeai",
"inputs": {
"metadata": {
"id": "ab375f12-0042-4410-9182-29e30db82c85",
"name": "metadata",
"type": "MetadataField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MetadataField"
}
},
"latents": {
"id": "3a7e7efd-bff5-47d7-9d48-615127afee78",
"name": "latents",
"type": "LatentsField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"vae": {
"id": "a1f5f7a1-0795-4d58-b036-7820c0b0ef2b",
"name": "vae",
"type": "VaeField",
"fieldKind": "input",
"label": ""
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"tiled": {
"id": "da52059a-0cee-4668-942f-519aa794d739",
"name": "tiled",
"type": "boolean",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
},
"fp32": {
"id": "c4841df3-b24e-4140-be3b-ccd454c2522c",
"name": "fp32",
"type": "boolean",
"fieldKind": "input",
"label": "",
"value": false
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
"outputs": {
"image": {
"id": "72d667d0-cf85-459d-abf2-28bd8b823fe7",
"name": "image",
"type": "ImageField",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ImageField"
}
},
"width": {
"id": "c8c907d8-1066-49d1-b9a6-83bdcd53addc",
"name": "width",
"type": "integer",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "230f359c-b4ea-436c-b372-332d7dcdca85",
"name": "height",
"type": "integer",
"fieldKind": "output"
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": false,
"useCache": true,
"version": "1.0.0"
}
},
"width": 320,
"height": 267,
"height": 266,
"position": {
"x": 2037.9648469717395,
"y": 426.10844427600136
"x": 1800,
"y": 25
}
}
],
"edges": [
{
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"sourceHandle": "value",
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
"targetHandle": "seed",
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
"type": "default"
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"sourceHandle": "clip",
"target": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"targetHandle": "clip",
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-7d8bf987-284f-413a-b2fd-d825445a5d6cclip",
"type": "default"
},
{
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "default",
"sourceHandle": "clip",
"target": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"targetHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-93dc02a4-d05b-48ed-b99c-c9b616af3402clip",
"type": "default"
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
"sourceHandle": "noise",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"targetHandle": "noise",
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-eea2702a-19fb-45b5-9d75-56b4211ec03cnoise",
"type": "default"
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"source": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"sourceHandle": "conditioning",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"targetHandle": "positive_conditioning",
"id": "reactflow__edge-7d8bf987-284f-413a-b2fd-d825445a5d6cconditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cpositive_conditioning",
"type": "default"
},
{
"source": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"source": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "conditioning",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"targetHandle": "negative_conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-93dc02a4-d05b-48ed-b99c-c9b616af3402conditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cnegative_conditioning",
"type": "default"
},
{
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"sourceHandle": "unet",
"source": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"targetHandle": "unet",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8unet-eea2702a-19fb-45b5-9d75-56b4211ec03cunet",
"type": "default"
},
{
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"sourceHandle": "latents",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"targetHandle": "latents",
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
"type": "default"
},
{
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"sourceHandle": "vae",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"targetHandle": "vae",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8vae-58c957f5-0d01-41fc-a803-b2bbf0413d4fvae",
"type": "default"
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
}
]
}
}

View File

@ -2,43 +2,72 @@
set -e
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function is_bin_in_path {
builtin type -P "$1" &>/dev/null
}
function git_show {
git show -s --format='%h %s' $1
}
cd "$(dirname "$0")"
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
# Some machines only have `python3` in PATH, others have `python` - make an alias.
# We can use a function to approximate an alias within a non-interactive shell.
if ! is_bin_in_path python && is_bin_in_path python3; then
function python {
python3 "$@"
}
fi
if [[ -v "VIRTUAL_ENV" ]]; then
# we can't just call 'deactivate' because this function is not exported
# to the environment of this script from the bash process that runs the script
echo "A virtual environment is activated. Please deactivate it before proceeding".
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
exit -1
fi
VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
VERSION=$(
cd ..
python -c "from invokeai.version import __version__ as version; print(version)"
)
PATCH=""
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v3-latest"
echo Building installer for version $VERSION
echo "Be certain that you're in the 'installer' directory before continuing."
read -p "Press any key to continue, or CTRL-C to exit..."
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
read -e -p "Tag this repo with '${VERSION}' and '${LATEST_TAG}'? [n]: " input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
# ---------------------- FRONTEND ----------------------
git push origin :refs/tags/$VERSION
if ! git tag -fa $VERSION ; then
echo "Existing/invalid tag"
exit -1
fi
pushd ../invokeai/frontend/web >/dev/null
echo
echo "Installing frontend dependencies..."
echo
pnpm i --frozen-lockfile
echo
echo "Building frontend..."
echo
pnpm build
popd
git push origin :refs/tags/$LATEST_TAG
git tag -fa $LATEST_TAG
# ---------------------- BACKEND ----------------------
echo "remember to push --tags!"
fi
# ----------------------
echo Building the wheel
echo
echo "Building wheel..."
echo
# install the 'build' package in the user site packages, if needed
# could be improved by using a temporary venv, but it's tiny and harmless
@ -46,12 +75,15 @@ if [[ $(python -c 'from importlib.util import find_spec; print(find_spec("build"
pip install --user build
fi
rm -r ../build
rm -rf ../build
python -m build --wheel --outdir dist/ ../.
# ----------------------
echo Building installer zip fles for InvokeAI $VERSION
echo
echo "Building installer zip files for InvokeAI ${VERSION}..."
echo
# get rid of any old ones
rm -f *.zip
@ -59,9 +91,11 @@ rm -rf InvokeAI-Installer
# copy content
mkdir InvokeAI-Installer
for f in templates lib *.txt *.reg; do
for f in templates *.txt *.reg; do
cp -r ${f} InvokeAI-Installer/
done
mkdir InvokeAI-Installer/lib
cp lib/*.py InvokeAI-Installer/lib
# Move the wheel
mv dist/*.whl InvokeAI-Installer/lib/
@ -72,13 +106,13 @@ cp install.sh.in InvokeAI-Installer/install.sh
chmod a+x InvokeAI-Installer/install.sh
# Windows
perl -p -e "s/^set INVOKEAI_VERSION=.*/set INVOKEAI_VERSION=$VERSION/" install.bat.in > InvokeAI-Installer/install.bat
perl -p -e "s/^set INVOKEAI_VERSION=.*/set INVOKEAI_VERSION=$VERSION/" install.bat.in >InvokeAI-Installer/install.bat
cp WinLongPathsEnabled.reg InvokeAI-Installer/
# Zip everything up
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
# clean up
rm -rf InvokeAI-Installer tmp dist
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
exit 0

View File

@ -241,12 +241,12 @@ class InvokeAiInstance:
pip[
"install",
"--require-virtualenv",
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"numpy==1.26.3", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch==2.1.0",
"torch==2.1.2",
"torchmetrics==0.11.4",
"torchvision>=0.14.1",
"torchvision==0.16.2",
"--force-reinstall",
"--find-links" if find_links is not None else None,
find_links,

71
installer/tag_release.sh Executable file
View File

@ -0,0 +1,71 @@
#!/bin/bash
set -e
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function does_tag_exist {
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
}
function git_show_ref {
git show-ref --dereference $1 --abbrev 7
}
function git_show {
git show -s --format='%h %s' $1
}
VERSION=$(
cd ..
python -c "from invokeai.version import __version__ as version; print(version)"
)
PATCH=""
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v${MAJOR_VERSION}-latest"
if does_tag_exist $VERSION; then
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
git_show_ref tags/$VERSION
echo
fi
if does_tag_exist $LATEST_TAG; then
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
git_show_ref tags/$LATEST_TAG
echo
fi
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
read -e -p 'y/n [n]: ' input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
echo
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
git push --delete origin $VERSION
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
if ! git tag -fa $VERSION; then
echo "Existing/invalid tag"
exit -1
fi
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
git push --delete origin $LATEST_TAG
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
git tag -fa $LATEST_TAG
echo -e "Pushing updated tags to remote..."
git push origin --tags
fi
exit 0

View File

@ -2,7 +2,8 @@
from logging import Logger
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -11,6 +12,7 @@ from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
@ -23,14 +25,13 @@ from ..services.invoker import Invoker
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite import SqliteDatabase
from ..services.shared.graph import GraphExecutionState
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -61,14 +62,15 @@ class ApiDependencies:
invoker: Invoker
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = SqliteDatabase(config, logger)
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
logger = logger
@ -79,14 +81,21 @@ class ApiDependencies:
boards = BoardService()
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
download_queue_service = DownloadQueueService(event_bus=events)
metadata_store = ModelMetadataStore(db=db)
model_install_service = ModelInstallService(
app_config=config,
record_store=model_record_service,
download_queue=download_queue_service,
metadata_store=metadata_store,
event_bus=events,
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -94,7 +103,6 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
workflow_records = SqliteWorkflowRecordsStorage(db=db)
services = InvocationServices(
@ -105,7 +113,6 @@ class ApiDependencies:
configuration=configuration,
events=events,
graph_execution_manager=graph_execution_manager,
graph_library=graph_library,
image_files=image_files,
image_records=image_records,
images=images,
@ -114,6 +121,8 @@ class ApiDependencies:
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
download_queue=download_queue_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,
@ -121,17 +130,13 @@ class ApiDependencies:
session_processor=session_processor,
session_queue=session_queue,
urls=urls,
workflow_image_records=workflow_image_records,
workflow_records=workflow_records,
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
db.clean()
@staticmethod
def shutdown():
def shutdown() -> None:
if ApiDependencies.invoker:
ApiDependencies.invoker.stop()

View File

@ -0,0 +1,111 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for the download queue."""
from typing import List, Optional
from fastapi import Body, Path, Response
from fastapi.routing import APIRouter
from pydantic.networks import AnyHttpUrl
from starlette.exceptions import HTTPException
from invokeai.app.services.download import (
DownloadJob,
UnknownJobIDException,
)
from ..dependencies import ApiDependencies
download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"])
@download_queue_router.get(
"/",
operation_id="list_downloads",
)
async def list_downloads() -> List[DownloadJob]:
"""Get a list of active and inactive jobs."""
queue = ApiDependencies.invoker.services.download_queue
return queue.list_jobs()
@download_queue_router.patch(
"/",
operation_id="prune_downloads",
responses={
204: {"description": "All completed jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_downloads():
"""Prune completed and errored jobs."""
queue = ApiDependencies.invoker.services.download_queue
queue.prune_jobs()
return Response(status_code=204)
@download_queue_router.post(
"/i/",
operation_id="download",
)
async def download(
source: AnyHttpUrl = Body(description="download source"),
dest: str = Body(description="download destination"),
priority: int = Body(default=10, description="queue priority"),
access_token: Optional[str] = Body(default=None, description="token for authorization to download"),
) -> DownloadJob:
"""Download the source URL to the file or directory indicted in dest."""
queue = ApiDependencies.invoker.services.download_queue
return queue.download(source, dest, priority, access_token)
@download_queue_router.get(
"/i/{id}",
operation_id="get_download_job",
responses={
200: {"description": "Success"},
404: {"description": "The requested download JobID could not be found"},
},
)
async def get_download_job(
id: int = Path(description="ID of the download job to fetch."),
) -> DownloadJob:
"""Get a download job using its ID."""
try:
job = ApiDependencies.invoker.services.download_queue.id_to_job(id)
return job
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
@download_queue_router.delete(
"/i/{id}",
operation_id="cancel_download_job",
responses={
204: {"description": "Job has been cancelled"},
404: {"description": "The requested download JobID could not be found"},
},
)
async def cancel_download_job(
id: int = Path(description="ID of the download job to cancel."),
):
"""Cancel a download job using its ID."""
try:
queue = ApiDependencies.invoker.services.download_queue
job = queue.id_to_job(id)
queue.cancel_job(job)
return Response(status_code=204)
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
@download_queue_router.delete(
"/i",
operation_id="cancel_all_download_jobs",
responses={
204: {"description": "Download jobs have been cancelled"},
},
)
async def cancel_all_download_jobs():
"""Cancel all download jobs."""
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
return Response(status_code=204)

View File

@ -8,10 +8,11 @@ from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator, WorkflowFieldValidator
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies
@ -73,7 +74,7 @@ async def upload_image(
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if workflow_raw is not None:
try:
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
@ -184,6 +185,18 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> Optional[WorkflowWithoutID]:
try:
return ApiDependencies.invoker.services.images.get_workflow(image_name)
except Exception:
raise HTTPException(status_code=404)
@images_router.api_route(
"/i/{image_name}/full",
methods=["GET", "HEAD"],

View File

@ -4,7 +4,7 @@
from hashlib import sha1
from random import randbytes
from typing import List, Optional
from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -12,30 +12,45 @@ from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordOrderBy,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
class ModelsList(BaseModel):
"""Return list of configs."""
models: list[AnyModelConfig]
models: List[AnyModelConfig]
model_config = ConfigDict(use_enum_values=True)
class ModelTagSet(BaseModel):
"""Return tags for a set of models."""
key: str
name: str
author: str
tags: Set[str]
@model_records_router.get(
"/",
operation_id="list_model_records",
@ -43,15 +58,25 @@ class ModelsList(BaseModel):
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[ModelFormat] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
)
)
else:
found_models.extend(record_store.search_by_attr(model_type=model_type))
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
return ModelsList(models=found_models)
@ -75,6 +100,59 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.get("/meta", operation_id="list_model_summary")
async def list_model_summary(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of models per page"),
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
) -> PaginatedResults[ModelSummary]:
"""Gets a page of model summary data."""
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
@model_records_router.get(
"/meta/i/{key}",
operation_id="get_model_metadata",
responses={
200: {"description": "Success"},
400: {"description": "Bad request"},
404: {"description": "No metadata available"},
},
)
async def get_model_metadata(
key: str = Path(description="Key of the model repo metadata to fetch."),
) -> Optional[AnyModelRepoMetadata]:
"""Get a model metadata object."""
record_store = ApiDependencies.invoker.services.model_records
result = record_store.get_metadata(key)
if not result:
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
return result
@model_records_router.get(
"/tags",
operation_id="list_tags",
)
async def list_tags() -> Set[str]:
"""Get a unique set of all the model tags."""
record_store = ApiDependencies.invoker.services.model_records
return record_store.list_tags()
@model_records_router.get(
"/tags/search",
operation_id="search_by_metadata_tags",
)
async def search_by_metadata_tags(
tags: Set[str] = Query(default=None, description="Tags to search for"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
results = record_store.search_by_metadata_tag(tags)
return ModelsList(models=results)
@model_records_router.patch(
"/i/{key}",
operation_id="update_model_record",
@ -117,12 +195,17 @@ async def update_model_record(
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""Delete Model"""
"""
Delete model record from database.
The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
"""
logger = ApiDependencies.invoker.services.logger
try:
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
installer = ApiDependencies.invoker.services.model_install
installer.delete(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
@ -143,9 +226,7 @@ async def del_model_record(
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.
"""
"""Add a model using the configuration information appropriate for its type."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_records
if config.key == "<NOKEY>":
@ -162,3 +243,175 @@ async def add_model_record(
# now fetch it out
return record_store.get_model(config.key)
@model_records_router.post(
"/import",
operation_id="import_model_record",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
`{
"type": "local",
"path": "/path/to/model",
"inplace": false
}`
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
`{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}`
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
`{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}`
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
"model_install_running"
"model_install_completed"
"model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_records_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""Return list of model install jobs."""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
return jobs
@model_records_router.get(
"/import/{id}",
operation_id="get_model_install_job",
responses={
200: {"description": "Success"},
404: {"description": "No such job"},
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
"""Return model install job corresponding to the given source."""
try:
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@model_records_router.delete(
"/import/{id}",
operation_id="cancel_model_install_job",
responses={
201: {"description": "The job was cancelled successfully"},
415: {"description": "No such job"},
},
status_code=201,
)
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
"""Cancel the model install job(s) corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_install
try:
job = installer.get_job_by_id(id)
except ValueError as e:
raise HTTPException(status_code=415, detail=str(e))
installer.cancel_job(job)
@model_records_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs() -> Response:
"""Prune all completed and errored jobs from the install job list."""
ApiDependencies.invoker.services.model_install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
204: {"description": "Model config record database resynced with files on disk"},
400: {"description": "Bad request"},
},
)
async def sync_models_to_config() -> Response:
"""
Traverse the models and autoimport directories.
Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
return Response(status_code=204)

View File

@ -23,10 +23,11 @@ class DynamicPromptsResponse(BaseModel):
)
async def parse_dynamicprompts(
prompt: str = Body(description="The prompt to parse with dynamicprompts"),
max_prompts: int = Body(default=1000, description="The max number of prompts to generate"),
max_prompts: int = Body(ge=1, le=10000, default=1000, description="The max number of prompts to generate"),
combinatorial: bool = Body(default=True, description="Whether to use the combinatorial generator"),
) -> DynamicPromptsResponse:
"""Creates a batch process"""
max_prompts = min(max_prompts, 10000)
generator: Union[RandomPromptGenerator, CombinatorialPromptGenerator]
try:
error: Optional[str] = None

View File

@ -1,7 +1,19 @@
from fastapi import APIRouter, Path
from typing import Optional
from fastapi import APIRouter, Body, HTTPException, Path, Query
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.baseinvocation import WorkflowField
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowCategory,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordOrderBy,
WorkflowWithoutID,
)
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
@ -10,11 +22,76 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
"/i/{workflow_id}",
operation_id="get_workflow",
responses={
200: {"model": WorkflowField},
200: {"model": WorkflowRecordDTO},
},
)
async def get_workflow(
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowField:
) -> WorkflowRecordDTO:
"""Gets a workflow"""
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
try:
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
@workflows_router.patch(
"/i/{workflow_id}",
operation_id="update_workflow",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def update_workflow(
workflow: Workflow = Body(description="The updated workflow", embed=True),
) -> WorkflowRecordDTO:
"""Updates a workflow"""
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
@workflows_router.delete(
"/i/{workflow_id}",
operation_id="delete_workflow",
)
async def delete_workflow(
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
@workflows_router.post(
"/",
operation_id="create_workflow",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def create_workflow(
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
@workflows_router.get(
"/",
operation_id="list_workflows",
responses={
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
},
)
async def list_workflows(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of workflows per page"),
order_by: WorkflowRecordOrderBy = Query(
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
),
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets a page of workflows"""
return ApiDependencies.invoker.services.workflow_records.get_many(
page=page, per_page=per_page, order_by=order_by, direction=direction, query=query, category=category
)

View File

@ -20,6 +20,7 @@ class SocketIO:
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
@ -28,10 +29,13 @@ class SocketIO:
room=event[1]["data"]["queue_id"],
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])

View File

@ -45,6 +45,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
app_info,
board_images,
boards,
download_queue,
images,
model_records,
models,
@ -75,7 +76,7 @@ mimetypes.add_type("text/css", ".css")
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
# Add event handler
event_handler_id: int = id(app)
@ -116,6 +117,7 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
@ -203,8 +205,8 @@ app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid a
def overridden_swagger() -> HTMLResponse:
return get_swagger_ui_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=app.title,
swagger_favicon_url="/static/docs/favicon.ico",
title=f"{app.title} - Swagger UI",
swagger_favicon_url="static/docs/invoke-favicon-docs.svg",
)
@ -212,25 +214,26 @@ def overridden_swagger() -> HTMLResponse:
def overridden_redoc() -> HTMLResponse:
return get_redoc_html(
openapi_url=app.openapi_url, # type: ignore [arg-type] # this is always a string
title=app.title,
redoc_favicon_url="/static/docs/favicon.ico",
title=f"{app.title} - Redoc",
redoc_favicon_url="static/docs/invoke-favicon-docs.svg",
)
web_root_path = Path(list(web_dir.__path__)[0])
# Only serve the UI if we it has a build
if (web_root_path / "dist").exists():
# Cannot add headers to StaticFiles, so we must serve index.html with a custom route
# Add cache-control: no-store header to prevent caching of index.html, which leads to broken UIs at release
@app.get("/", include_in_schema=False, name="ui_root")
def get_index() -> FileResponse:
return FileResponse(Path(web_root_path, "dist/index.html"), headers={"Cache-Control": "no-store"})
# Cannot add headers to StaticFiles, so we must serve index.html with a custom route
# Add cache-control: no-store header to prevent caching of index.html, which leads to broken UIs at release
@app.get("/", include_in_schema=False, name="ui_root")
def get_index() -> FileResponse:
return FileResponse(Path(web_root_path, "dist/index.html"), headers={"Cache-Control": "no-store"})
# Must mount *after* the other routes else it borks em
app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), name="assets")
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
# # Must mount *after* the other routes else it borks em
app.mount("/static", StaticFiles(directory=Path(web_root_path, "static/")), name="static") # docs favicon is in here
app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), name="assets")
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
def invoke_api() -> None:
@ -271,6 +274,8 @@ def invoke_api() -> None:
port=port,
loop="asyncio",
log_level=app_config.log_level,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import inspect
import re
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
@ -16,6 +17,7 @@ from pydantic.fields import FieldInfo, _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string
@ -37,6 +39,19 @@ class InvalidFieldError(TypeError):
pass
class Classification(str, Enum, metaclass=MetaEnum):
"""
The classification of an Invocation.
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
"""
Stable = "stable"
Beta = "beta"
Prototype = "prototype"
class Input(str, Enum, metaclass=MetaEnum):
"""
The type of input a field accepts.
@ -437,6 +452,7 @@ class UIConfigBase(BaseModel):
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
validate_assignment=True,
@ -452,6 +468,7 @@ class InvocationContext:
queue_id: str
queue_item_id: int
queue_batch_id: str
workflow: Optional[WorkflowWithoutID]
def __init__(
self,
@ -460,12 +477,14 @@ class InvocationContext:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
workflow: Optional[WorkflowWithoutID],
):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id
self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id
self.workflow = workflow
class BaseInvocationOutput(BaseModel):
@ -602,6 +621,7 @@ class BaseInvocation(ABC, BaseModel):
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
@ -705,8 +725,10 @@ class _Model(BaseModel):
pass
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
@ -775,6 +797,7 @@ def invocation(
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@ -785,6 +808,7 @@ def invocation(
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@ -805,11 +829,12 @@ def invocation(
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
module_name = cls.__module__.split(".")[0]
if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX):
cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0]
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
if is_custom_node:
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
else:
cls.UIConfig.node_pack = None
@ -903,24 +928,6 @@ def invocation_output(
return wrapper
class WorkflowField(RootModel):
"""
Pydantic model for workflows with custom root of type dict[str, Any].
Workflows are stored without a strict schema.
"""
root: dict[str, Any] = Field(description="The workflow")
WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
class MetadataField(RootModel):
"""
Pydantic model for metadata with custom root of type dict[str, Any].
@ -943,3 +950,13 @@ class WithMetadata(BaseModel):
orig_required=False,
).model_dump(exclude_none=True),
)
class WithWorkflow:
workflow = None
def __init_subclass__(cls) -> None:
logger.warn(
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
)
super().__init_subclass__()

View File

@ -1,4 +1,3 @@
import re
from dataclasses import dataclass
from typing import List, Optional, Union
@ -17,6 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -87,7 +87,7 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
@ -210,7 +210,7 @@ class SDXLPromptInvocationBase:
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_list.append(

View File

@ -24,9 +24,10 @@ from controlnet_aux import (
)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.shared.fields import FieldDescriptions
@ -39,7 +40,6 @@ from .baseinvocation import (
InvocationContext,
OutputField,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -76,17 +76,16 @@ class ControlField(BaseModel):
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
@ -96,17 +95,17 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.0")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet"
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
@ -114,6 +113,17 @@ class ControlNetInvocation(BaseInvocation):
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self) -> "ControlNetInvocation":
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
@ -129,7 +139,7 @@ class ControlNetInvocation(BaseInvocation):
# This invocation exists for other invocations to subclass it - do not register with @invocation!
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
"""Base class for invocations that preprocess images for ControlNet"""
image: ImageField = InputField(description="The image to process")
@ -153,7 +163,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
node_id=self.id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
"""Builds an ImageOutput and its ImageField"""
@ -173,7 +183,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
@ -196,7 +206,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
@ -225,7 +235,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
@ -247,7 +257,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
@ -270,7 +280,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
title="Openpose Processor",
tags=["controlnet", "openpose", "pose"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
@ -295,7 +305,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
@ -322,7 +332,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
@ -339,7 +349,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.1.0"
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
)
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
@ -362,7 +372,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.1.0"
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
)
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
@ -389,7 +399,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
@ -419,7 +429,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
@ -435,7 +445,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
@ -458,7 +468,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
@ -487,7 +497,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
@ -527,7 +537,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
@ -569,7 +579,7 @@ class SamDetectorReproducibleColors(SamDetector):
title="Color Map Processor",
tags=["controlnet"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image"""

View File

@ -6,7 +6,6 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@ -34,7 +33,7 @@ for d in Path(__file__).parent.iterdir():
continue
# load the module, appending adding a suffix to identify it as a custom node pack
spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute())
spec = spec_from_file_location(module_name, init.absolute())
if spec is None or spec.loader is None:
logger.warn(f"Could not load {init}")

View File

@ -8,11 +8,11 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.1.0")
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
class CvInpaintInvocation(BaseInvocation, WithMetadata):
"""Simple inpaint using opencv."""
image: ImageField = InputField(description="The image to inpaint")
@ -41,7 +41,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -17,7 +17,6 @@ from invokeai.app.invocations.baseinvocation import (
InvocationContext,
OutputField,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -438,8 +437,8 @@ def get_faces_list(
return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.1.0")
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
class FaceOffInvocation(BaseInvocation, WithMetadata):
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
image: ImageField = InputField(description="Image for face detection")
@ -508,7 +507,7 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
mask_dto = context.services.images.create(
@ -532,8 +531,8 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.1.0")
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
class FaceMaskInvocation(BaseInvocation, WithMetadata):
"""Face mask creation using mediapipe face detection"""
image: ImageField = InputField(description="Image to face detect")
@ -627,7 +626,7 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
mask_dto = context.services.images.create(
@ -650,9 +649,9 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.1.0"
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
)
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
image: ImageField = InputField(description="Image to face detect")
@ -716,7 +715,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -13,7 +13,15 @@ from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import (
BaseInvocation,
Classification,
Input,
InputField,
InvocationContext,
WithMetadata,
invocation,
)
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
@ -36,8 +44,14 @@ class ShowImageInvocation(BaseInvocation):
)
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.1.0")
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"blank_image",
title="Blank Image",
tags=["image"],
category="image",
version="1.2.0",
)
class BlankImageInvocation(BaseInvocation, WithMetadata):
"""Creates a blank image and forwards it to the pipeline"""
width: int = InputField(default=512, description="The width of the image")
@ -56,7 +70,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -66,8 +80,14 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.1.0")
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_crop",
title="Crop Image",
tags=["image", "crop"],
category="image",
version="1.2.0",
)
class ImageCropInvocation(BaseInvocation, WithMetadata):
"""Crops an image to a specified box. The box can be outside of the image."""
image: ImageField = InputField(description="The image to crop")
@ -90,7 +110,7 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -155,8 +175,14 @@ class CenterPadCropInvocation(BaseInvocation):
)
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0")
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_paste",
title="Paste Image",
tags=["image", "paste"],
category="image",
version="1.2.0",
)
class ImagePasteInvocation(BaseInvocation, WithMetadata):
"""Pastes an image into another image."""
base_image: ImageField = InputField(description="The base image")
@ -199,7 +225,7 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -209,8 +235,14 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.1.0")
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"tomask",
title="Mask from Alpha",
tags=["image", "mask"],
category="image",
version="1.2.0",
)
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
"""Extracts the alpha channel of an image as a mask."""
image: ImageField = InputField(description="The image to create the mask from")
@ -231,7 +263,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -241,8 +273,14 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.1.0")
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_mul",
title="Multiply Images",
tags=["image", "multiply"],
category="image",
version="1.2.0",
)
class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
image1: ImageField = InputField(description="The first image to multiply")
@ -262,7 +300,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -275,8 +313,14 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.1.0")
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_chan",
title="Extract Image Channel",
tags=["image", "channel"],
category="image",
version="1.2.0",
)
class ImageChannelInvocation(BaseInvocation, WithMetadata):
"""Gets a channel from an image."""
image: ImageField = InputField(description="The image to get the channel from")
@ -295,7 +339,7 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -308,8 +352,14 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.1.0")
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_conv",
title="Convert Image Mode",
tags=["image", "convert"],
category="image",
version="1.2.0",
)
class ImageConvertInvocation(BaseInvocation, WithMetadata):
"""Converts an image to a different mode."""
image: ImageField = InputField(description="The image to convert")
@ -328,7 +378,7 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -338,8 +388,14 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.1.0")
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_blur",
title="Blur Image",
tags=["image", "blur"],
category="image",
version="1.2.0",
)
class ImageBlurInvocation(BaseInvocation, WithMetadata):
"""Blurs an image"""
image: ImageField = InputField(description="The image to blur")
@ -363,7 +419,7 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -373,6 +429,64 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation(
"unsharp_mask",
title="Unsharp Mask",
tags=["image", "unsharp_mask"],
category="image",
version="1.2.0",
classification=Classification.Beta,
)
class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
"""Applies an unsharp mask filter to an image"""
image: ImageField = InputField(description="The image to use")
radius: float = InputField(gt=0, description="Unsharp mask radius", default=2)
strength: float = InputField(ge=0, description="Unsharp mask strength", default=50)
def pil_from_array(self, arr):
return Image.fromarray((arr * 255).astype("uint8"))
def array_from_pil(self, img):
return numpy.array(img) / 255
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
mode = image.mode
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
image = image.convert("RGB")
image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
image = self.array_from_pil(image)
image += (image - image_blurred) * (self.strength / 100.0)
image = numpy.clip(image, 0, 1)
image = self.pil_from_array(image)
image = image.convert(mode)
# Make the image RGBA if we had a source alpha channel
if alpha_channel is not None:
image.putalpha(alpha_channel)
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image.width,
height=image.height,
)
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",
@ -393,8 +507,14 @@ PIL_RESAMPLING_MAP = {
}
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.1.0")
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"img_resize",
title="Resize Image",
tags=["image", "resize"],
category="image",
version="1.2.0",
)
class ImageResizeInvocation(BaseInvocation, WithMetadata):
"""Resizes an image to specific dimensions"""
image: ImageField = InputField(description="The image to resize")
@ -420,7 +540,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -430,8 +550,14 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.1.0")
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"img_scale",
title="Scale Image",
tags=["image", "scale"],
category="image",
version="1.2.0",
)
class ImageScaleInvocation(BaseInvocation, WithMetadata):
"""Scales an image by a factor"""
image: ImageField = InputField(description="The image to scale")
@ -462,7 +588,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -472,8 +598,14 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.1.0")
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_lerp",
title="Lerp Image",
tags=["image", "lerp"],
category="image",
version="1.2.0",
)
class ImageLerpInvocation(BaseInvocation, WithMetadata):
"""Linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp")
@ -496,7 +628,7 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -506,8 +638,14 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.1.0")
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_ilerp",
title="Inverse Lerp Image",
tags=["image", "ilerp"],
category="image",
version="1.2.0",
)
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
"""Inverse linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp")
@ -530,7 +668,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -540,8 +678,14 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.1.0")
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"img_nsfw",
title="Blur NSFW Image",
tags=["image", "nsfw"],
category="image",
version="1.2.0",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
"""Add blur to NSFW-flagged images"""
image: ImageField = InputField(description="The image to check")
@ -566,7 +710,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -587,9 +731,9 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Add Invisible Watermark",
tags=["image", "watermark"],
category="image",
version="1.1.0",
version="1.2.0",
)
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
"""Add an invisible watermark to an image"""
image: ImageField = InputField(description="The image to check")
@ -606,7 +750,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -616,8 +760,14 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.1.0")
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"mask_edge",
title="Mask Edge",
tags=["image", "mask", "inpaint"],
category="image",
version="1.2.0",
)
class MaskEdgeInvocation(BaseInvocation, WithMetadata):
"""Applies an edge mask to an image"""
image: ImageField = InputField(description="The image to apply the mask to")
@ -652,7 +802,7 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -667,9 +817,9 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
title="Combine Masks",
tags=["image", "mask", "multiply"],
category="image",
version="1.1.0",
version="1.2.0",
)
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class MaskCombineInvocation(BaseInvocation, WithMetadata):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
mask1: ImageField = InputField(description="The first mask to combine")
@ -689,7 +839,7 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -699,8 +849,14 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.1.0")
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"color_correct",
title="Color Correct",
tags=["image", "color"],
category="image",
version="1.2.0",
)
class ColorCorrectInvocation(BaseInvocation, WithMetadata):
"""
Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image.
@ -800,7 +956,7 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -810,8 +966,14 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.1.0")
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_hue_adjust",
title="Adjust Image Hue",
tags=["image", "hue"],
category="image",
version="1.2.0",
)
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
"""Adjusts the Hue of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -840,7 +1002,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -913,9 +1075,9 @@ CHANNEL_FORMATS = {
"value",
],
category="image",
version="1.1.0",
version="1.2.0",
)
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
"""Add or subtract a value from a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -950,7 +1112,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -984,9 +1146,9 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"value",
],
category="image",
version="1.1.0",
version="1.2.0",
)
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
"""Scale a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -1025,7 +1187,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
workflow=context.workflow,
metadata=self.metadata,
)
@ -1043,10 +1205,10 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
title="Save Image",
tags=["primitives", "image"],
category="primitives",
version="1.1.0",
version="1.2.0",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class SaveImageInvocation(BaseInvocation, WithMetadata):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description=FieldDescriptions.image)
@ -1064,7 +1226,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -1082,7 +1244,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
version="1.0.1",
use_cache=False,
)
class LinearUIOutputInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
"""Handles Linear UI Image Outputting tasks."""
image: ImageField = InputField(description=FieldDescriptions.image)

View File

@ -13,7 +13,7 @@ from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
@ -118,8 +118,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class InfillColorInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image with a solid color"""
image: ImageField = InputField(description="The image to infill")
@ -144,7 +144,7 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -154,8 +154,8 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1")
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class InfillTileInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image with tiles of the image"""
image: ImageField = InputField(description="The image to infill")
@ -181,7 +181,7 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -192,9 +192,9 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0"
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
)
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
image: ImageField = InputField(description="The image to infill")
@ -235,7 +235,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -245,8 +245,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using the LaMa model"""
image: ImageField = InputField(description="The image to infill")
@ -264,7 +264,7 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -274,8 +274,8 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class CV2InfillInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using OpenCV Inpainting"""
image: ImageField = InputField(description="The image to infill")
@ -293,7 +293,7 @@ class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -2,7 +2,7 @@ import os
from builtins import float
from typing import List, Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -15,6 +15,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
@ -39,7 +40,6 @@ class IPAdapterField(BaseModel):
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
# weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@ -47,6 +47,17 @@ class IPAdapterField(BaseModel):
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
@ -54,7 +65,7 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.0")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
@ -64,18 +75,27 @@ class IPAdapterInvocation(BaseInvocation):
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField(
default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight"
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the IP-Adapter is first applied (% of total steps)"
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.model_info(

View File

@ -64,7 +64,6 @@ from .baseinvocation import (
OutputField,
UIType,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -221,7 +220,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.5.0",
version="1.5.1",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@ -280,7 +279,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
ui_order=7,
)
cfg_rescale_multiplier: float = InputField(
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
latents: Optional[LatentsField] = InputField(
default=None,
@ -802,9 +801,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.1.0",
version="1.2.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
"""Generates an image from latents."""
latents: LatentsField = InputField(
@ -886,7 +885,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -1,7 +1,6 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
import inspect
import re
# from contextlib import ExitStack
from typing import List, Literal, Union
@ -21,6 +20,7 @@ from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -31,7 +31,6 @@ from .baseinvocation import (
UIComponent,
UIType,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -79,7 +78,7 @@ class ONNXPromptInvocation(BaseInvocation):
]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
@ -326,9 +325,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
title="ONNX Latents to Image",
tags=["latents", "image", "vae", "onnx"],
category="image",
version="1.1.0",
version="1.2.0",
)
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
"""Generates an image from latents."""
latents: LatentsField = InputField(
@ -378,7 +377,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -1,6 +1,6 @@
from typing import Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -14,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.model_management.models.base import BaseModelType
@ -37,6 +38,17 @@ class T2IAdapterField(BaseModel):
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("t2i_adapter_output")
class T2IAdapterOutput(BaseInvocationOutput):
@ -44,7 +56,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0"
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.1"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
@ -61,7 +73,7 @@ class T2IAdapterInvocation(BaseInvocation):
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
@ -71,6 +83,17 @@ class T2IAdapterInvocation(BaseInvocation):
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
)
@field_validator("weight")
@classmethod
def validate_ip_adapter_weight(cls, v):
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
return T2IAdapterOutput(
t2i_adapter=T2IAdapterField(

View File

@ -1,3 +1,5 @@
from typing import Literal
import numpy as np
from PIL import Image
from pydantic import BaseModel
@ -5,17 +7,24 @@ from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
Input,
InputField,
InvocationContext,
OutputField,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.tiles import (
calc_tiles_even_split,
calc_tiles_min_overlap,
calc_tiles_with_overlap,
merge_tiles_with_linear_blending,
merge_tiles_with_seam_blending,
)
from invokeai.backend.tiles.utils import Tile
@ -29,7 +38,14 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.")
@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"calculate_image_tiles",
title="Calculate Image Tiles",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@ -56,6 +72,79 @@ class CalculateImageTilesInvocation(BaseInvocation):
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_even_split",
title="Calculate Image Tiles Even Split",
tags=["tiles"],
category="tiles",
version="1.1.0",
classification=Classification.Beta,
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
num_tiles_x: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the x axis",
)
num_tiles_y: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the y axis",
)
overlap: int = InputField(
default=128,
ge=0,
multiple_of=8,
description="The overlap, in pixels, between adjacent tiles.",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_even_split(
image_height=self.image_height,
image_width=self.image_width,
num_tiles_x=self.num_tiles_x,
num_tiles_y=self.num_tiles_y,
overlap=self.overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_min_overlap",
title="Calculate Image Tiles Minimum Overlap",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_min_overlap(
image_height=self.image_height,
image_width=self.image_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.min_overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation_output("tile_to_properties_output")
class TileToPropertiesOutput(BaseInvocationOutput):
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
@ -77,7 +166,14 @@ class TileToPropertiesOutput(BaseInvocationOutput):
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"tile_to_properties",
title="Tile to Properties",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class TileToPropertiesInvocation(BaseInvocation):
"""Split a Tile into its individual properties."""
@ -103,7 +199,14 @@ class PairTileImageOutput(BaseInvocationOutput):
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"pair_tile_image",
title="Pair Tile with Image",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class PairTileImageInvocation(BaseInvocation):
"""Pair an image with its tile properties."""
@ -122,13 +225,29 @@ class PairTileImageInvocation(BaseInvocation):
)
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.0.0")
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
BLEND_MODES = Literal["Linear", "Seam"]
@invocation(
"merge_tiles_to_image",
title="Merge Tiles to Image",
tags=["tiles"],
category="tiles",
version="1.1.0",
classification=Classification.Beta,
)
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
"""Merge multiple tile images into a single image."""
# Inputs
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
blend_mode: BLEND_MODES = InputField(
default="Seam",
description="blending type Linear or Seam",
input=Input.Direct,
)
blend_amount: int = InputField(
default=32,
ge=0,
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
)
@ -158,10 +277,18 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
channels = tile_np_images[0].shape[-1]
dtype = tile_np_images[0].dtype
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
if self.blend_mode == "Linear":
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
elif self.blend_mode == "Seam":
merge_tiles_with_seam_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
else:
raise ValueError(f"Unsupported blend mode: '{self.blend_mode}'.")
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
# Convert into a PIL image and save
pil_image = Image.fromarray(np_image)
image_dto = context.services.images.create(
@ -172,7 +299,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),

View File

@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
# TODO: Populate this from disk?
# TODO: Use model manager to load?
@ -29,8 +29,8 @@ if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.2.0")
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
class ESRGANInvocation(BaseInvocation, WithMetadata):
"""Upscales an image using RealESRGAN."""
image: ImageField = InputField(description="The input image")
@ -118,7 +118,7 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -0,0 +1,14 @@
from typing import Union
def validate_weights(weights: Union[float, list[float]]) -> None:
"""Validate that all control weights in the valid range"""
to_validate = weights if isinstance(weights, list) else [weights]
if any(i < -1 or i > 2 for i in to_validate):
raise ValueError("Control weights must be within -1 to 2 range")
def validate_begin_end_step(begin_step_percent: float, end_step_percent: float) -> None:
"""Validate that begin_step_percent is less than end_step_percent"""
if begin_step_percent >= end_step_percent:
raise ValueError("Begin step percent must be less than or equal to end step percent")

View File

@ -4,7 +4,7 @@ from typing import Optional, cast
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .board_image_records_base import BoardImageRecordStorageBase
@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,

View File

@ -3,7 +3,7 @@ import threading
from typing import Union, cast
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
from .board_records_base import BoardRecordStorageBase
@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()

View File

@ -1,6 +1,7 @@
"""
Init file for InvokeAI configure package
"""
"""Init file for InvokeAI configure package."""
from .config_base import PagingArgumentParser # noqa F401
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
from invokeai.app.services.config.config_common import PagingArgumentParser
from .config_default import InvokeAIAppConfig, get_invokeai_config
__all__ = ["InvokeAIAppConfig", "get_invokeai_config", "PagingArgumentParser"]

View File

@ -173,7 +173,7 @@ from __future__ import annotations
import os
from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter
@ -209,7 +209,7 @@ class InvokeAIAppConfig(InvokeAISettings):
"""Configuration object for InvokeAI App."""
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
singleton_init: ClassVar[Optional[Dict]] = None
singleton_init: ClassVar[Optional[Dict[str, Any]]] = None
# fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
@ -221,6 +221,9 @@ class InvokeAIAppConfig(InvokeAISettings):
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
# SSL options correspond to https://www.uvicorn.org/settings/#https
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
# FEATURES
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
@ -260,7 +263,7 @@ class InvokeAIAppConfig(InvokeAISettings):
# DEVICE
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
precision : Literal["auto", "float16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
precision : Literal["auto", "float16", "bfloat16", "float32", "autocast"] = Field(default="auto", description="Floating point precision", json_schema_extra=Categories.Device)
# GENERATION
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", json_schema_extra=Categories.Generation)
@ -298,8 +301,8 @@ class InvokeAIAppConfig(InvokeAISettings):
self,
argv: Optional[list[str]] = None,
conf: Optional[DictConfig] = None,
clobber=False,
):
clobber: Optional[bool] = False,
) -> None:
"""
Update settings with contents of init file, environment, and command-line settings.
@ -334,7 +337,7 @@ class InvokeAIAppConfig(InvokeAISettings):
)
@classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
def get_config(cls, **kwargs: Any) -> InvokeAIAppConfig:
"""Return a singleton InvokeAIAppConfig configuration object."""
if (
cls.singleton_config is None
@ -353,7 +356,7 @@ class InvokeAIAppConfig(InvokeAISettings):
else:
root = self.find_root().expanduser().absolute()
self.root = root # insulate ourselves from relative paths that may change
return root
return root.resolve()
@property
def root_dir(self) -> Path:
@ -383,17 +386,17 @@ class InvokeAIAppConfig(InvokeAISettings):
return db_dir / DB_FILE
@property
def model_conf_path(self) -> Optional[Path]:
def model_conf_path(self) -> Path:
"""Path to models configuration file."""
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self) -> Optional[Path]:
def legacy_conf_path(self) -> Path:
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
return self._resolve(self.legacy_conf_dir)
@property
def models_path(self) -> Optional[Path]:
def models_path(self) -> Path:
"""Path to the models directory."""
return self._resolve(self.models_dir)
@ -452,7 +455,7 @@ class InvokeAIAppConfig(InvokeAISettings):
return _find_root()
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
def get_invokeai_config(**kwargs: Any) -> InvokeAIAppConfig:
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
return InvokeAIAppConfig.get_config(**kwargs)

View File

@ -0,0 +1,12 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",
"DownloadJobStatus",
"UnknownJobIDException",
]

View File

@ -0,0 +1,262 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Model download service."""
from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
class DownloadJobStatus(str, Enum):
"""State of a download job."""
WAITING = "waiting" # not enqueued, will not run
RUNNING = "running" # actively downloading
COMPLETED = "completed" # finished running
CANCELLED = "cancelled" # user cancelled
ERROR = "error" # terminated with an error message
class DownloadJobCancelledException(Exception):
"""This exception is raised when a download job is cancelled."""
class UnknownJobIDException(Exception):
"""This exception is raised when an invalid job id is referened."""
class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
# set when an error occurs
error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error")
error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error")
# internal flag
_cancelled: bool = PrivateAttr(default=False)
# optional event handlers passed in on creation
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
# cancelled and the callbacks are private attributes in order to prevent
# them from being serialized and/or used in the Json Schema
@property
def cancelled(self) -> bool:
"""Call to cancel the job."""
return self._cancelled
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
@property
def running(self) -> bool:
"""Return true if the job is running."""
return self.status == DownloadJobStatus.RUNNING
@property
def errored(self) -> bool:
"""Return true if the job is errored."""
return self.status == DownloadJobStatus.ERROR
@property
def in_terminal_state(self) -> bool:
"""Return true if job has finished, one way or another."""
return self.status not in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]
@property
def on_start(self) -> Optional[DownloadEventHandler]:
"""Return the on_start event handler."""
return self._on_start
@property
def on_progress(self) -> Optional[DownloadEventHandler]:
"""Return the on_progress event handler."""
return self._on_progress
@property
def on_complete(self) -> Optional[DownloadEventHandler]:
"""Return the on_complete event handler."""
return self._on_complete
@property
def on_error(self) -> Optional[DownloadExceptionHandler]:
"""Return the on_error event handler."""
return self._on_error
@property
def on_cancelled(self) -> Optional[DownloadEventHandler]:
"""Return the on_cancelled event handler."""
return self._on_cancelled
def set_callbacks(
self,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> None:
"""Set the callbacks for download events."""
self._on_start = on_start
self._on_progress = on_progress
self._on_complete = on_complete
self._on_error = on_error
self._on_cancelled = on_cancelled
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@abstractmethod
def start(self, *args: Any, **kwargs: Any) -> None:
"""Start the download worker threads."""
@abstractmethod
def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop the download worker threads."""
@abstractmethod
def download(
self,
source: AnyHttpUrl,
dest: Path,
priority: int = 10,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> DownloadJob:
"""
Create and enqueue download job.
:param source: Source of the download as a URL.
:param dest: Path to download to. See below.
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A DownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object. Its behavior is:
1. If the path exists and is a directory, then the URL contents will be downloaded
into that directory using the filename indicated in the response's `Content-Disposition` field.
If no content-disposition is present, then the last component of the URL will be used (similar to
wget's behavior).
2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded
content.
3. If the path exists and is an existing file, then the downloader will try to resume the download from
the end of the existing file.
"""
pass
@abstractmethod
def submit_download_job(
self,
job: DownloadJob,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> None:
"""
Enqueue a download job.
:param job: The DownloadJob
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJob]:
"""
List active download jobs.
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJob:
"""
Return the DownloadJob corresponding to the integer ID.
:param id: ID of the DownloadJob.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def cancel_all_jobs(self) -> None:
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def prune_jobs(self) -> None:
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def cancel_job(self, job: DownloadJob) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def join(self) -> None:
"""Wait until all jobs are off the queue."""
pass

View File

@ -0,0 +1,437 @@
# Copyright (c) 2023, Lincoln D. Stein
"""Implementation of multithreaded download queue for invokeai."""
import os
import re
import threading
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadExceptionHandler,
DownloadJob,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
ServiceInactiveException,
UnknownJobIDException,
)
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
class DownloadQueueService(DownloadQueueServiceBase):
"""Class for queued download of models."""
def __init__(
self,
max_parallel_dl: int = 5,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._jobs = {}
self._next_job_id = 0
self._queue = PriorityQueue()
self._stop_event = threading.Event()
self._worker_pool = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
self._event_bus = event_bus
self._requests = requests_session or requests.Session()
self._accept_download_requests = False
self._max_parallel_dl = max_parallel_dl
def start(self, *args: Any, **kwargs: Any) -> None:
"""Start the download worker threads."""
with self._lock:
if self._worker_pool:
raise Exception("Attempt to start the download service twice")
self._stop_event.clear()
self._start_workers(self._max_parallel_dl)
self._accept_download_requests = True
def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop the download worker threads."""
with self._lock:
if not self._worker_pool:
raise Exception("Attempt to stop the download service before it was started")
self._accept_download_requests = False # reject attempts to add new jobs to queue
queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING]
active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING]
if queued_jobs:
self._logger.warning(f"Cancelling {len(queued_jobs)} queued downloads")
if active_jobs:
self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete")
with self._queue.mutex:
self._queue.queue.clear()
self.join() # wait for all active jobs to finish
self._stop_event.set()
self._worker_pool.clear()
def submit_download_job(
self,
job: DownloadJob,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> None:
"""Enqueue a download job."""
if not self._accept_download_requests:
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
with self._lock:
job.id = self._next_job_id
self._next_job_id += 1
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
def download(
self,
source: AnyHttpUrl,
dest: Path,
priority: int = 10,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> DownloadJob:
"""Create and enqueue a download job and return it."""
if not self._accept_download_requests:
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
job = DownloadJob(
source=source,
dest=dest,
priority=priority,
access_token=access_token,
)
self.submit_download_job(
job,
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
return job
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
def prune_jobs(self) -> None:
"""Prune completed and errored queue items from the job list."""
with self._lock:
to_delete = set()
for job_id, job in self._jobs.items():
if job.in_terminal_state:
to_delete.add(job_id)
for job_id in to_delete:
del self._jobs[job_id]
def id_to_job(self, id: int) -> DownloadJob:
"""Translate a job ID into a DownloadJob object."""
try:
return self._jobs[id]
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None:
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
job.cancel()
def cancel_all_jobs(self) -> None:
"""Cancel all jobs (those not in enqueued, running or paused state)."""
for job in self._jobs.values():
if not job.in_terminal_state:
self.cancel_job(job)
def _start_workers(self, max_workers: int) -> None:
"""Start the requested number of worker threads."""
self._stop_event.clear()
for i in range(0, max_workers): # noqa B007
worker = threading.Thread(target=self._download_next_item, daemon=True)
self._logger.debug(f"Download queue worker thread {worker.name} starting.")
worker.start()
self._worker_pool.add(worker)
def _download_next_item(self) -> None:
"""Worker thread gets next job on priority queue."""
done = False
while not done:
if self._stop_event.is_set():
done = True
continue
try:
job = self._queue.get(timeout=1)
except Empty:
continue
try:
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
finally:
job.job_ended = get_iso_timestamp()
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
# Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), headers=header, stream=True)
if not resp.ok:
raise HTTPError(resp.reason)
job.content_type = resp.headers.get("Content-Type")
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
if job.dest.is_dir():
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
if match := re.search('filename="(.+)"', resp.headers.get("Content-Disposition", "")):
remote_name = match.group(1)
if self._validate_filename(job.dest.as_posix(), remote_name):
file_name = remote_name
job.download_path = job.dest / file_name
else:
job.dest.parent.mkdir(parents=True, exist_ok=True)
job.download_path = job.dest
assert job.download_path
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():
raise OSError(f"[Errno 17] File {job.download_path} exists")
# append ".downloading" to the path
in_progress_path = self._in_progress_path(job.download_path)
# signal caller that the download is starting. At this point, key fields such as
# download_path and total_bytes will be populated. We call it here because the might
# discover that the local file is already complete and generate a COMPLETED status.
self._signal_job_started(job)
# "range not satisfiable" - local file is at least as large as the remote file
if resp.status_code == 416 or (content_length > 0 and job.bytes >= content_length):
self._logger.warning(f"{job.download_path}: complete file found. Skipping.")
return
# "partial content" - local file is smaller than remote file
elif resp.status_code == 206 or job.bytes > 0:
self._logger.warning(f"{job.download_path}: partial file found. Resuming")
# some other error
elif resp.status_code != 200:
raise HTTPError(resp.reason)
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
# DOWNLOAD LOOP
with open(in_progress_path, open_mode) as file:
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
if job.cancelled:
raise DownloadJobCancelledException("Job was cancelled at caller's request")
job.bytes += file.write(data)
if (job.bytes - last_report_bytes >= report_delta) or (job.bytes >= job.total_bytes):
last_report_bytes = job.bytes
self._signal_job_progress(job)
# if we get here we are done and can rename the file to the original dest
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
in_progress_path.rename(job.download_path)
def _validate_filename(self, directory: str, filename: str) -> bool:
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
pc_path_max = (
os.pathconf(directory, "PC_PATH_MAX") if hasattr(os, "pathconf") else 32767
) # hardcoded for windows with long names enabled
if "/" in filename:
return False
if filename.startswith(".."):
return False
if len(filename) > pc_name_max:
return False
if len(os.path.join(directory, filename)) > pc_path_max:
return False
return True
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return
job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
if job.on_error:
try:
job.on_error(job, excp)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus:
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
try:
if job.download_path:
partial_file = self._in_progress_path(job.download_path)
partial_file.unlink()
except OSError as excp:
self._logger.warning(excp)
# Example on_progress event handler to display a TQDM status bar
# Activate with:
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
class TqdmProgress(object):
"""TQDM-based progress bar object to use in on_progress handlers."""
_bars: Dict[int, tqdm] # the tqdm object
_last: Dict[int, int] # last bytes downloaded
def __init__(self) -> None: # noqa D107
self._bars = {}
self._last = {}
def update(self, job: DownloadJob) -> None: # noqa D102
job_id = job.id
# new job
if job_id not in self._bars:
assert job.download_path
dest = Path(job.download_path).name
self._bars[job_id] = tqdm(
desc=dest,
initial=0,
total=job.total_bytes,
unit="iB",
unit_scale=True,
)
self._last[job_id] = 0
self._bars[job_id].update(job.bytes - self._last[job_id])
self._last[job_id] = job.bytes

View File

@ -0,0 +1 @@
from .events_base import EventServiceBase # noqa F401

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional
from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
@ -16,6 +17,8 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
class EventServiceBase:
queue_event: str = "queue_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
@ -30,6 +33,20 @@ class EventServiceBase:
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
@ -313,3 +330,166 @@ class EventServiceBase:
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_install_downloading(
self,
source: str,
local_path: str,
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
:param source: Source of the model
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
},
)
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={
"source": source,
"total_bytes": total_bytes,
"key": key,
},
)
def emit_model_install_cancelled(self, source: str) -> None:
"""
Emit when an install job is cancelled.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source},
)
def emit_model_install_error(
self,
source: str,
error_type: str,
error: str,
) -> None:
"""
Emit when an install job encounters an exception.
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)

View File

@ -4,7 +4,8 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC):
@ -33,7 +34,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType,
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -43,3 +44,8 @@ class ImageFileStorageBase(ABC):
def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets the workflow of an image."""
pass

View File

@ -7,8 +7,9 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase
@ -56,7 +57,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType,
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -64,12 +65,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path = self.get_path(image_name)
pnginfo = PngImagePlugin.PngInfo()
info_dict = {}
if metadata is not None:
pnginfo.add_text("invokeai_metadata", metadata.model_dump_json())
metadata_json = metadata.model_dump_json()
info_dict["invokeai_metadata"] = metadata_json
pnginfo.add_text("invokeai_metadata", metadata_json)
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow.model_dump_json())
workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
image.save(
image_path,
"PNG",
@ -121,6 +129,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
return None
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]

View File

@ -75,6 +75,7 @@ class ImageRecordStorageBase(ABC):
image_category: ImageCategory,
width: int,
height: int,
has_workflow: bool,
is_intermediate: Optional[bool] = False,
starred: Optional[bool] = False,
session_id: Optional[str] = None,

View File

@ -100,6 +100,7 @@ IMAGE_DTO_COLS = ", ".join(
"height",
"session_id",
"node_id",
"has_workflow",
"is_intermediate",
"created_at",
"updated_at",
@ -145,6 +146,7 @@ class ImageRecord(BaseModelExcludeNull):
"""The node ID that generated this image, if it is a generated image."""
starred: bool = Field(description="Whether this image is starred.")
"""Whether this image is starred."""
has_workflow: bool = Field(description="Whether this image has a workflow.")
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
@ -188,6 +190,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
starred = image_dict.get("starred", False)
has_workflow = image_dict.get("has_workflow", False)
return ImageRecord(
image_name=image_name,
@ -202,4 +205,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at=deleted_at,
is_intermediate=is_intermediate,
starred=starred,
has_workflow=has_workflow,
)

View File

@ -5,7 +5,7 @@ from typing import Optional, Union, cast
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .image_records_base import ImageRecordStorageBase
from .image_records_common import (
@ -32,91 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "starred" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)
def get(self, image_name: str) -> ImageRecord:
try:
self._lock.acquire()
@ -408,6 +323,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_category: ImageCategory,
width: int,
height: int,
has_workflow: bool,
is_intermediate: Optional[bool] = False,
starred: Optional[bool] = False,
session_id: Optional[str] = None,
@ -429,9 +345,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id,
metadata,
is_intermediate,
starred
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@ -444,6 +361,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata_json,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()

View File

@ -3,7 +3,7 @@ from typing import Callable, Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecord,
@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC):
@ -51,7 +52,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
**kwargs
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
@ -86,6 +87,11 @@ class ImageServiceABC(ABC):
"""Gets an image's metadata."""
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""

View File

@ -24,11 +24,6 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
default=None, description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
workflow_id: Optional[str] = Field(
default=None,
description="The workflow that generated this image.",
)
"""The workflow that generated this image."""
def image_record_to_dto(
@ -36,7 +31,6 @@ def image_record_to_dto(
image_url: str,
thumbnail_url: str,
board_id: Optional[str],
workflow_id: Optional[str],
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
@ -44,5 +38,4 @@ def image_record_to_dto(
image_url=image_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
workflow_id=workflow_id,
)

View File

@ -2,9 +2,10 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from ..image_files.image_files_common import (
ImageFileDeleteException,
@ -42,7 +43,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
**kwargs
) -> ImageDTO:
if image_origin not in ResourceOrigin:
@ -56,12 +57,6 @@ class ImageService(ImageServiceABC):
(width, height) = image.size
try:
if workflow is not None:
created_workflow = self.__invoker.services.workflow_records.create(workflow)
workflow_id = created_workflow.model_dump()["id"]
else:
workflow_id = None
# TODO: Consider using a transaction here to ensure consistency between storage and database
self.__invoker.services.image_records.save(
# Non-nullable fields
@ -70,6 +65,7 @@ class ImageService(ImageServiceABC):
image_category=image_category,
width=width,
height=height,
has_workflow=workflow is not None,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
@ -80,8 +76,6 @@ class ImageService(ImageServiceABC):
)
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
if workflow_id is not None:
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
)
@ -145,7 +139,6 @@ class ImageService(ImageServiceABC):
image_url=self.__invoker.services.urls.get_image_url(image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
)
return image_dto
@ -166,18 +159,15 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
try:
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
if workflow_id is None:
return None
return self.__invoker.services.workflow_records.get(workflow_id)
except ImageRecordNotFoundException:
self.__invoker.services.logger.error("Image record not found")
return self.__invoker.services.image_files.get_workflow(image_name)
except ImageFileNotFoundException:
self.__invoker.services.logger.error("Image file not found")
raise
except Exception:
self.__invoker.services.logger.error("Problem getting image workflow")
raise
except Exception as e:
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
@ -225,7 +215,6 @@ class ImageService(ImageServiceABC):
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
)
for r in results.items
]

View File

@ -108,6 +108,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
queue_batch_id=queue_item.session_queue_batch_id,
workflow=queue_item.workflow,
)
)
@ -131,7 +132,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
source_node_id=source_node_id,
result=outputs.model_dump(),
)
self.__invoker.services.performance_statistics.log_stats()
except KeyboardInterrupt:
pass
@ -178,6 +178,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
workflow=queue_item.workflow,
invoke_all=True,
)
except Exception as e:
@ -193,6 +194,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id)
self.__invoker.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.session_queue_batch_id,
queue_item_id=queue_item.session_queue_item_id,

View File

@ -1,9 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
@ -15,5 +18,6 @@ class InvocationQueueItem(BaseModel):
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC
from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase
from .image_files.image_files_base import ImageFileStorageBase
from .image_records.image_records_base import ImageRecordStorageBase
@ -21,14 +22,14 @@ if TYPE_CHECKING:
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState, LibraryGraph
from .shared.graph import GraphExecutionState
from .urls.urls_base import UrlServiceBase
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@ -43,7 +44,6 @@ class InvocationServices:
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC"
image_records: "ImageRecordStorageBase"
image_files: "ImageFileStorageBase"
@ -51,6 +51,8 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
download_queue: "DownloadQueueServiceBase"
model_install: "ModelInstallServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -59,7 +61,6 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase"
names: "NameServiceBase"
urls: "UrlServiceBase"
workflow_image_records: "WorkflowImageRecordsStorageBase"
workflow_records: "WorkflowRecordsStorageBase"
def __init__(
@ -71,7 +72,6 @@ class InvocationServices:
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
graph_library: "ItemStorageABC[LibraryGraph]",
images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
@ -79,6 +79,8 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
download_queue: "DownloadQueueServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -87,7 +89,6 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
urls: "UrlServiceBase",
workflow_image_records: "WorkflowImageRecordsStorageBase",
workflow_records: "WorkflowRecordsStorageBase",
):
self.board_images = board_images
@ -97,7 +98,6 @@ class InvocationServices:
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.image_files = image_files
self.image_records = image_records
@ -105,6 +105,8 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.download_queue = download_queue
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue
@ -113,5 +115,4 @@ class InvocationServices:
self.invocation_cache = invocation_cache
self.names = names
self.urls = urls
self.workflow_image_records = workflow_image_records
self.workflow_records = workflow_records

View File

@ -30,23 +30,13 @@ writes to the system log is stored in InvocationServices.performance_statistics.
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from typing import Dict
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_common import NodeLog
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float
@abstractmethod
def __init__(self):
"""
@ -77,45 +67,8 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
def log_stats(self, graph_execution_state_id: str):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
pass

View File

@ -1,25 +1,84 @@
from dataclasses import dataclass, field
from typing import Dict
# size of GIG in bytes
GIG = 1073741824
from collections import defaultdict
from dataclasses import dataclass
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
class NodeExecutionStats:
"""Class for tracking execution stats of an invocation node."""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0
invocation_type: str
start_time: float # Seconds since the epoch.
end_time: float # Seconds since the epoch.
start_ram_gb: float # GB
end_ram_gb: float # GB
peak_vram_gb: float # GB
def total_time(self) -> float:
return self.end_time - self.start_time
@dataclass
class NodeLog:
"""Class for tracking node usage"""
class GraphExecutionStats:
"""Class for tracking execution stats of a graph."""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
def __init__(self):
self._node_stats_list: list[NodeExecutionStats] = []
def add_node_execution_stats(self, node_stats: NodeExecutionStats):
self._node_stats_list.append(node_stats)
def get_total_run_time(self) -> float:
"""Get the total time spent executing nodes in the graph."""
total = 0.0
for node_stats in self._node_stats_list:
total += node_stats.total_time()
return total
def get_first_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the first node in the graph (by start_time)."""
first_node = None
for node_stats in self._node_stats_list:
if first_node is None or node_stats.start_time < first_node.start_time:
first_node = node_stats
assert first_node is not None
return first_node
def get_last_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the last node in the graph (by end_time)."""
last_node = None
for node_stats in self._node_stats_list:
if last_node is None or node_stats.end_time > last_node.end_time:
last_node = node_stats
return last_node
def get_pretty_log(self, graph_execution_state_id: str) -> str:
log = f"Graph stats: {graph_execution_state_id}\n"
log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n"
# Log stats aggregated by node type.
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
for node_stats in self._node_stats_list:
node_stats_by_type[node_stats.invocation_type].append(node_stats)
for node_type, node_type_stats_list in node_stats_by_type.items():
num_calls = len(node_type_stats_list)
time_used = sum([n.total_time() for n in node_type_stats_list])
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n"
# Log stats for the entire graph.
log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n"
first_node = self.get_first_node_stats()
last_node = self.get_last_node_stats()
if first_node is not None and last_node is not None:
total_wall_time = last_node.end_time - first_node.start_time
ram_change = last_node.end_ram_gb - first_node.start_ram_gb
log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n"
log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n"
return log

View File

@ -1,5 +1,5 @@
import time
from typing import Dict
from contextlib import contextmanager
import psutil
import torch
@ -7,161 +7,119 @@ import torch
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.backend.model_management.model_cache import CacheStats
from .invocation_stats_base import InvocationStatsServiceBase
from .invocation_stats_common import GIG, NodeLog, NodeStats
from .invocation_stats_common import GraphExecutionStats, NodeExecutionStats
# Size of 1GB in bytes.
GB = 2**30
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
_invoker: Invoker
def __init__(self):
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
self._cache_stats: Dict[str, CacheStats] = {}
self.ram_used: float = 0.0
self.ram_changed: float = 0.0
# Maps graph_execution_state_id to GraphExecutionStats.
self._stats: dict[str, GraphExecutionStats] = {}
# Maps graph_execution_state_id to model manager CacheStats.
self._cache_stats: dict[str, CacheStats] = {}
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
class StatsContext:
"""Context manager for collecting statistics."""
invocation: BaseInvocation
collector: "InvocationStatsServiceBase"
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerServiceBase
def __init__(
self,
invocation: BaseInvocation,
graph_id: str,
model_manager: ModelManagerServiceBase,
collector: "InvocationStatsServiceBase",
):
"""Initialize statistics for this run."""
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0.0
self.ram_used = 0
self.model_manager = model_manager
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.ram_used = psutil.Process().memory_info().rss
if self.model_manager:
self.model_manager.collect_cache_stats(self.collector._cache_stats[self.graph_id])
def __exit__(self, *args):
"""Called on exit from the context."""
ram_used = psutil.Process().memory_info().rss
self.collector.update_mem_stats(
ram_used=ram_used / GIG,
ram_changed=(ram_used - self.ram_used) / GIG,
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type, # type: ignore # `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> StatsContext:
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
@contextmanager
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
if not self._stats.get(graph_execution_state_id):
# First time we're seeing this graph_execution_state_id.
self._stats[graph_execution_state_id] = GraphExecutionStats()
self._cache_stats[graph_execution_state_id] = CacheStats()
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
def reset_all_stats(self):
"""Zero all statistics"""
self._stats = {}
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
self._prune_stale_stats()
# Record state before the invocation.
start_time = time.time()
start_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if self._invoker.services.model_manager:
self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id])
def reset_stats(self, graph_execution_id: str):
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
# Let the invocation run.
yield None
finally:
# Record state after the invocation.
node_stats = NodeExecutionStats(
invocation_type=invocation.type,
start_time=start_time,
end_time=time.time(),
start_ram_gb=start_ram / GB,
end_ram_gb=psutil.Process().memory_info().rss / GB,
peak_vram_gb=torch.cuda.max_memory_allocated() / GB if torch.cuda.is_available() else 0.0,
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
self.ram_used = ram_used
self.ram_changed = ram_changed
def _prune_stale_stats(self):
"""Check all graphs being tracked and prune any that have completed/errored.
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
completed = set()
errored = set()
for graph_id, _node_log in self._stats.items():
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
for now we call this function periodically to prevent them from accumulating.
"""
to_prune = []
for graph_execution_state_id in self._stats:
try:
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
except Exception:
errored.add(graph_id)
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
continue
if not current_graph_state.is_complete():
if not graph_execution_state.is_complete():
# The graph is still running, don't prune it.
continue
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info(f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:>30} {stats.calls:>4} {stats.time_used:7.3f}s {stats.max_vram:4.3f}G")
total_time += stats.time_used
to_prune.append(graph_execution_state_id)
cache_stats = self._cache_stats[graph_id]
hwm = cache_stats.high_watermark / GIG
tot = cache_stats.cache_size / GIG
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GIG
for graph_execution_state_id in to_prune:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
logger.info("RAM used by InvokeAI process: " + "%4.2fG" % self.ram_used + f" ({self.ram_changed:+5.3f}G)")
logger.info(f"RAM used to load models: {loaded:4.2f}G")
if torch.cuda.is_available():
logger.info("VRAM in use: " + "%4.3fG" % (torch.cuda.memory_allocated() / GIG))
logger.info("RAM cache statistics:")
logger.info(f" Model cache hits: {cache_stats.hits}")
logger.info(f" Model cache misses: {cache_stats.misses}")
logger.info(f" Models cached: {cache_stats.in_cache}")
logger.info(f" Models cleared from cache: {cache_stats.cleared}")
logger.info(f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G")
if len(to_prune) > 0:
logger.info(f"Pruned stale graph stats for {to_prune}.")
completed.add(graph_id)
def reset_stats(self, graph_execution_state_id: str):
try:
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]
except KeyError as e:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.")
for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]
def log_stats(self, graph_execution_state_id: str):
try:
graph_stats = self._stats[graph_execution_state_id]
cache_stats = self._cache_stats[graph_execution_state_id]
except KeyError as e:
logger.warning(f"Attempted to log statistics for unknown graph {graph_execution_state_id}: {e}.")
return
for graph_id in errored:
del self._stats[graph_id]
del self._cache_stats[graph_id]
log = graph_stats.get_pretty_log(graph_execution_state_id)
hwm = cache_stats.high_watermark / GB
tot = cache_stats.cache_size / GB
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB
log += f"RAM used to load models: {loaded:4.2f}G\n"
if torch.cuda.is_available():
log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n"
log += "RAM cache statistics:\n"
log += f" Model cache hits: {cache_stats.hits}\n"
log += f" Model cache misses: {cache_stats.misses}\n"
log += f" Models cached: {cache_stats.in_cache}\n"
log += f" Models cleared from cache: {cache_stats.cleared}\n"
log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
logger.info(log)
del self._stats[graph_execution_state_id]
del self._cache_stats[graph_execution_state_id]

View File

@ -2,6 +2,8 @@
from typing import Optional
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from .invocation_queue.invocation_queue_common import InvocationQueueItem
from .invocation_services import InvocationServices
from .shared.graph import Graph, GraphExecutionState
@ -22,6 +24,7 @@ class Invoker:
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
workflow: Optional[WorkflowWithoutID] = None,
invoke_all: bool = False,
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
@ -43,6 +46,7 @@ class Invoker:
session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
workflow=workflow,
invoke_all=invoke_all,
)
)

View File

@ -5,7 +5,7 @@ from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, TypeAdapter
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .item_storage_base import ItemStorageABC

View File

@ -0,0 +1,27 @@
"""Initialization file for model install service package."""
from .model_install_base import (
CivitaiModelSource,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
URLModelSource,
)
from .model_install_default import ModelInstallService
__all__ = [
"ModelInstallServiceBase",
"ModelInstallService",
"InstallStatus",
"ModelInstallJob",
"UnknownInstallJobException",
"ModelSource",
"LocalModelSource",
"HFModelSource",
"URLModelSource",
"CivitaiModelSource",
]

View File

@ -0,0 +1,412 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class CivitaiModelSource(StringLikeSource):
"""A Civitai version id, with optional variant and access token."""
version_id: int
variant: Optional[ModelRepoVariant] = None
access_token: Optional[str] = None
type: Literal["civitai"] = "civitai"
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = str(self.version_id)
base += f" ({self.variant})" if self.variant else ""
return base
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
base += f":{self.subfolder}" if self.subfolder else ""
base += f" ({self.variant})" if self.variant else ""
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
]
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: Optional[int] = Field(
default=None, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.status = InstallStatus.ERROR
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
@property
def error(self) -> Optional[str]:
"""Error traceback."""
return "".join(traceback.format_exception(self._exception)) if self._exception else None
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@abstractmethod
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
metadata_store: ModelMetadataStore,
event_bus: Optional["EventServiceBase"] = None,
):
"""
Create ModelInstallService object.
:param config: Systemwide InvokeAIAppConfig.
:param store: Systemwide ModelConfigStore
:param event_bus: InvokeAI event bus for reporting events to.
"""
# make the invoker optional here because we don't need it and it
# makes the installer harder to use outside the web app
@abstractmethod
def start(self, invoker: Optional[Invoker] = None) -> None:
"""Start the installer service."""
@abstractmethod
def stop(self, invoker: Optional[Invoker] = None) -> None:
"""Stop the model install service. After this the objection can be safely deleted."""
@property
@abstractmethod
def app_config(self) -> InvokeAIAppConfig:
"""Return the appConfig object associated with the installer."""
@property
@abstractmethod
def record_store(self) -> ModelRecordServiceBase:
"""Return the ModelRecoreService object associated with the installer."""
@property
@abstractmethod
def event_bus(self) -> Optional[EventServiceBase]:
"""Return the event service base object associated with the installer."""
@abstractmethod
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe and register the model at model_path.
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@abstractmethod
def unregister(self, key: str) -> None:
"""Remove model with indicated key from the database."""
@abstractmethod
def delete(self, key: str) -> None:
"""Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
@abstractmethod
def unconditionally_delete(self, key: str) -> None:
"""Remove model with indicated key from the database and unconditionally delete weight files from disk."""
@abstractmethod
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe, register and install the model in the models directory.
This moves the model from its current location into
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@abstractmethod
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""Install the indicated model.
:param source: ModelSource object
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread and will issue the following events on the event bus:
- model_install_started
- model_install_error
- model_install_completed
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
The call returns a ModelInstallJob object which can be
polled to learn the current status and/or error message.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
@abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""
@abstractmethod
def get_job_by_id(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob corresponding to the provided id. Raises ValueError if no job has that ID."""
@abstractmethod
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
"""
List active and complete install jobs.
"""
@abstractmethod
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@abstractmethod
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
@abstractmethod
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
"""
Wait for all pending installs to complete.
This will block until all pending installs have
completed, been cancelled, or errored out.
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
installs do not complete within the indicated time.
"""
@abstractmethod
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
"""
Recursively scan directory for new models and register or install them.
:param scan_dir: Path to the directory to scan.
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
@abstractmethod
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""

Some files were not shown because too many files have changed in this diff Show More