Compare commits

..

43 Commits

Author SHA1 Message Date
2b1762d8da feat: add 'project' workflow category 2023-12-05 18:39:53 +11:00
b056a9c181 chore: bump ruff 2023-12-05 14:56:19 +11:00
f56c47b550 fix(db): remove extraneous lock 2023-12-05 14:49:00 +11:00
f8e35aec7b Merge remote-tracking branch 'origin/main' into feat/workflow-saving 2023-12-05 13:31:47 +11:00
10cf10c16c fix(db): fix bug with releasing without lock in db.clean() 2023-12-05 06:48:53 +11:00
7d4a78e470 fix(ui): fix circular dependency 2023-12-03 20:27:40 +11:00
37c87affd0 fix(ui): fix save/save-as workflow naming 2023-12-03 20:21:47 +11:00
3863bd9da3 fix(ui): fix workflow loading
- Different handling for loading from library vs external
- Fix bug where only nodes and edges loaded
2023-12-03 20:13:44 +11:00
4b2e3aa54d fix(ui): remove commented out property 2023-12-03 19:56:15 +11:00
d699efa5bc feat(ui): use custom hook in current image buttons
Already in use elsewhere, forgot to use it here.
2023-12-03 19:54:23 +11:00
b9a1374b8f feat(backend): workflow_records.get_many arg "filter_text" -> "query" 2023-12-03 19:22:16 +11:00
411ea75861 fix(backend): revert unrelated service organisational changes 2023-12-03 19:03:29 +11:00
375c9a1c20 fix: tidy up unused files, unrelated changes 2023-12-03 19:00:01 +11:00
907340b1e1 docs: update default workflows README 2023-12-03 18:54:42 +11:00
0f32d260b7 feat(ui): split out workflow redux state
The `nodes` slice is a rather complicated slice. Removing `workflow` makes it a bit more reasonable.

Also helps to flatten state out a bit.
2023-12-02 23:06:34 +11:00
92bc04dc87 feat(ui): tweak reset workflow editor translations 2023-12-02 21:49:50 +11:00
929b1f4a41 fix(db): fix mis-ordered db cleanup step
It was happening before pruning queue items - should happen afterwards, else you have to restart the app again to free disk space made available by the pruning.
2023-12-02 21:40:01 +11:00
6d7b4b8e8a feat(ui): clean up workflow library hooks 2023-12-02 21:01:18 +11:00
4a14ee0e01 fix(tests): fix tests 2023-12-02 20:00:07 +11:00
f268ea4e39 fix(workflow_records): typo 2023-12-02 19:54:15 +11:00
78face3481 feat(ui): refine workflow list UI 2023-12-02 19:52:17 +11:00
5a0e8261bf feat(workflows): update default workflows
- Update TextToImage_SD15
- Add TextToImage_SDXL
- Add README
2023-12-02 19:51:55 +11:00
0447fa2dcb Merge remote-tracking branch 'origin/main' into feat/workflow-saving 2023-12-02 13:55:19 +11:00
4fd163698c feat: simplify default workflows
- Rename "system" -> "default"
- Simplify syncing logic
- Update UI to match
2023-12-02 11:41:05 +11:00
224438a108 fix: merge conflicts 2023-12-01 23:05:52 +11:00
81d2d5abae Merge remote-tracking branch 'origin/main' into feat/workflow-saving 2023-12-01 23:03:01 +11:00
734e871e8f feat(backend): sync system workflows to db 2023-12-01 22:45:17 +11:00
b0350e9bc8 feat: workflow library - system graphs - wip 2023-12-01 19:36:38 +11:00
0a25efd054 feat: workflow library WIP
- Save to library
- Duplicate
- Filter/sort
- UI/queries
2023-12-01 15:24:22 +11:00
46905175a9 wip 2023-11-30 19:19:58 +11:00
11085783ef feat(ui): workflow library pagination button styles 2023-11-30 16:13:41 +11:00
3d57c14bb3 feat(ui): add workflow loading, deleting to workflow library UI 2023-11-29 23:39:10 +11:00
18f3190857 chore(ui): typegen 2023-11-29 23:38:39 +11:00
fcc056fe6a feat(backend): add WorkflowRecordListItemDTO
This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl
2023-11-29 23:37:11 +11:00
c1bfc1f47b fix(nodes): fix get_workflow from queue item dict func 2023-11-29 15:36:28 +11:00
14bf87e5e7 feat(nodes): restore WithWorkflow as no-op class
This class is deprecated and no longer needed. Set its workflow attr value to None (meaning it is now a no-op), and issue a warning when an invocation subclasses it.
2023-11-29 13:32:24 +11:00
715ce8538b feat(nodes): do not overwrite custom node module names
Use a different, simpler method to detect if a node is custom.
2023-11-29 13:30:52 +11:00
1987bc9cc5 feat: revert SQLiteMigrator class
Will pursue this in a separate PR.
2023-11-29 13:29:19 +11:00
0b079df4ae feat(ui): updated workflow handling (WIP)
Clientside updates for the backend workflow changes.

Includes roughed-out workflow library UI.
2023-11-29 12:42:10 +11:00
a514c9e28b feat(backend): update workflows handling
Update workflows handling for Workflow Library.

**Updated Workflow Storage**

"Embedded Workflows" are workflows associated with images, and are now only stored in the image files. "Library Workflows" are not associated with images, and are stored only in DB.

This works out nicely. We have always saved workflows to files, but recently began saving them to the DB in addition to in image files. When that happened, we stopped reading workflows from files, so all the workflows that only existed in images were inaccessible. With this change, access to those workflows is restored, and no workflows are lost.

**Updated Workflow Handling in Nodes**

Prior to this change, workflows were embedded in images by passing the whole workflow JSON to a special workflow field on a node. In the node's `invoke()` function, the node was able to access this workflow and save it with the image. This (inaccurately) models workflows as a property of an image and is rather awkward technically.

A workflow is now a property of a batch/session queue item. It is available in the InvocationContext and therefore available to all nodes during `invoke()`.

**Database Migrations**

Added a `SQLiteMigrator` class to handle database migrations. Migrations were needed to accomodate the DB-related changes in this PR. See the code for details.

The `images`, `workflows` and `session_queue` tables required migrations for this PR, and are using the new migrator. Other tables/services are still creating tables themselves. A followup PR will adapt them to use the migrator.

**Other/Support Changes**

- Add a `has_workflow` column to `images` table to indicate that the image has an embedded workflow.
- Add handling for retrieving the workflow from an image in python. The image file must be fetched, the workflow extracted, and then sent to client, avoiding needing the browser to parse the image file. With the `has_workflow` column, the UI knows if there is a workflow to be fetched, and only fetches when the user requests to load the workflow.
- Add route to get the workflow from an image
- Add CRUD service/routes for the library workflows
- `workflow_images` table and services removed (no longer needed now that embedded workflows are not in the DB)
2023-11-29 12:42:10 +11:00
8cf2806489 fix(workflow_records): fix SQLite workflow insertion to ignore duplicates 2023-11-29 12:42:10 +11:00
eb446471cc fix(ui): exclude public/en.json from prettier config 2023-11-29 12:42:10 +11:00
7392d07331 chore: bump pydantic to 2.5.2
This release fixes pydantic/pydantic#8175 and allows us to use `JsonValue`
2023-11-29 12:42:10 +11:00
464 changed files with 222452 additions and 39007 deletions

View File

@ -42,21 +42,6 @@ Please provide steps on how to test changes, any hardware or
software specifications as well as any other pertinent information.
-->
## Merge Plan
<!--
A merge plan describes how this PR should be handled after it is approved.
Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is merged"
A merge plan is particularly important for large PRs or PRs that touch the
database in any way.
-->
## Added/updated tests?
- [ ] Yes

View File

@ -22,22 +22,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Setup Node 18
uses: actions/setup-node@v4
uses: actions/setup-node@v3
with:
node-version: '18'
- name: Checkout
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install dependencies
run: 'pnpm install --prefer-frozen-lockfile'
- name: Typescript
run: 'pnpm run lint:tsc'
- name: Madge
run: 'pnpm run lint:madge'
- name: ESLint
run: 'pnpm run lint:eslint'
- name: Prettier
run: 'pnpm run lint:prettier'
- uses: actions/checkout@v3
- run: 'yarn install --frozen-lockfile'
- run: 'yarn run lint:tsc'
- run: 'yarn run lint:madge'
- run: 'yarn run lint:eslint'
- run: 'yarn run lint:prettier'

View File

@ -1,15 +1,13 @@
name: PyPI Release
on:
push:
paths:
- 'invokeai/version/invokeai_version.py'
workflow_dispatch:
inputs:
publish_package:
description: 'Publish build on PyPi? [true/false]'
required: true
default: 'false'
jobs:
build-and-release:
release:
if: github.repository == 'invoke-ai/InvokeAI'
runs-on: ubuntu-22.04
env:
@ -17,43 +15,19 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
TWINE_NON_INTERACTIVE: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: checkout sources
uses: actions/checkout@v3
- name: Setup Node 18
uses: actions/setup-node@v4
with:
node-version: '18'
- name: Setup pnpm
uses: pnpm/action-setup@v2
with:
version: '8.12.1'
- name: Install frontend dependencies
run: pnpm install --prefer-frozen-lockfile
working-directory: invokeai/frontend/web
- name: Build frontend
run: pnpm run build
working-directory: invokeai/frontend/web
- name: Install python dependencies
- name: install deps
run: pip install --upgrade build twine
- name: Build python package
- name: build package
run: python3 -m build
- name: Upload build as workflow artifact
uses: actions/upload-artifact@v4
with:
name: dist
path: dist
- name: Check distribution
- name: check distribution
run: twine check dist/*
- name: Check PyPI versions
- name: check PyPI versions
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
run: |
pip install --upgrade requests
@ -62,6 +36,6 @@ jobs:
EXISTS=scripts.pypi_helper.local_on_pypi(); \
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
- name: Publish build on PyPi
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
- name: upload package
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != ''
run: twine upload dist/*

3
.gitignore vendored
View File

@ -16,7 +16,7 @@ __pycache__/
.Python
build/
develop-eggs/
dist/
# dist/
downloads/
eggs/
.eggs/
@ -187,4 +187,3 @@ installer/install.bat
installer/install.sh
installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/

View File

@ -1,20 +1,6 @@
# simple Makefile with scripts that are otherwise hard to remember
# to use, run from the repo root `make <command>`
default: help
help:
@echo Developer commands:
@echo
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
ruff check . --fix
@ -32,21 +18,4 @@ mypy:
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
# (many files are ignored by the config, so this is useful for checking all files)
mypy-all:
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
# Build the frontend
frontend-build:
cd invokeai/frontend/web && pnpm build
# Run the frontend in dev mode
frontend-dev:
cd invokeai/frontend/web && pnpm dev
# Installer zip file
installer-zip:
cd installer && ./create_installer.sh
# Tag the release
tag-release:
cd installer && ./tag_release.sh
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports

View File

@ -125,8 +125,8 @@ and go to http://localhost:9090.
You must have Python 3.10 through 3.11 installed on your machine. Earlier or
later versions are not supported.
Node.js also needs to be installed along with `pnpm` (can be installed with
the command `npm install -g pnpm` if needed)
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
@ -270,7 +270,7 @@ upgrade script.** See the next section for a Windows recipe.
3. Select option [1] to upgrade to the latest release.
4. Once the upgrade is finished you will be returned to the launcher
menu. Select option [6] "Re-run the configure script to fix a broken
menu. Select option [7] "Re-run the configure script to fix a broken
install or to complete a major upgrade".
This will run the configure script against the v2.3 directory and

View File

@ -11,5 +11,5 @@ INVOKEAI_ROOT=
# HUGGING_FACE_HUB_TOKEN=
## optional variables specific to the docker setup.
# GPU_DRIVER=nvidia #| rocm
# GPU_DRIVER=cuda # or rocm
# CONTAINER_UID=1000

View File

@ -59,16 +59,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# #### Build the Web UI ------------------------------------
FROM node:18-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack enable
FROM node:18 AS web-builder
WORKDIR /build
COPY invokeai/frontend/web/ ./
RUN --mount=type=cache,target=/pnpm/store \
pnpm install --frozen-lockfile
RUN pnpm run build
RUN --mount=type=cache,target=/usr/lib/node_modules \
npm install --include dev
RUN --mount=type=cache,target=/usr/lib/node_modules \
yarn vite build
#### Runtime stage ---------------------------------------
@ -102,8 +100,6 @@ ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv/invokeai
ENV INVOKEAI_ROOT=/invokeai
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
# --link requires buldkit w/ dockerfile syntax 1.4
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
@ -121,7 +117,7 @@ WORKDIR ${INVOKEAI_SRC}
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
RUN python3 -c "from patchmatch import patch_match"
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R 1000:1000 ${INVOKEAI_ROOT}
COPY docker/docker-entrypoint.sh ./
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]

View File

@ -1,14 +1,6 @@
# InvokeAI Containerized
All commands should be run within the `docker` directory: `cd docker`
## Quickstart :rocket:
On a known working Linux+Docker+CUDA (Nvidia) system, execute `./run.sh` in this directory. It will take a few minutes - depending on your internet speed - to install the core models. Once the application starts up, open `http://localhost:9090` in your browser to Invoke!
For more configuration options (using an AMD GPU, custom root directory location, etc): read on.
## Detailed setup
All commands are to be run from the `docker` directory: `cd docker`
#### Linux
@ -26,12 +18,12 @@ For more configuration options (using an AMD GPU, custom root directory location
This is done via Docker Desktop preferences
### Configure Invoke environment
## Quickstart
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
a. the desired location of the InvokeAI runtime directory, or
b. an existing, v3.0.0 compatible runtime directory.
1. Execute `run.sh`
1. `docker compose up`
The image will be built automatically if needed.
@ -45,21 +37,19 @@ The runtime directory (holding models and outputs) will be created in the locati
The Docker daemon on the system must be already set up to use the GPU. In case of Linux, this involves installing `nvidia-docker-runtime` and configuring the `nvidia` runtime as default. Steps will be different for AMD. Please see Docker documentation for the most up-to-date instructions for using your GPU with Docker.
To use an AMD GPU, set `GPU_DRIVER=rocm` in your `.env` file.
## Customize
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `run.sh`, your custom values will be used.
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `docker compose up`, your custom values will be used.
You can also set these values in `docker-compose.yml` directly, but `.env` will help avoid conflicts when code is updated.
Values are optional, but setting `INVOKEAI_ROOT` is highly recommended. The default is `~/invokeai`. Example:
Example (values are optional, but setting `INVOKEAI_ROOT` is highly recommended):
```bash
INVOKEAI_ROOT=/Volumes/WorkDrive/invokeai
HUGGINGFACE_TOKEN=the_actual_token
CONTAINER_UID=1000
GPU_DRIVER=nvidia
GPU_DRIVER=cuda
```
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.

11
docker/build.sh Executable file
View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -e
build_args=""
[[ -f ".env" ]] && build_args=$(awk '$1 ~ /\=[^$]/ {print "--build-arg " $0 " "}' .env)
echo "docker compose build args:"
echo $build_args
docker compose build $build_args

View File

@ -2,8 +2,23 @@
version: '3.8'
x-invokeai: &invokeai
services:
invokeai:
image: "local/invokeai:latest"
# edit below to run on a container runtime other than nvidia-container-runtime.
# not yet tested with rocm/AMD GPUs
# Comment out the "deploy" section to run on CPU only
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# For AMD support, comment out the deploy section above and uncomment the devices section below:
#devices:
# - /dev/kfd:/dev/kfd
# - /dev/dri:/dev/dri
build:
context: ..
dockerfile: docker/Dockerfile
@ -35,27 +50,3 @@ x-invokeai: &invokeai
# - |
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
# invokeai-nodes-web --host 0.0.0.0
services:
invokeai-nvidia:
<<: *invokeai
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
invokeai-cpu:
<<: *invokeai
profiles:
- cpu
invokeai-rocm:
<<: *invokeai
devices:
- /dev/kfd:/dev/kfd
- /dev/dri:/dev/dri
profiles:
- rocm

View File

@ -1,32 +1,11 @@
#!/usr/bin/env bash
set -e -o pipefail
set -e
run() {
local scriptdir=$(dirname "${BASH_SOURCE[0]}")
cd "$scriptdir" || exit 1
# This script is provided for backwards compatibility with the old docker setup.
# it doesn't do much aside from wrapping the usual docker compose CLI.
local build_args=""
local profile=""
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
cd "$SCRIPTDIR" || exit 1
touch .env
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
[[ -z "$profile" ]] && profile="nvidia"
local service_name="invokeai-$profile"
if [[ ! -z "$build_args" ]]; then
printf "%s\n" "docker compose build args:"
printf "%s\n" "$build_args"
fi
docker compose build $build_args
unset build_args
printf "%s\n" "starting service $service_name"
docker compose --profile "$profile" up -d "$service_name"
docker compose logs -f
}
run
docker compose up -d
docker compose logs -f

View File

@ -1,277 +0,0 @@
# The InvokeAI Download Queue
The DownloadQueueService provides a multithreaded parallel download
queue for arbitrary URLs, with queue prioritization, event handling,
and restart capabilities.
## Simple Example
```
from invokeai.app.services.download import DownloadQueueService, TqdmProgress
download_queue = DownloadQueueService()
for url in ['https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true',
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true',
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png',
'https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor',
]:
# urls start downloading as soon as download() is called
download_queue.download(source=url,
dest='/tmp/downloads',
on_progress=TqdmProgress().update
)
download_queue.join() # wait for all downloads to finish
for job in download_queue.list_jobs():
print(job.model_dump_json(exclude_none=True, indent=4),"\n")
```
Output:
```
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true",
"dest": "/tmp/downloads",
"id": 0,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/a-painting-of-a-fire.png",
"job_started": "2023-12-04T05:34:41.742174",
"job_ended": "2023-12-04T05:34:42.592035",
"bytes": 666734,
"total_bytes": 666734
}
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true",
"dest": "/tmp/downloads",
"id": 1,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/birdhouse.png",
"job_started": "2023-12-04T05:34:41.741975",
"job_ended": "2023-12-04T05:34:42.652841",
"bytes": 774949,
"total_bytes": 774949
}
{
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png",
"dest": "/tmp/downloads",
"id": 2,
"priority": 10,
"status": "error",
"job_started": "2023-12-04T05:34:41.742079",
"job_ended": "2023-12-04T05:34:42.147625",
"bytes": 0,
"total_bytes": 0,
"error_type": "HTTPError(Not Found)",
"error": "Traceback (most recent call last):\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 182, in _download_next_item\n self._do_download(job)\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 206, in _do_download\n raise HTTPError(resp.reason)\nrequests.exceptions.HTTPError: Not Found\n"
}
{
"source": "https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor",
"dest": "/tmp/downloads",
"id": 3,
"priority": 10,
"status": "completed",
"download_path": "/tmp/downloads/xl_more_art-full_v1.safetensors",
"job_started": "2023-12-04T05:34:42.147645",
"job_ended": "2023-12-04T05:34:43.735990",
"bytes": 719020768,
"total_bytes": 719020768
}
```
## The API
The default download queue is `DownloadQueueService`, an
implementation of ABC `DownloadQueueServiceBase`. It juggles multiple
background download requests and provides facilities for interrogating
and cancelling the requests. Access to a current or past download task
is mediated via `DownloadJob` objects which report the current status
of a job request
### The Queue Object
A default download queue is located in
`ApiDependencies.invoker.services.download_queue`. However, you can
create additional instances if you need to isolate your queue from the
main one.
```
queue = DownloadQueueService(event_bus=events)
```
`DownloadQueueService()` takes three optional arguments:
| **Argument** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| `max_parallel_dl` | int | 5 | Maximum number of simultaneous downloads allowed |
| `event_bus` | EventServiceBase | None | System-wide FastAPI event bus for reporting download events |
| `requests_session` | requests.sessions.Session | None | An alternative requests Session object to use for the download |
`max_parallel_dl` specifies how many download jobs are allowed to run
simultaneously. Each will run in a different thread of execution.
`event_bus` is an EventServiceBase, typically the one created at
InvokeAI startup. If present, download events are periodically emitted
on this bus to allow clients to follow download progress.
`requests_session` is a url library requests Session object. It is
used for testing.
### The Job object
The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.
The only job type currently implemented is `DownloadJob`, a pydantic object with the
following fields:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `source` | AnyHttpUrl | | Where to download from |
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| `priority` | int | 10 | Job priority. Lower priorities run before higher priorities |
| |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the location of the downloaded file |
| `job_started` | float | | Timestamp for when the job started running |
| `job_ended` | float | | Timestamp for when the job completed or errored out |
| `job_sequence` | int | | A counter that is incremented each time a model is dequeued |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
When you create a job, you can assign it a `priority`. If multiple
jobs are queued, the job with the lowest priority runs first.
Every job has a `source` and a `dest`. `source` is a pydantic.networks AnyHttpUrl object.
The `dest` is a path on the local filesystem that specifies the
destination for the downloaded object. Its semantics are
described below.
When the job is submitted, it is assigned a numeric `id`. The id can
then be used to fetch the job object from the queue.
The `status` field is updated by the queue to indicate where the job
is in its lifecycle. Values are defined in the string enum
`DownloadJobStatus`, a symbol available from
`invokeai.app.services.download_manager`. Possible values are:
| **Value** | **String Value** | ** Description ** |
|--------------|---------------------|-------------------|
| `WAITING` | waiting | Job is on the queue but not yet running|
| `RUNNING` | running | The download is started |
| `COMPLETED` | completed | Job has finished its work without an error |
| `ERROR` | error | Job encountered an error and will not run again|
`job_started` and `job_ended` indicate when the job
was started (using a python timestamp) and when it completed.
In case of an error, the job's status will be set to `DownloadJobStatus.ERROR`, the text of the
Exception that caused the error will be placed in the `error_type`
field and the traceback that led to the error will be in `error`.
A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.
### Callbacks
Download jobs can be associated with a series of callbacks, each with
the signature `Callable[["DownloadJob"], None]`. The callbacks are assigned
using optional arguments `on_start`, `on_progress`, `on_complete` and
`on_error`. When the corresponding event occurs, the callback wil be
invoked and passed the job. The callback will be run in a `try:`
context in the same thread as the download job. Any exceptions that
occur during execution of the callback will be caught and converted
into a log error message, thereby allowing the download to continue.
#### `TqdmProgress`
The `invokeai.app.services.download.download_default` module defines a
class named `TqdmProgress` which can be used as an `on_progress`
handler to display a completion bar in the console. Use as follows:
```
from invokeai.app.services.download import TqdmProgress
download_queue.download(source='http://some.server.somewhere/some_file',
dest='/tmp/downloads',
on_progress=TqdmProgress().update
)
```
### Events
If the queue was initialized with the InvokeAI event bus (the case
when using `ApiDependencies.invoker.services.download_queue`), then
download events will also be issued on the bus. The events are:
* `download_started` -- This is issued when a job is taken off the
queue and a request is made to the remote server for the URL headers, but before any data
has been downloaded. The event payload will contain the keys `source`
and `download_path`. The latter contains the path that the URL will be
downloaded to.
* `download_progress -- This is issued periodically as the download
runs. The payload contains the keys `source`, `download_path`,
`current_bytes` and `total_bytes`. The latter two fields can be
used to display the percent complete.
* `download_complete` -- This is issued when the download completes
successfully. The payload contains the keys `source`, `download_path`
and `total_bytes`.
* `download_error` -- This is issued when the download stops because
of an error condition. The payload contains the fields `error_type`
and `error`. The former is the text representation of the exception,
and the latter is a traceback showing where the error occurred.
### Job control
To create a job call the queue's `download()` method. You can list all
jobs using `list_jobs()`, fetch a single job by its with
`id_to_job()`, cancel a running job with `cancel_job()`, cancel all
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.
#### job = queue.download(source, dest, priority, access_token)
Create a new download job and put it on the queue, returning the
DownloadJob object.
#### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s.
#### job = queue.id_to_job(id)
Return the job corresponding to given ID.
Return a list of all active and inactive `DownloadJob`s.
#### queue.prune_jobs()
Remove inactive (complete or errored) jobs from the listing returned
by `list_jobs()`.
#### queue.join()
Block until all pending jobs have run to completion or errored out.

View File

@ -10,36 +10,40 @@ model. These are the:
tracks the type of the model, its provenance, and where it can be
found on disk.
* _ModelLoadServiceBase_ Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
* _DownloadQueueServiceBase_ A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
queue has special methods for downloading repo_id folders from
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelInstallServiceBase_ A service for installing models to
disk. It uses `DownloadQueueServiceBase` to download models and
their metadata, and `ModelRecordServiceBase` to store that
information. It is also responsible for managing the InvokeAI
`models` directory and its contents.
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
A multithreaded downloader responsible
for downloading models from a remote source to disk. The download
queue has special methods for downloading repo_id folders from
Hugging Face, as well as discriminating among model versions in
Civitai, but can be used for arbitrary content.
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
Responsible for loading a model from disk
into RAM and VRAM and getting it ready for inference.
## Location of the Code
All four of these services can be found in
`invokeai/app/services` in the following directories:
* `invokeai/app/services/model_records/`
* `invokeai/app/services/downloads/`
* `invokeai/app/services/model_loader/`
* `invokeai/app/services/model_install/`
* `invokeai/app/services/model_loader/` (**under development**)
* `invokeai/app/services/downloads/`(**under development**)
With the exception of the install service, each of these is a thin
shell around a corresponding implementation located in
`invokeai/backend/model_manager`. The main difference between the
modules found in app services and those in the backend folder is that
the former add support for event reporting and are more tied to the
needs of the InvokeAI API.
Code related to the FastAPI web API can be found in
`invokeai/app/api/routers/model_records.py`.
`invokeai/app/api/routers/models.py`.
***
@ -161,6 +165,10 @@ of the fields, including `name`, `model_type` and `base_model`, are
shared between `ModelConfigBase` and `ModelBase`, and this is a
potential source of confusion.
** TO DO: ** The `ModelBase` code needs to be revised to reduce the
duplication of similar classes and to support using the `key` as the
primary model identifier.
## Reading and Writing Model Configuration Records
The `ModelRecordService` provides the ability to retrieve model
@ -354,7 +362,7 @@ model and pass its key to `get_model()`.
Several methods allow you to create and update stored model config
records.
#### add_model(key, config) -> AnyModelConfig:
#### add_model(key, config) -> ModelConfigBase:
Given a key and a configuration, this will add the model's
configuration record to the database. `config` can either be a subclass of
@ -378,356 +386,27 @@ fields to be updated. This will return an `AnyModelConfig` on success,
or raise `InvalidModelConfigException` or `UnknownModelException`
exceptions on failure.
***TO DO:*** Investigate why `update_model()` returns an
`AnyModelConfig` while `add_model()` returns a `ModelConfigBase`.
### rename_model(key, new_name) -> ModelConfigBase:
This is a special case of `update_model()` for the use case of
changing the model's name. It is broken out because there are cases in
which the InvokeAI application wants to synchronize the model's name
with its path in the `models` directory after changing the name, type
or base. However, when using the ModelRecordService directly, the call
is equivalent to:
```
store.rename_model(key, {'name': 'new_name'})
```
***TO DO:*** Investigate why `rename_model()` is returning a
`ModelConfigBase` while `update_model()` returns a `AnyModelConfig`.
***
## Model installation
The `ModelInstallService` class implements the
`ModelInstallServiceBase` abstract base class, and provides a one-stop
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
- Downloading a model from an arbitrary URL and installing it in
`models_dir` (_implementation pending_).
- Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link (_implementation pending_).
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16. (_implementation pending_)
### Initializing the installer
A default installer is created at InvokeAI api startup time and stored
in `ApiDependencies.invoker.services.model_install` and can
also be retrieved from an invocation's `context` argument with
`context.services.model_install`.
In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(config, store)
```
The full form of `ModelInstallService()` takes the following
required parameters:
| **Argument** | **Type** | **Description** |
|------------------|------------------------------|------------------------------|
| `config` | InvokeAIAppConfig | InvokeAI app configuration object |
| `record_store` | ModelRecordServiceBase | Config record storage database |
| `event_bus` | EventServiceBase | Optional event bus to send download/install progress events to |
Once initialized, the installer will provide the following methods:
#### install_job = installer.import_model()
The `import_model()` method is the core of the installer. The
following illustrates basic usage:
```
from invokeai.app.services.model_install import (
LocalModelSource,
HFModelSource,
URLModelSource,
)
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local diffusers folder
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
for source in [source1, source2, source3, source4, source5, source6, source7]:
install_job = installer.install_model(source)
source2job = installer.wait_for_installs()
for source in sources:
job = source2job[source]
if job.status == "completed":
model_config = job.config_out
model_key = model_config.key
print(f"{source} installed as {model_key}")
elif job.status == "error":
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
```
As shown here, the `import_model()` method accepts a variety of
sources, including local safetensors files, local diffusers folders,
HuggingFace repo_ids with and without a subfolder designation,
Civitai model URLs and arbitrary URLs that point to checkpoint files
(but not to folders).
Each call to `import_model()` return a `ModelInstallJob` job,
an object which tracks the progress of the install.
If a remote model is requested, the model's files are downloaded in
parallel across a multiple set of threads using the download
queue. During the download process, the `ModelInstallJob` is updated
to provide status and progress information. After the files (if any)
are downloaded, the remainder of the installation runs in a single
serialized background thread. These are the model probing, file
copying, and config record database update steps.
Multiple install jobs can be queued up. You may block until all
install jobs are completed (or errored) by calling the
`wait_for_installs()` method as shown in the code
example. `wait_for_installs()` will return a `dict` that maps the
requested source to its job. This object can be interrogated
to determine its status. If the job errored out, then the error type
and details can be recovered from `job.error_type` and `job.error`.
The full list of arguments to `import_model()` is as follows:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
| `inplace` | bool | True | Leave a local model in its current location |
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
| `access_token` | str | None | Provide authorization information needed to download |
The `inplace` field controls how local model Paths are handled. If
True (the default), then the model is simply registered in its current
location by the installer's `ModelConfigRecordService`. Otherwise, a
copy of the model put into the location specified by the `models_dir`
application configuration parameter.
The `variant` field is used for HuggingFace repo_ids only. If
provided, the repo_id download handler will look for and download
tensors files that follow the convention for the selected variant:
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
- "onnx" will select files ending with the suffix ".onnx"
- "openvino" will select files beginning with "openvino_model"
In the special case of the "fp16" variant, the installer will select
the 32-bit version of the files if the 16-bit version is unavailable.
`subfolder` is used for HuggingFace repo_ids only. If provided, the
model will be downloaded from the designated subfolder rather than the
top-level repository folder. If a subfolder is attached to the repo_id
using the format `repo_owner/repo_name:subfolder`, then the subfolder
specified by the repo_id will override the subfolder argument.
`config` can be used to override all or a portion of the configuration
attributes returned by the model prober. See the section below for
details.
`access_token` is passed to the download queue and used to access
repositories that require it.
#### Monitoring the install job process
When you create an install job with `import_model()`, it launches the
download and installation process in the background and returns a
`ModelInstallJob` object for monitoring the process.
The `ModelInstallJob` class has the following structure:
| **Attribute** | **Type** | **Description** |
|----------------|-----------------|------------------|
| `status` | `InstallStatus` | An enum of ["waiting", "running", "completed" and "error" |
| `config_in` | `dict` | Overriding configuration values provided by the caller |
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
| `error_type` | `str` | Name of the exception that led to an error status |
| `error` | `str` | Traceback of the error |
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as an event of type `EventServiceBase.model_event`, a timestamp and
the following event names:
- `model_install_started`
The payload will contain the keys `timestamp` and `source`. The latter
indicates the requested model source for installation.
- `model_install_progress`
Emitted at regular intervals when downloading a remote model, the
payload will contain the keys `timestamp`, `source`, `current_bytes`
and `total_bytes`. These events are _not_ emitted when a local model
already on the filesystem is imported.
- `model_install_completed`
Issued once at the end of a successful installation. The payload will
contain the keys `timestamp`, `source` and `key`, where `key` is the
ID under which the model has been registered.
- `model_install_error`
Emitted if the installation process fails for some reason. The payload
will contain the keys `timestamp`, `source`, `error_type` and
`error`. `error_type` is a short message indicating the nature of the
error, and `error` is the long traceback to help debug the problem.
#### Model confguration and probing
The install service uses the `invokeai.backend.model_manager.probe`
module during import to determine the model's type, base type, and
other configuration parameters. Among other things, it assigns a
default name and description for the model based on probed
fields.
When downloading remote models is implemented, additional
configuration information, such as list of trigger terms, will be
retrieved from the HuggingFace and Civitai model repositories.
The probed values can be overriden by providing a dictionary in the
optional `config` argument passed to `import_model()`. You may provide
overriding values for any of the model's configuration
attributes. Here is an example of setting the
`SchedulerPredictionType` and `name` for an sd-2 model:
This is typically used to set
the model's name and description, but can also be used to overcome
cases in which automatic probing is unable to (correctly) determine
the model's attribute. The most common situation is the
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
example of how it works:
```
install_job = installer.import_model(
source='stabilityai/stable-diffusion-2-1',
variant='fp16',
config=dict(
prediction_type=SchedulerPredictionType('v_prediction')
name='stable diffusion 2 base model',
)
)
```
### Other installer methods
This section describes additional methods provided by the installer class.
#### jobs = installer.wait_for_installs()
Block until all pending installs are completed or errored and then
returns a list of completed jobs.
#### jobs = installer.list_jobs([source])
Return a list of all active and complete `ModelInstallJobs`. An
optional `source` argument allows you to filter the returned list by a
model source string pattern using a partial string match.
#### jobs = installer.get_job(source)
Return a list of `ModelInstallJob` corresponding to the indicated
model source.
#### installer.prune_jobs
Remove non-pending jobs (completed or errored) from the job list
returned by `list_jobs()` and `get_job()`.
#### installer.app_config, installer.record_store,
installer.event_bus
Properties that provide access to the installer's `InvokeAIAppConfig`,
`ModelRecordServiceBase` and `EventServiceBase` objects.
#### key = installer.register_path(model_path, config), key = installer.install_path(model_path, config)
These methods bypass the download queue and directly register or
install the model at the indicated path, returning the unique ID for
the installed model.
Both methods accept a Path object corresponding to a checkpoint or
diffusers folder, and an optional dict of config attributes to use to
override the values derived from model probing.
The difference between `register_path()` and `install_path()` is that
the former creates a model configuration record without changing the
location of the model in the filesystem. The latter makes a copy of
the model inside the InvokeAI models directory before registering
it.
#### installer.unregister(key)
This will remove the model config record for the model at key, and is
equivalent to `installer.record_store.del_model(key)`
#### installer.delete(key)
This is similar to `unregister()` but has the additional effect of
conditionally deleting the underlying model file(s) if they reside
within the InvokeAI models directory
#### installer.unconditionally_delete(key)
This method is similar to `unregister()`, but also unconditionally
deletes the corresponding model weights file(s), regardless of whether
they are inside or outside the InvokeAI models hierarchy.
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
This method will recursively scan the directory indicated in
`scan_dir` for new models and either install them in the models
directory or register them in place, depending on the setting of
`install` (default False).
The return value is the list of keys of the new installed/registered
models.
#### installer.sync_to_config()
This method synchronizes models in the models directory and autoimport
directory to those in the `ModelConfigRecordService` database. New
models are registered and orphan models are unregistered.
#### installer.start(invoker)
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.
# The remainder of this documentation is provisional, pending implementation of the Download and Load services
## Let's get loaded, the lowdown on ModelLoadService
The `ModelLoadService` is responsible for loading a named model into
@ -1184,3 +863,351 @@ other resources that it might have been using.
This will start/pause/cancel all jobs that have been submitted to the
queue and have not yet reached a terminal state.
## Model installation
The `ModelInstallService` class implements the
`ModelInstallServiceBase` abstract base class, and provides a one-stop
shop for all your model install needs. It provides the following
functionality:
- Registering a model config record for a model already located on the
local filesystem, without moving it or changing its path.
- Installing a model alreadiy located on the local filesystem, by
moving it into the InvokeAI root directory under the
`models` folder (or wherever config parameter `models_dir`
specifies).
- Downloading a model from an arbitrary URL and installing it in
`models_dir`.
- Special handling for Civitai model URLs which allow the user to
paste in a model page's URL or download link. Any metadata provided
by Civitai, such as trigger terms, are captured and placed in the
model config record.
- Special handling for HuggingFace repo_ids to recursively download
the contents of the repository, paying attention to alternative
variants such as fp16.
- Probing of models to determine their type, base type and other key
information.
- Interface with the InvokeAI event bus to provide status updates on
the download, installation and registration process.
### Initializing the installer
A default installer is created at InvokeAI api startup time and stored
in `ApiDependencies.invoker.services.model_install_service` and can
also be retrieved from an invocation's `context` argument with
`context.services.model_install_service`.
In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download_manager import DownloadQueueServive
from invokeai.app.services.model_record_service import ModelRecordServiceBase
config = InvokeAI.get_config()
queue = DownloadQueueService()
store = ModelRecordServiceBase.open(config)
installer = ModelInstallService(config=config, queue=queue, store=store)
```
The full form of `ModelInstallService()` takes the following
parameters. Each parameter will default to a reasonable value, but it
is recommended that you set them explicitly as shown in the above example.
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `config` | InvokeAIAppConfig | Use system-wide config | InvokeAI app configuration object |
| `queue` | DownloadQueueServiceBase | Create a new download queue for internal use | Download queue |
| `store` | ModelRecordServiceBase | Use config to select the database to open | Config storage database |
| `event_bus` | EventServiceBase | None | An event bus to send download/install progress events to |
| `event_handlers` | List[DownloadEventHandler] | None | Event handlers for the download queue |
Note that if `store` is not provided, then the class will use
`ModelRecordServiceBase.open(config)` to select the database to use.
Once initialized, the installer will provide the following methods:
#### install_job = installer.install_model()
The `install_model()` method is the core of the installer. The
following illustrates basic usage:
```
sources = [
Path('/opt/models/sushi.safetensors'), # a local safetensors file
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
'runwayml/stable-diffusion-v1-5', # a repo_id
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
'https://civitai.com/api/download/models/63006', # a civitai direct download link
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
]
for source in sources:
install_job = installer.install_model(source)
source2key = installer.wait_for_installs()
for source in sources:
model_key = source2key[source]
print(f"{source} installed as {model_key}")
```
As shown here, the `install_model()` method accepts a variety of
sources, including local safetensors files, local diffusers folders,
HuggingFace repo_ids with and without a subfolder designation,
Civitai model URLs and arbitrary URLs that point to checkpoint files
(but not to folders).
Each call to `install_model()` will return a `ModelInstallJob` job, a
subclass of `DownloadJobBase`. The install job has additional
install-specific fields described in the next section.
Each install job will run in a series of background threads using
the object's download queue. You may block until all install jobs are
completed (or errored) by calling the `wait_for_installs()` method as
shown in the code example. `wait_for_installs()` will return a `dict`
that maps the requested source to the key of the installed model. In
the case that a model fails to download or install, its value in the
dict will be None. The actual cause of the error will be reported in
the corresponding job's `error` field.
Alternatively you may install event handlers and/or listen for events
on the InvokeAI event bus in order to monitor the progress of the
requested installs.
The full list of arguments to `model_install()` is as follows:
| **Argument** | **Type** | **Default** | **Description** |
|------------------|------------------------------|-------------|-------------------------------------------|
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
| `inplace` | bool | True | Leave a local model in its current location |
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
| `probe_override` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
| `metadata` | ModelSourceMetadata | None | Provide metadata that will be added to model's config |
| `access_token` | str | None | Provide authorization information needed to download |
| `priority` | int | 10 | Download queue priority for the job |
The `inplace` field controls how local model Paths are handled. If
True (the default), then the model is simply registered in its current
location by the installer's `ModelConfigRecordService`. Otherwise, the
model will be moved into the location specified by the `models_dir`
application configuration parameter.
The `variant` field is used for HuggingFace repo_ids only. If
provided, the repo_id download handler will look for and download
tensors files that follow the convention for the selected variant:
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
- "onnx" will select files ending with the suffix ".onnx"
- "openvino" will select files beginning with "openvino_model"
In the special case of the "fp16" variant, the installer will select
the 32-bit version of the files if the 16-bit version is unavailable.
`subfolder` is used for HuggingFace repo_ids only. If provided, the
model will be downloaded from the designated subfolder rather than the
top-level repository folder. If a subfolder is attached to the repo_id
using the format `repo_owner/repo_name:subfolder`, then the subfolder
specified by the repo_id will override the subfolder argument.
`probe_override` can be used to override all or a portion of the
attributes returned by the model prober. This can be used to overcome
cases in which automatic probing is unable to (correctly) determine
the model's attribute. The most common situation is the
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
example of how it works:
```
install_job = installer.install_model(
source='stabilityai/stable-diffusion-2-1',
variant='fp16',
probe_override=dict(
prediction_type=SchedulerPredictionType('v_prediction')
)
)
```
`metadata` allows you to attach custom metadata to the installed
model. See the next section for details.
`priority` and `access_token` are passed to the download queue and
have the same effect as they do for the DownloadQueueServiceBase.
#### Monitoring the install job process
When you create an install job with `model_install()`, events will be
passed to the list of `DownloadEventHandlers` provided at installer
initialization time. Event handlers can also be added to individual
model install jobs by calling their `add_handler()` method as
described earlier for the `DownloadQueueService`.
If the `event_bus` argument was provided, events will also be
broadcast to the InvokeAI event bus. The events will appear on the bus
as a singular event type named `model_event` with a payload of
`job`. You can then retrieve the job and check its status.
** TO DO: ** consider breaking `model_event` into
`model_install_started`, `model_install_completed`, etc. The event bus
features have not yet been tested with FastAPI/websockets, and it may
turn out that the job object is not serializable.
#### Model metadata and probing
The install service has special handling for HuggingFace and Civitai
URLs that capture metadata from the source and include it in the model
configuration record. For example, fetching the Civitai model 8765
will produce a config record similar to this (using YAML
representation):
```
5abc3ef8600b6c1cc058480eaae3091e:
path: sd-1/lora/to8contrast-1-5.safetensors
name: to8contrast-1-5
base_model: sd-1
model_type: lora
model_format: lycoris
key: 5abc3ef8600b6c1cc058480eaae3091e
hash: 5abc3ef8600b6c1cc058480eaae3091e
description: 'Trigger terms: to8contrast style'
author: theovercomer8
license: allowCommercialUse=Sell; allowDerivatives=True; allowNoCredit=True
source: https://civitai.com/models/8765?modelVersionId=10638
thumbnail_url: null
tags:
- model
- style
- portraits
```
For sources that do not provide model metadata, you can attach custom
fields by providing a `metadata` argument to `model_install()` using
an initialized `ModelSourceMetadata` object (available for import from
`model_install_service.py`):
```
from invokeai.app.services.model_install_service import ModelSourceMetadata
meta = ModelSourceMetadata(
name="my model",
author="Sushi Chef",
description="Highly customized model; trigger with 'sushi',"
license="mit",
thumbnail_url="http://s3.amazon.com/ljack/pics/sushi.png",
tags=list('sfw', 'food')
)
install_job = installer.install_model(
source='sushi_chef/model3',
variant='fp16',
metadata=meta,
)
```
It is not currently recommended to provide custom metadata when
installing from Civitai or HuggingFace source, as the metadata
provided by the source will overwrite the fields you provide. Instead,
after the model is installed you can use
`ModelRecordService.update_model()` to change the desired fields.
** TO DO: ** Change the logic so that the caller's metadata fields take
precedence over those provided by the source.
#### Other installer methods
This section describes additional, less-frequently-used attributes and
methods provided by the installer class.
##### installer.wait_for_installs()
This is equivalent to the `DownloadQueue` `join()` method. It will
block until all the active jobs in the install queue have reached a
terminal state (completed, errored or cancelled).
##### installer.queue, installer.store, installer.config
These attributes provide access to the `DownloadQueueServiceBase`,
`ModelConfigRecordServiceBase`, and `InvokeAIAppConfig` objects that
the installer uses.
For example, to temporarily pause all pending installations, you can
do this:
```
installer.queue.pause_all_jobs()
```
##### key = installer.register_path(model_path, overrides), key = installer.install_path(model_path, overrides)
These methods bypass the download queue and directly register or
install the model at the indicated path, returning the unique ID for
the installed model.
Both methods accept a Path object corresponding to a checkpoint or
diffusers folder, and an optional dict of attributes to use to
override the values derived from model probing.
The difference between `register_path()` and `install_path()` is that
the former will not move the model from its current position, while
the latter will move it into the `models_dir` hierarchy.
##### installer.unregister(key)
This will remove the model config record for the model at key, and is
equivalent to `installer.store.unregister(key)`
##### installer.delete(key)
This is similar to `unregister()` but has the additional effect of
deleting the underlying model file(s) -- even if they were outside the
`models_dir` directory!
##### installer.conditionally_delete(key)
This method will call `unregister()` if the model identified by `key`
is outside the `models_dir` hierarchy, and call `delete()` if the
model is inside.
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
This method will recursively scan the directory indicated in
`scan_dir` for new models and either install them in the models
directory or register them in place, depending on the setting of
`install` (default False).
The return value is the list of keys of the new installed/registered
models.
#### installer.scan_models_directory()
This method scans the models directory for new models and registers
them in place. Models that are present in the
`ModelConfigRecordService` database whose paths are not found will be
unregistered.
#### installer.sync_to_config()
This method synchronizes models in the models directory and autoimport
directory to those in the `ModelConfigRecordService` database. New
models are registered and orphan models are unregistered.
#### hash=installer.hash(model_path)
This method is calls the fasthash algorithm on a model's Path
(either a file or a folder) to generate a unique ID based on the
contents of the model.
##### installer.start(invoker)
The `start` method is called by the API intialization routines when
the API starts up. Its effect is to call `sync_to_config()` to
synchronize the model record store database with what's currently on
disk.
This method should not ordinarily be called manually.

View File

@ -46,18 +46,17 @@ We encourage you to ping @psychedelicious and @blessedcoolant on [Discord](http
```bash
node --version
```
2. Install [pnpm](https://pnpm.io/) and confirm it is installed by running this:
2. Install [yarn classic](https://classic.yarnpkg.com/lang/en/) and confirm it is installed by running this:
```bash
npm install --global pnpm
pnpm --version
npm install --global yarn
yarn --version
```
From `invokeai/frontend/web/` run `pnpm install` to get everything set up.
From `invokeai/frontend/web/` run `yarn install` to get everything set up.
Start everything in dev mode:
1. Ensure your virtual environment is running
2. Start the dev server: `pnpm dev`
2. Start the dev server: `yarn dev`
3. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
4. Point your browser to the dev server address e.g. [http://localhost:5173/](http://localhost:5173/)
@ -73,4 +72,4 @@ For a number of technical and logistical reasons, we need to commit UI build art
If you submit a PR, there is a good chance we will ask you to include a separate commit with a build of the app.
To build for production, run `pnpm build`.
To build for production, run `yarn build`.

View File

@ -154,16 +154,14 @@ groups in `invokeia.yaml`:
### Web Server
| Setting | Default Value | Description |
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
| Setting | Default Value | Description |
|----------|----------------|--------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].

View File

@ -293,19 +293,6 @@ manager, please follow these steps:
## Developer Install
!!! warning
InvokeAI uses a SQLite database. By running on `main`, you accept responsibility for your database. This
means making regular backups (especially before pulling) and/or fixing it yourself in the event that a
PR introduces a schema change.
If you don't need persistent backend storage, you can use an ephemeral in-memory database by setting
`use_memory_db: true` under `Path:` in your `invokeai.yaml` file.
If this is untenable, you should run the application via the official installer or a manual install of the
python package from pypi. These releases will not break your database.
If you have an interest in how InvokeAI works, or you would like to
add features or bugfixes, you are encouraged to install the source
code for InvokeAI. For this to work, you will need to install the
@ -401,5 +388,3 @@ environment variable INVOKEAI_ROOT to point to the installation directory.
Note that if you run into problems with the Conda installation, the InvokeAI
staff will **not** be able to help you out. Caveat Emptor!
[dev-chat]: https://discord.com/channels/1020123559063990373/1049495067846524939

View File

@ -1,10 +0,0 @@
document.addEventListener("DOMContentLoaded", function () {
var script = document.createElement("script");
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
script.setAttribute("data-website-id", "b5973bb1-476b-451e-8cf4-98de86745a10");
script.setAttribute("data-project-name", "Invoke.AI");
script.setAttribute("data-project-color", "#11213C");
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/113954515?s=280&v=4");
script.async = true;
document.head.appendChild(script);
});

View File

@ -13,12 +13,7 @@ If you'd prefer, you can also just download the whole node folder from the linke
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
- Community Nodes
+ [Adapters-Linked](#adapters-linked-nodes)
+ [Average Images](#average-images)
+ [Clean Image Artifacts After Cut](#clean-image-artifacts-after-cut)
+ [Close Color Mask](#close-color-mask)
+ [Clothing Mask](#clothing-mask)
+ [Contrast Limited Adaptive Histogram Equalization](#contrast-limited-adaptive-histogram-equalization)
+ [Depth Map from Wavefront OBJ](#depth-map-from-wavefront-obj)
+ [Film Grain](#film-grain)
+ [Generative Grammar-Based Prompt Nodes](#generative-grammar-based-prompt-nodes)
@ -27,23 +22,16 @@ To use a community workflow, download the the `.json` node graph file and load i
+ [Halftone](#halftone)
+ [Ideal Size](#ideal-size)
+ [Image and Mask Composition Pack](#image-and-mask-composition-pack)
+ [Image Dominant Color](#image-dominant-color)
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
+ [Image Picker](#image-picker)
+ [Image Resize Plus](#image-resize-plus)
+ [Load Video Frame](#load-video-frame)
+ [Make 3D](#make-3d)
+ [Mask Operations](#mask-operations)
+ [Match Histogram](#match-histogram)
+ [Metadata-Linked](#metadata-linked-nodes)
+ [Negative Image](#negative-image)
+ [Oobabooga](#oobabooga)
+ [Prompt Tools](#prompt-tools)
+ [Remote Image](#remote-image)
+ [Remove Background](#remove-background)
+ [Retroize](#retroize)
+ [Size Stepper Nodes](#size-stepper-nodes)
+ [Simple Skin Detection](#simple-skin-detection)
+ [Text font to Image](#text-font-to-image)
+ [Thresholding](#thresholding)
+ [Unsharp Mask](#unsharp-mask)
@ -53,19 +41,6 @@ To use a community workflow, download the the `.json` node graph file and load i
- [Help](#help)
--------------------------------
### Adapters Linked Nodes
**Description:** A set of nodes for linked adapters (ControlNet, IP-Adaptor & T2I-Adapter). This allows multiple adapters to be chained together without using a `collect` node which means it can be used inside an `iterate` node without any collecting on every iteration issues.
- `ControlNet-Linked` - Collects ControlNet info to pass to other nodes.
- `IP-Adapter-Linked` - Collects IP-Adapter info to pass to other nodes.
- `T2I-Adapter-Linked` - Collects T2I-Adapter info to pass to other nodes.
Note: These are inherited from the core nodes so any update to the core nodes should be reflected in these.
**Node Link:** https://github.com/skunkworxdark/adapters-linked-nodes
--------------------------------
### Average Images
@ -73,46 +48,6 @@ Note: These are inherited from the core nodes so any update to the core nodes sh
**Node Link:** https://github.com/JPPhoto/average-images-node
--------------------------------
### Clean Image Artifacts After Cut
Description: Removes residual artifacts after an image is separated from its background.
Node Link: https://github.com/VeyDlin/clean-artifact-after-cut-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clean-artifact-after-cut-node/master/.readme/node.png" width="500" />
--------------------------------
### Close Color Mask
Description: Generates a mask for images based on a closely matching color, useful for color-based selections.
Node Link: https://github.com/VeyDlin/close-color-mask-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/close-color-mask-node/master/.readme/node.png" width="500" />
--------------------------------
### Clothing Mask
Description: Employs a U2NET neural network trained for the segmentation of clothing items in images.
Node Link: https://github.com/VeyDlin/clothing-mask-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clothing-mask-node/master/.readme/node.png" width="500" />
--------------------------------
### Contrast Limited Adaptive Histogram Equalization
Description: Enhances local image contrast using adaptive histogram equalization with contrast limiting.
Node Link: https://github.com/VeyDlin/clahe-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/clahe-node/master/.readme/node.png" width="500" />
--------------------------------
### Depth Map from Wavefront OBJ
@ -229,16 +164,6 @@ This includes 15 Nodes:
</br><img src="https://raw.githubusercontent.com/dwringer/composition-nodes/main/composition_pack_overview.jpg" width="500" />
--------------------------------
### Image Dominant Color
Description: Identifies and extracts the dominant color from an image using k-means clustering.
Node Link: https://github.com/VeyDlin/image-dominant-color-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-dominant-color-node/master/.readme/node.png" width="500" />
--------------------------------
### Image to Character Art Image Nodes
@ -260,17 +185,6 @@ View:
**Node Link:** https://github.com/JPPhoto/image-picker-node
--------------------------------
### Image Resize Plus
Description: Provides various image resizing options such as fill, stretch, fit, center, and crop.
Node Link: https://github.com/VeyDlin/image-resize-plus-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
--------------------------------
### Load Video Frame
@ -295,16 +209,6 @@ View:
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png" width="300" />
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png" width="300" />
--------------------------------
### Mask Operations
Description: Offers logical operations (OR, SUB, AND) for combining and manipulating image masks.
Node Link: https://github.com/VeyDlin/mask-operations-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/mask-operations-node/master/.readme/node.png" width="500" />
--------------------------------
### Match Histogram
@ -322,30 +226,6 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
--------------------------------
### Metadata Linked Nodes
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
- `Metadata From Image` - Provides Metadata from an image.
- `Metadata To String` - Extracts a String value of a label from metadata.
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
- `Metadata To Float` - Extracts a Float value of a label from metadata.
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
--------------------------------
### Negative Image
Description: Creates a negative version of an image, effective for visual effects and mask inversion.
Node Link: https://github.com/VeyDlin/negative-image-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/negative-image-node/master/.readme/node.png" width="500" />
--------------------------------
### Oobabooga
@ -409,15 +289,6 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
**Node Link:** https://github.com/fieldOfView/InvokeAI-remote_image
--------------------------------
### Remove Background
Description: An integration of the rembg package to remove backgrounds from images using multiple U2NET models.
Node Link: https://github.com/VeyDlin/remove-background-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/remove-background-node/master/.readme/node.png" width="500" />
--------------------------------
### Retroize
@ -430,17 +301,6 @@ View:
<img src="https://github.com/Ar7ific1al/InvokeAI_nodes_retroize/assets/2306586/de8b4fa6-324c-4c2d-b36c-297600c73974" width="500" />
--------------------------------
### Simple Skin Detection
Description: Detects skin in images based on predefined color thresholds.
Node Link: https://github.com/VeyDlin/simple-skin-detection-node
View:
</br><img src="https://raw.githubusercontent.com/VeyDlin/simple-skin-detection-node/master/.readme/node.png" width="500" />
--------------------------------
### Size Stepper Nodes
@ -526,7 +386,6 @@ See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/READ
<img src="https://github.com/skunkworxdark/XYGrid_nodes/blob/main/images/collage.png" width="300" />
--------------------------------
### Example Node Template

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
{
"name": "Text to Image - SD1.5",
"name": "Text to Image",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 1.5/2",
"version": "1.1.0",
"version": "1.0.1",
"contact": "invoke@invoke.ai",
"tags": "text2image, SD1.5, SD2, default",
"notes": "",
@ -18,19 +18,10 @@
{
"nodeId": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"fieldName": "prompt"
},
{
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
"fieldName": "width"
},
{
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
"fieldName": "height"
}
],
"meta": {
"category": "default",
"version": "2.0.0"
"version": "1.0.0"
},
"nodes": [
{
@ -39,56 +30,44 @@
"data": {
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "compel",
"label": "Negative Compel Prompt",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
"name": "prompt",
"type": "string",
"fieldKind": "input",
"label": "Negative Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
"name": "clip",
"type": "ClipField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
"label": ""
}
},
"outputs": {
"conditioning": {
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
"type": "ConditioningField",
"fieldKind": "output"
}
}
},
"label": "Negative Compel Prompt",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
},
"width": 320,
"height": 259,
"height": 261,
"position": {
"x": 1000,
"y": 350
"x": 995.7263915923627,
"y": 239.67783573351227
}
},
{
@ -97,60 +76,37 @@
"data": {
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "noise",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"nodePack": "invokeai",
"inputs": {
"seed": {
"id": "6431737c-918a-425d-a3b4-5d57e2f35d4d",
"name": "seed",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"width": {
"id": "38fc5b66-fe6e-47c8-bba9-daf58e454ed7",
"name": "width",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"height": {
"id": "16298330-e2bf-4872-a514-d6923df53cbb",
"name": "height",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"use_cpu": {
"id": "c7c436d3-7a7a-4e76-91e4-c6deb271623c",
"name": "use_cpu",
"type": "boolean",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
@ -158,40 +114,35 @@
"noise": {
"id": "50f650dc-0184-4e23-a927-0497a96fe954",
"name": "noise",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
"type": "LatentsField",
"fieldKind": "output"
},
"width": {
"id": "bb8a452b-133d-42d1-ae4a-3843d7e4109a",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
"type": "integer",
"fieldKind": "output"
},
"height": {
"id": "35cfaa12-3b8b-4b7a-a884-327ff3abddd9",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
"type": "integer",
"fieldKind": "output"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
},
"width": 320,
"height": 388,
"height": 389,
"position": {
"x": 600,
"y": 325
"x": 993.4442117555518,
"y": 605.6757415334787
}
},
{
@ -200,24 +151,13 @@
"data": {
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"type": "main_model_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"model": {
"id": "993eabd2-40fd-44fe-bce7-5d0c7075ddab",
"name": "model",
"type": "MainModelField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MainModelField"
},
"value": {
"model_name": "stable-diffusion-v1-5",
"base_model": "sd-1",
@ -229,40 +169,35 @@
"unet": {
"id": "5c18c9db-328d-46d0-8cb9-143391c410be",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
"type": "UNetField",
"fieldKind": "output"
},
"clip": {
"id": "6effcac0-ec2f-4bf5-a49e-a2c29cf921f4",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
"type": "ClipField",
"fieldKind": "output"
},
"vae": {
"id": "57683ba3-f5f5-4f58-b9a2-4b83dacad4a1",
"name": "vae",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
"type": "VaeField",
"fieldKind": "output"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
},
"width": 320,
"height": 226,
"position": {
"x": 600,
"y": 25
"x": 163.04436745878343,
"y": 254.63156870373479
}
},
{
@ -271,56 +206,44 @@
"data": {
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "compel",
"label": "Positive Compel Prompt",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
"name": "prompt",
"type": "string",
"fieldKind": "input",
"label": "Positive Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": "Super cute tiger cub, national geographic award-winning photograph"
"value": ""
},
"clip": {
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
"name": "clip",
"type": "ClipField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
"label": ""
}
},
"outputs": {
"conditioning": {
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
"type": "ConditioningField",
"fieldKind": "output"
}
}
},
"label": "Positive Compel Prompt",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.0.0"
},
"width": 320,
"height": 259,
"height": 261,
"position": {
"x": 1000,
"y": 25
"x": 595.7263915923627,
"y": 239.67783573351227
}
},
{
@ -329,36 +252,21 @@
"data": {
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"type": "rand_int",
"label": "Random Seed",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": false,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"low": {
"id": "3ec65a37-60ba-4b6c-a0b2-553dd7a84b84",
"name": "low",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"high": {
"id": "085f853a-1a5f-494d-8bec-e4ba29a3f2d1",
"name": "high",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 2147483647
}
},
@ -366,20 +274,23 @@
"value": {
"id": "812ade4d-7699-4261-b9fc-a6c9d2ab55ee",
"name": "value",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
"type": "integer",
"fieldKind": "output"
}
}
},
"label": "Random Seed",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": false,
"version": "1.0.0"
},
"width": 320,
"height": 32,
"height": 218,
"position": {
"x": 600,
"y": 275
"x": 541.094822888628,
"y": 694.5704476446829
}
},
{
@ -388,224 +299,144 @@
"data": {
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "denoise_latents",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
"id": "90b7f4f8-ada7-4028-8100-d2e54f192052",
"name": "positive_conditioning",
"type": "ConditioningField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
"label": ""
},
"negative_conditioning": {
"id": "9393779e-796c-4f64-b740-902a1177bf53",
"name": "negative_conditioning",
"type": "ConditioningField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
"label": ""
},
"noise": {
"id": "8e17f1e5-4f98-40b1-b7f4-86aeeb4554c1",
"name": "noise",
"type": "LatentsField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
"label": ""
},
"steps": {
"id": "9b63302d-6bd2-42c9-ac13-9b1afb51af88",
"name": "steps",
"type": "integer",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 50
"value": 10
},
"cfg_scale": {
"id": "87dd04d3-870e-49e1-98bf-af003a810109",
"name": "cfg_scale",
"type": "FloatPolymorphic",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "FloatField"
},
"value": 7.5
},
"denoising_start": {
"id": "f369d80f-4931-4740-9bcd-9f0620719fab",
"name": "denoising_start",
"type": "float",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"denoising_end": {
"id": "747d10e5-6f02-445c-994c-0604d814de8c",
"name": "denoising_end",
"type": "float",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 1
},
"scheduler": {
"id": "1de84a4e-3a24-4ec8-862b-16ce49633b9b",
"name": "scheduler",
"type": "Scheduler",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SchedulerField"
},
"value": "unipc"
"value": "euler"
},
"unet": {
"id": "ffa6fef4-3ce2-4bdb-9296-9a834849489b",
"name": "unet",
"type": "UNetField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
"label": ""
},
"control": {
"id": "077b64cb-34be-4fcc-83f2-e399807a02bd",
"name": "control",
"type": "ControlPolymorphic",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "ControlField"
}
"label": ""
},
"ip_adapter": {
"id": "1d6948f7-3a65-4a65-a20c-768b287251aa",
"name": "ip_adapter",
"type": "IPAdapterPolymorphic",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "IPAdapterField"
}
"label": ""
},
"t2i_adapter": {
"id": "75e67b09-952f-4083-aaf4-6b804d690412",
"name": "t2i_adapter",
"type": "T2IAdapterPolymorphic",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "T2IAdapterField"
}
},
"cfg_rescale_multiplier": {
"id": "9101f0a6-5fe0-4826-b7b3-47e5d506826c",
"name": "cfg_rescale_multiplier",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
"label": ""
},
"latents": {
"id": "334d4ba3-5a99-4195-82c5-86fb3f4f7d43",
"name": "latents",
"type": "LatentsField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
"label": ""
},
"denoise_mask": {
"id": "0d3dbdbf-b014-4e95-8b18-ff2ff9cb0bfa",
"name": "denoise_mask",
"type": "DenoiseMaskField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "DenoiseMaskField"
}
"label": ""
}
},
"outputs": {
"latents": {
"id": "70fa5bbc-0c38-41bb-861a-74d6d78d2f38",
"name": "latents",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
"type": "LatentsField",
"fieldKind": "output"
},
"width": {
"id": "98ee0e6c-82aa-4e8f-8be5-dc5f00ee47f0",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
"type": "integer",
"fieldKind": "output"
},
"height": {
"id": "e8cb184a-5e1a-47c8-9695-4b8979564f5d",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
"type": "integer",
"fieldKind": "output"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": true,
"useCache": true,
"version": "1.4.0"
},
"width": 320,
"height": 703,
"height": 646,
"position": {
"x": 1400,
"y": 25
"x": 1476.5794704734735,
"y": 256.80174342731783
}
},
{
@ -614,185 +445,153 @@
"data": {
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "l2i",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": false,
"useCache": true,
"version": "1.2.0",
"nodePack": "invokeai",
"inputs": {
"metadata": {
"id": "ab375f12-0042-4410-9182-29e30db82c85",
"name": "metadata",
"type": "MetadataField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MetadataField"
}
"label": ""
},
"latents": {
"id": "3a7e7efd-bff5-47d7-9d48-615127afee78",
"name": "latents",
"type": "LatentsField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
"label": ""
},
"vae": {
"id": "a1f5f7a1-0795-4d58-b036-7820c0b0ef2b",
"name": "vae",
"type": "VaeField",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
"label": ""
},
"tiled": {
"id": "da52059a-0cee-4668-942f-519aa794d739",
"name": "tiled",
"type": "boolean",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
},
"fp32": {
"id": "c4841df3-b24e-4140-be3b-ccd454c2522c",
"name": "fp32",
"type": "boolean",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
"value": false
}
},
"outputs": {
"image": {
"id": "72d667d0-cf85-459d-abf2-28bd8b823fe7",
"name": "image",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ImageField"
}
"type": "ImageField",
"fieldKind": "output"
},
"width": {
"id": "c8c907d8-1066-49d1-b9a6-83bdcd53addc",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
"type": "integer",
"fieldKind": "output"
},
"height": {
"id": "230f359c-b4ea-436c-b372-332d7dcdca85",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
"type": "integer",
"fieldKind": "output"
}
}
},
"label": "",
"isOpen": true,
"notes": "",
"embedWorkflow": false,
"isIntermediate": false,
"useCache": true,
"version": "1.0.0"
},
"width": 320,
"height": 266,
"height": 267,
"position": {
"x": 1800,
"y": 25
"x": 2037.9648469717395,
"y": 426.10844427600136
}
}
],
"edges": [
{
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
"targetHandle": "seed",
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
"type": "default"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-7d8bf987-284f-413a-b2fd-d825445a5d6cclip",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"sourceHandle": "clip",
"target": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
"targetHandle": "clip",
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-7d8bf987-284f-413a-b2fd-d825445a5d6cclip",
"type": "default"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-93dc02a4-d05b-48ed-b99c-c9b616af3402clip",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"sourceHandle": "clip",
"target": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
"targetHandle": "clip",
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-93dc02a4-d05b-48ed-b99c-c9b616af3402clip",
"type": "default"
},
{
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-eea2702a-19fb-45b5-9d75-56b4211ec03cnoise",
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"targetHandle": "noise",
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-eea2702a-19fb-45b5-9d75-56b4211ec03cnoise",
"type": "default"
},
{
"id": "reactflow__edge-7d8bf987-284f-413a-b2fd-d825445a5d6cconditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cpositive_conditioning",
"source": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"targetHandle": "positive_conditioning",
"id": "reactflow__edge-7d8bf987-284f-413a-b2fd-d825445a5d6cconditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cpositive_conditioning",
"type": "default"
},
{
"id": "reactflow__edge-93dc02a4-d05b-48ed-b99c-c9b616af3402conditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cnegative_conditioning",
"source": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8unet-eea2702a-19fb-45b5-9d75-56b4211ec03cunet",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
"targetHandle": "negative_conditioning",
"id": "reactflow__edge-93dc02a4-d05b-48ed-b99c-c9b616af3402conditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cnegative_conditioning",
"type": "default"
},
{
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8vae-58c957f5-0d01-41fc-a803-b2bbf0413d4fvae",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"sourceHandle": "unet",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"targetHandle": "unet",
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8unet-eea2702a-19fb-45b5-9d75-56b4211ec03cunet",
"type": "default"
},
{
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"sourceHandle": "latents",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "default",
"targetHandle": "latents",
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
"type": "default"
},
{
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"sourceHandle": "vae",
"targetHandle": "vae"
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"targetHandle": "vae",
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8vae-58c957f5-0d01-41fc-a803-b2bbf0413d4fvae",
"type": "default"
}
]
}
}

View File

@ -2,72 +2,43 @@
set -e
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function is_bin_in_path {
builtin type -P "$1" &>/dev/null
}
function git_show {
git show -s --format='%h %s' $1
}
cd "$(dirname "$0")"
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
echo "The current working directory is $(pwd)"
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
echo
# Some machines only have `python3` in PATH, others have `python` - make an alias.
# We can use a function to approximate an alias within a non-interactive shell.
if ! is_bin_in_path python && is_bin_in_path python3; then
function python {
python3 "$@"
}
fi
if [[ -v "VIRTUAL_ENV" ]]; then
# we can't just call 'deactivate' because this function is not exported
# to the environment of this script from the bash process that runs the script
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
echo "A virtual environment is activated. Please deactivate it before proceeding".
exit -1
fi
VERSION=$(
cd ..
python -c "from invokeai.version import __version__ as version; print(version)"
)
VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
PATCH=""
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v3-latest"
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
echo Building installer for version $VERSION
echo "Be certain that you're in the 'installer' directory before continuing."
read -p "Press any key to continue, or CTRL-C to exit..."
# ---------------------- FRONTEND ----------------------
read -e -p "Tag this repo with '${VERSION}' and '${LATEST_TAG}'? [n]: " input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
pushd ../invokeai/frontend/web >/dev/null
echo
echo "Installing frontend dependencies..."
echo
pnpm i --frozen-lockfile
echo
echo "Building frontend..."
echo
pnpm build
popd
git push origin :refs/tags/$VERSION
if ! git tag -fa $VERSION ; then
echo "Existing/invalid tag"
exit -1
fi
# ---------------------- BACKEND ----------------------
git push origin :refs/tags/$LATEST_TAG
git tag -fa $LATEST_TAG
echo
echo "Building wheel..."
echo
echo "remember to push --tags!"
fi
# ----------------------
echo Building the wheel
# install the 'build' package in the user site packages, if needed
# could be improved by using a temporary venv, but it's tiny and harmless
@ -75,15 +46,12 @@ if [[ $(python -c 'from importlib.util import find_spec; print(find_spec("build"
pip install --user build
fi
rm -rf ../build
rm -r ../build
python -m build --wheel --outdir dist/ ../.
# ----------------------
echo
echo "Building installer zip files for InvokeAI ${VERSION}..."
echo
echo Building installer zip fles for InvokeAI $VERSION
# get rid of any old ones
rm -f *.zip
@ -91,11 +59,9 @@ rm -rf InvokeAI-Installer
# copy content
mkdir InvokeAI-Installer
for f in templates *.txt *.reg; do
for f in templates lib *.txt *.reg; do
cp -r ${f} InvokeAI-Installer/
done
mkdir InvokeAI-Installer/lib
cp lib/*.py InvokeAI-Installer/lib
# Move the wheel
mv dist/*.whl InvokeAI-Installer/lib/
@ -106,13 +72,13 @@ cp install.sh.in InvokeAI-Installer/install.sh
chmod a+x InvokeAI-Installer/install.sh
# Windows
perl -p -e "s/^set INVOKEAI_VERSION=.*/set INVOKEAI_VERSION=$VERSION/" install.bat.in >InvokeAI-Installer/install.bat
perl -p -e "s/^set INVOKEAI_VERSION=.*/set INVOKEAI_VERSION=$VERSION/" install.bat.in > InvokeAI-Installer/install.bat
cp WinLongPathsEnabled.reg InvokeAI-Installer/
# Zip everything up
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
# clean up
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
rm -rf InvokeAI-Installer tmp dist
exit 0

View File

@ -244,9 +244,9 @@ class InvokeAiInstance:
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch==2.1.2",
"torch==2.1.0",
"torchmetrics==0.11.4",
"torchvision>=0.16.2",
"torchvision>=0.14.1",
"--force-reinstall",
"--find-links" if find_links is not None else None,
find_links,

View File

@ -1,71 +0,0 @@
#!/bin/bash
set -e
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function does_tag_exist {
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
}
function git_show_ref {
git show-ref --dereference $1 --abbrev 7
}
function git_show {
git show -s --format='%h %s' $1
}
VERSION=$(
cd ..
python -c "from invokeai.version import __version__ as version; print(version)"
)
PATCH=""
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v${MAJOR_VERSION}-latest"
if does_tag_exist $VERSION; then
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
git_show_ref tags/$VERSION
echo
fi
if does_tag_exist $LATEST_TAG; then
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
git_show_ref tags/$LATEST_TAG
echo
fi
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
read -e -p 'y/n [n]: ' input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
echo
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
git push --delete origin $VERSION
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
if ! git tag -fa $VERSION; then
echo "Existing/invalid tag"
exit -1
fi
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
git push --delete origin $LATEST_TAG
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
git tag -fa $LATEST_TAG
echo -e "Pushing updated tags to remote..."
git push origin --tags
fi
exit 0

View File

@ -2,7 +2,6 @@
from logging import Logger
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -11,7 +10,6 @@ from ..services.board_images.board_images_default import BoardImagesService
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from ..services.boards.boards_default import BoardService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
@ -24,13 +22,14 @@ from ..services.invoker import Invoker
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
from ..services.model_install import ModelInstallService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.graph import GraphExecutionState
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -67,9 +66,8 @@ class ApiDependencies:
logger.debug(f"Internet connectivity is {config.internet_available}")
output_folder = config.output_path
image_files = DiskImageFileStorage(f"{output_folder}/images")
db = init_db(config=config, logger=logger, image_files=image_files)
db = SqliteDatabase(config, logger)
configuration = config
logger = logger
@ -80,16 +78,14 @@ class ApiDependencies:
boards = BoardService()
events = FastAPIEventService(event_handler_id)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
image_files = DiskImageFileStorage(f"{output_folder}/images")
image_records = SqliteImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
model_manager = ModelManagerService(config, logger)
model_record_service = ModelRecordServiceSQL(db=db)
download_queue_service = DownloadQueueService(event_bus=events)
model_install_service = ModelInstallService(
app_config=config, record_store=model_record_service, event_bus=events
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
processor = DefaultInvocationProcessor()
@ -107,6 +103,7 @@ class ApiDependencies:
configuration=configuration,
events=events,
graph_execution_manager=graph_execution_manager,
graph_library=graph_library,
image_files=image_files,
image_records=image_records,
images=images,
@ -115,8 +112,6 @@ class ApiDependencies:
logger=logger,
model_manager=model_manager,
model_records=model_record_service,
download_queue=download_queue_service,
model_install=model_install_service,
names=names,
performance_statistics=performance_statistics,
processor=processor,
@ -127,6 +122,8 @@ class ApiDependencies:
workflow_records=workflow_records,
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
db.clean()

View File

@ -1,111 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for the download queue."""
from typing import List, Optional
from fastapi import Body, Path, Response
from fastapi.routing import APIRouter
from pydantic.networks import AnyHttpUrl
from starlette.exceptions import HTTPException
from invokeai.app.services.download import (
DownloadJob,
UnknownJobIDException,
)
from ..dependencies import ApiDependencies
download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"])
@download_queue_router.get(
"/",
operation_id="list_downloads",
)
async def list_downloads() -> List[DownloadJob]:
"""Get a list of active and inactive jobs."""
queue = ApiDependencies.invoker.services.download_queue
return queue.list_jobs()
@download_queue_router.patch(
"/",
operation_id="prune_downloads",
responses={
204: {"description": "All completed jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_downloads():
"""Prune completed and errored jobs."""
queue = ApiDependencies.invoker.services.download_queue
queue.prune_jobs()
return Response(status_code=204)
@download_queue_router.post(
"/i/",
operation_id="download",
)
async def download(
source: AnyHttpUrl = Body(description="download source"),
dest: str = Body(description="download destination"),
priority: int = Body(default=10, description="queue priority"),
access_token: Optional[str] = Body(default=None, description="token for authorization to download"),
) -> DownloadJob:
"""Download the source URL to the file or directory indicted in dest."""
queue = ApiDependencies.invoker.services.download_queue
return queue.download(source, dest, priority, access_token)
@download_queue_router.get(
"/i/{id}",
operation_id="get_download_job",
responses={
200: {"description": "Success"},
404: {"description": "The requested download JobID could not be found"},
},
)
async def get_download_job(
id: int = Path(description="ID of the download job to fetch."),
) -> DownloadJob:
"""Get a download job using its ID."""
try:
job = ApiDependencies.invoker.services.download_queue.id_to_job(id)
return job
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
@download_queue_router.delete(
"/i/{id}",
operation_id="cancel_download_job",
responses={
204: {"description": "Job has been cancelled"},
404: {"description": "The requested download JobID could not be found"},
},
)
async def cancel_download_job(
id: int = Path(description="ID of the download job to cancel."),
):
"""Cancel a download job using its ID."""
try:
queue = ApiDependencies.invoker.services.download_queue
job = queue.id_to_job(id)
queue.cancel_job(job)
return Response(status_code=204)
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
@download_queue_router.delete(
"/i",
operation_id="cancel_all_download_jobs",
responses={
204: {"description": "Download jobs have been cancelled"},
},
)
async def cancel_all_download_jobs():
"""Cancel all download jobs."""
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
return Response(status_code=204)

View File

@ -4,7 +4,7 @@
from hashlib import sha1
from random import randbytes
from typing import Any, Dict, List, Optional
from typing import List, Optional
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -12,7 +12,6 @@ from pydantic import BaseModel, ConfigDict
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
@ -26,7 +25,7 @@ from invokeai.backend.model_manager.config import (
from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
class ModelsList(BaseModel):
@ -44,25 +43,15 @@ class ModelsList(BaseModel):
async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[str] = Query(
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
found_models: list[AnyModelConfig] = []
if base_models:
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
)
)
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
else:
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
found_models.extend(record_store.search_by_attr(model_type=model_type))
return ModelsList(models=found_models)
@ -128,17 +117,12 @@ async def update_model_record(
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
) -> Response:
"""
Delete model record from database.
The configuration record will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
"""
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
installer.delete(key)
record_store = ApiDependencies.invoker.services.model_records
record_store.del_model(key)
logger.info(f"Deleted model: {key}")
return Response(status_code=204)
except UnknownModelException as e:
@ -178,145 +162,3 @@ async def add_model_record(
# now fetch it out
return record_store.get_model(config.key)
@model_records_router.post(
"/import",
operation_id="import_model_record",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def import_model(
source: ModelSource,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
The source object is a discriminated Union of LocalModelSource,
HFModelSource and URLModelSource. Set the "type" field to the
appropriate value:
* To install a local path using LocalModelSource, pass a source of form:
`{
"type": "local",
"path": "/path/to/model",
"inplace": false
}`
The "inplace" flag, if true, will register the model in place in its
current filesystem location. Otherwise, the model will be copied
into the InvokeAI models directory.
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
`{
"type": "hf",
"repo_id": "stabilityai/stable-diffusion-2.0",
"variant": "fp16",
"subfolder": "vae",
"access_token": "f5820a918aaf01"
}`
The `variant`, `subfolder` and `access_token` fields are optional.
* To install a remote model using an arbitrary URL, pass:
`{
"type": "url",
"url": "http://www.civitai.com/models/123456",
"access_token": "f5820a918aaf01"
}`
The `access_token` field is optonal
The model's configuration record will be probed and filled in
automatically. To override the default guesses, pass "metadata"
with a Dict containing the attributes you wish to override.
Installation occurs in the background. Either use list_model_install_jobs()
to poll for completion, or listen on the event bus for the following events:
"model_install_started"
"model_install_completed"
"model_install_error"
On successful completion, the event's payload will contain the field "key"
containing the installed ID of the model. On an error, the event's payload
will contain the fields "error_type" and "error" describing the nature of the
error and its traceback, respectively.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_install
result: ModelInstallJob = installer.import_model(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_records_router.get(
"/import",
operation_id="list_model_install_jobs",
)
async def list_model_install_jobs() -> List[ModelInstallJob]:
"""
Return list of model install jobs.
If the optional 'source' argument is provided, then the list will be filtered
for partial string matches against the install source.
"""
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
return jobs
@model_records_router.patch(
"/import",
operation_id="prune_model_install_jobs",
responses={
204: {"description": "All completed and errored jobs have been pruned"},
400: {"description": "Bad request"},
},
)
async def prune_model_install_jobs() -> Response:
"""
Prune all completed and errored jobs from the install job list.
"""
ApiDependencies.invoker.services.model_install.prune_jobs()
return Response(status_code=204)
@model_records_router.patch(
"/sync",
operation_id="sync_models_to_config",
responses={
204: {"description": "Model config record database resynced with files on disk"},
400: {"description": "Bad request"},
},
)
async def sync_models_to_config() -> Response:
"""
Traverse the models and autoimport directories. Model files without a corresponding
record in the database are added. Orphan records without a models file are deleted.
"""
ApiDependencies.invoker.services.model_install.sync_to_config()
return Response(status_code=204)

View File

@ -20,7 +20,6 @@ class SocketIO:
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
@ -29,13 +28,10 @@ class SocketIO:
room=event[1]["data"]["queue_id"],
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])

View File

@ -45,7 +45,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
app_info,
board_images,
boards,
download_queue,
images,
model_records,
models,
@ -117,7 +116,6 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(utilities.utilities_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(model_records.model_records_router, prefix="/api")
app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
@ -221,19 +219,18 @@ def overridden_redoc() -> HTMLResponse:
web_root_path = Path(list(web_dir.__path__)[0])
# Only serve the UI if we it has a build
if (web_root_path / "dist").exists():
# Cannot add headers to StaticFiles, so we must serve index.html with a custom route
# Add cache-control: no-store header to prevent caching of index.html, which leads to broken UIs at release
@app.get("/", include_in_schema=False, name="ui_root")
def get_index() -> FileResponse:
return FileResponse(Path(web_root_path, "dist/index.html"), headers={"Cache-Control": "no-store"})
# # Must mount *after* the other routes else it borks em
app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), name="assets")
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
# Cannot add headers to StaticFiles, so we must serve index.html with a custom route
# Add cache-control: no-store header to prevent caching of index.html, which leads to broken UIs at release
@app.get("/", include_in_schema=False, name="ui_root")
def get_index() -> FileResponse:
return FileResponse(Path(web_root_path, "dist/index.html"), headers={"Cache-Control": "no-store"})
# # Must mount *after* the other routes else it borks em
app.mount("/static", StaticFiles(directory=Path(web_root_path, "static/")), name="static") # docs favicon is in here
app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), name="assets")
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
def invoke_api() -> None:
@ -274,8 +271,6 @@ def invoke_api() -> None:
port=port,
loop="asyncio",
log_level=app_config.log_level,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import inspect
import re
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
@ -39,19 +38,6 @@ class InvalidFieldError(TypeError):
pass
class Classification(str, Enum, metaclass=MetaEnum):
"""
The classification of an Invocation.
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
"""
Stable = "stable"
Beta = "beta"
Prototype = "prototype"
class Input(str, Enum, metaclass=MetaEnum):
"""
The type of input a field accepts.
@ -452,7 +438,6 @@ class UIConfigBase(BaseModel):
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
validate_assignment=True,
@ -621,7 +606,6 @@ class BaseInvocation(ABC, BaseModel):
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
@ -725,10 +709,8 @@ class _Model(BaseModel):
pass
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
# Get all pydantic model attrs, methods, etc
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
@ -797,7 +779,6 @@ def invocation(
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@ -808,7 +789,6 @@ def invocation(
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@ -829,7 +809,6 @@ def invocation(
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"

View File

@ -1,3 +1,4 @@
import re
from dataclasses import dataclass
from typing import List, Optional, Union
@ -16,7 +17,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -87,7 +87,7 @@ class CompelInvocation(BaseInvocation):
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt):
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
@ -210,7 +210,7 @@ class SDXLPromptInvocationBase:
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in extract_ti_triggers_from_prompt(prompt):
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1]
try:
ti_list.append(

View File

@ -13,15 +13,7 @@ from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import (
BaseInvocation,
Classification,
Input,
InputField,
InvocationContext,
WithMetadata,
invocation,
)
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
@ -121,11 +113,11 @@ class ImageCropInvocation(BaseInvocation, WithMetadata):
@invocation(
invocation_type="img_pad_crop",
title="Center Pad or Crop Image",
"img_paste",
title="Paste Image",
tags=["image", "paste"],
category="image",
tags=["image", "pad", "crop"],
version="1.0.0",
version="1.2.0",
)
class CenterPadCropInvocation(BaseInvocation):
"""Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image."""
@ -176,11 +168,11 @@ class CenterPadCropInvocation(BaseInvocation):
@invocation(
"img_paste",
title="Paste Image",
tags=["image", "paste"],
invocation_type="img_pad_crop",
title="Center Pad or Crop Image",
category="image",
version="1.2.0",
tags=["image", "pad", "crop"],
version="1.0.0",
)
class ImagePasteInvocation(BaseInvocation, WithMetadata):
"""Pastes an image into another image."""
@ -429,64 +421,6 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata):
)
@invocation(
"unsharp_mask",
title="Unsharp Mask",
tags=["image", "unsharp_mask"],
category="image",
version="1.2.0",
classification=Classification.Beta,
)
class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
"""Applies an unsharp mask filter to an image"""
image: ImageField = InputField(description="The image to use")
radius: float = InputField(gt=0, description="Unsharp mask radius", default=2)
strength: float = InputField(ge=0, description="Unsharp mask strength", default=50)
def pil_from_array(self, arr):
return Image.fromarray((arr * 255).astype("uint8"))
def array_from_pil(self, img):
return numpy.array(img) / 255
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
mode = image.mode
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
image = image.convert("RGB")
image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
image = self.array_from_pil(image)
image += (image - image_blurred) * (self.strength / 100.0)
image = numpy.clip(image, 0, 1)
image = self.pil_from_array(image)
image = image.convert(mode)
# Make the image RGBA if we had a source alpha channel
if alpha_channel is not None:
image.putalpha(alpha_channel)
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image.width,
height=image.height,
)
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
import inspect
import re
# from contextlib import ExitStack
from typing import List, Literal, Union
@ -20,7 +21,6 @@ from invokeai.backend import BaseModelType, ModelType, SubModelType
from ...backend.model_management import ONNXModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util import choose_torch_device
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -78,7 +78,7 @@ class ONNXPromptInvocation(BaseInvocation):
]
ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt):
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(

View File

@ -1,5 +1,3 @@
from typing import Literal
import numpy as np
from PIL import Image
from pydantic import BaseModel
@ -7,8 +5,6 @@ from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
Input,
InputField,
InvocationContext,
OutputField,
@ -18,13 +14,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.tiles.tiles import (
calc_tiles_even_split,
calc_tiles_min_overlap,
calc_tiles_with_overlap,
merge_tiles_with_linear_blending,
merge_tiles_with_seam_blending,
)
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
@ -38,14 +28,7 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.")
@invocation(
"calculate_image_tiles",
title="Calculate Image Tiles",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0")
class CalculateImageTilesInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
@ -72,79 +55,6 @@ class CalculateImageTilesInvocation(BaseInvocation):
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_even_split",
title="Calculate Image Tiles Even Split",
tags=["tiles"],
category="tiles",
version="1.1.0",
classification=Classification.Beta,
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
num_tiles_x: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the x axis",
)
num_tiles_y: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the y axis",
)
overlap: int = InputField(
default=128,
ge=0,
multiple_of=8,
description="The overlap, in pixels, between adjacent tiles.",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_even_split(
image_height=self.image_height,
image_width=self.image_width,
num_tiles_x=self.num_tiles_x,
num_tiles_y=self.num_tiles_y,
overlap=self.overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_min_overlap",
title="Calculate Image Tiles Minimum Overlap",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_min_overlap(
image_height=self.image_height,
image_width=self.image_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.min_overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation_output("tile_to_properties_output")
class TileToPropertiesOutput(BaseInvocationOutput):
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
@ -166,14 +76,7 @@ class TileToPropertiesOutput(BaseInvocationOutput):
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
@invocation(
"tile_to_properties",
title="Tile to Properties",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
class TileToPropertiesInvocation(BaseInvocation):
"""Split a Tile into its individual properties."""
@ -199,14 +102,7 @@ class PairTileImageOutput(BaseInvocationOutput):
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
@invocation(
"pair_tile_image",
title="Pair Tile with Image",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
class PairTileImageInvocation(BaseInvocation):
"""Pair an image with its tile properties."""
@ -225,29 +121,13 @@ class PairTileImageInvocation(BaseInvocation):
)
BLEND_MODES = Literal["Linear", "Seam"]
@invocation(
"merge_tiles_to_image",
title="Merge Tiles to Image",
tags=["tiles"],
category="tiles",
version="1.1.0",
classification=Classification.Beta,
)
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
"""Merge multiple tile images into a single image."""
# Inputs
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
blend_mode: BLEND_MODES = InputField(
default="Seam",
description="blending type Linear or Seam",
input=Input.Direct,
)
blend_amount: int = InputField(
default=32,
ge=0,
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
)
@ -277,18 +157,10 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
channels = tile_np_images[0].shape[-1]
dtype = tile_np_images[0].dtype
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
if self.blend_mode == "Linear":
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
elif self.blend_mode == "Seam":
merge_tiles_with_seam_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
else:
raise ValueError(f"Unsupported blend mode: '{self.blend_mode}'.")
# Convert into a PIL image and save
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
pil_image = Image.fromarray(np_image)
image_dto = context.services.images.create(

View File

@ -20,6 +20,63 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,

View File

@ -28,6 +28,52 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()

View File

@ -1,5 +1,6 @@
"""Init file for InvokeAI configure package."""
"""
Init file for InvokeAI configure package
"""
from .config_default import InvokeAIAppConfig, get_invokeai_config
__all__ = ["InvokeAIAppConfig", "get_invokeai_config"]
from .config_base import PagingArgumentParser # noqa F401
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401

View File

@ -173,7 +173,7 @@ from __future__ import annotations
import os
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
from omegaconf import DictConfig, OmegaConf
from pydantic import Field, TypeAdapter
@ -221,9 +221,6 @@ class InvokeAIAppConfig(InvokeAISettings):
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
# SSL options correspond to https://www.uvicorn.org/settings/#https
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
# FEATURES
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
@ -337,7 +334,7 @@ class InvokeAIAppConfig(InvokeAISettings):
)
@classmethod
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
"""Return a singleton InvokeAIAppConfig configuration object."""
if (
cls.singleton_config is None
@ -356,7 +353,7 @@ class InvokeAIAppConfig(InvokeAISettings):
else:
root = self.find_root().expanduser().absolute()
self.root = root # insulate ourselves from relative paths that may change
return root.resolve()
return root
@property
def root_dir(self) -> Path:
@ -386,17 +383,17 @@ class InvokeAIAppConfig(InvokeAISettings):
return db_dir / DB_FILE
@property
def model_conf_path(self) -> Path:
def model_conf_path(self) -> Optional[Path]:
"""Path to models configuration file."""
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self) -> Path:
def legacy_conf_path(self) -> Optional[Path]:
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
return self._resolve(self.legacy_conf_dir)
@property
def models_path(self) -> Path:
def models_path(self) -> Optional[Path]:
"""Path to the models directory."""
return self._resolve(self.models_dir)

View File

@ -1,12 +0,0 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",
"DownloadJobStatus",
"UnknownJobIDException",
]

View File

@ -1,217 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Model download service."""
from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
class DownloadJobStatus(str, Enum):
"""State of a download job."""
WAITING = "waiting" # not enqueued, will not run
RUNNING = "running" # actively downloading
COMPLETED = "completed" # finished running
CANCELLED = "cancelled" # user cancelled
ERROR = "error" # terminated with an error message
class DownloadJobCancelledException(Exception):
"""This exception is raised when a download job is cancelled."""
class UnknownJobIDException(Exception):
"""This exception is raised when an invalid job id is referened."""
class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
# set when an error occurs
error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error")
error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error")
# internal flag
_cancelled: bool = PrivateAttr(default=False)
# optional event handlers passed in on creation
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadEventHandler] = PrivateAttr(default=None)
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
# cancelled and the callbacks are private attributes in order to prevent
# them from being serialized and/or used in the Json Schema
@property
def cancelled(self) -> bool:
"""Call to cancel the job."""
return self._cancelled
@property
def on_start(self) -> Optional[DownloadEventHandler]:
"""Return the on_start event handler."""
return self._on_start
@property
def on_progress(self) -> Optional[DownloadEventHandler]:
"""Return the on_progress event handler."""
return self._on_progress
@property
def on_complete(self) -> Optional[DownloadEventHandler]:
"""Return the on_complete event handler."""
return self._on_complete
@property
def on_error(self) -> Optional[DownloadEventHandler]:
"""Return the on_error event handler."""
return self._on_error
@property
def on_cancelled(self) -> Optional[DownloadEventHandler]:
"""Return the on_cancelled event handler."""
return self._on_cancelled
def set_callbacks(
self,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> None:
"""Set the callbacks for download events."""
self._on_start = on_start
self._on_progress = on_progress
self._on_complete = on_complete
self._on_error = on_error
self._on_cancelled = on_cancelled
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@abstractmethod
def start(self, *args: Any, **kwargs: Any) -> None:
"""Start the download worker threads."""
@abstractmethod
def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop the download worker threads."""
@abstractmethod
def download(
self,
source: AnyHttpUrl,
dest: Path,
priority: int = 10,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> DownloadJob:
"""
Create a download job.
:param source: Source of the download as a URL.
:param dest: Path to download to. See below.
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A DownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object. Its behavior is:
1. If the path exists and is a directory, then the URL contents will be downloaded
into that directory using the filename indicated in the response's `Content-Disposition` field.
If no content-disposition is present, then the last component of the URL will be used (similar to
wget's behavior).
2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded
content.
3. If the path exists and is an existing file, then the downloader will try to resume the download from
the end of the existing file.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJob]:
"""
List active download jobs.
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJob:
"""
Return the DownloadJob corresponding to the integer ID.
:param id: ID of the DownloadJob.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def cancel_job(self, job: DownloadJob):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def join(self):
"""Wait until all jobs are off the queue."""
pass

View File

@ -1,418 +0,0 @@
# Copyright (c) 2023, Lincoln D. Stein
"""Implementation of multithreaded download queue for invokeai."""
import os
import re
import threading
import traceback
from logging import Logger
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadJob,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
ServiceInactiveException,
UnknownJobIDException,
)
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
class DownloadQueueService(DownloadQueueServiceBase):
"""Class for queued download of models."""
_jobs: Dict[int, DownloadJob]
_max_parallel_dl: int = 5
_worker_pool: Set[threading.Thread]
_queue: PriorityQueue[DownloadJob]
_stop_event: threading.Event
_lock: threading.Lock
_logger: Logger
_events: Optional[EventServiceBase] = None
_next_job_id: int = 0
_accept_download_requests: bool = False
_requests: requests.sessions.Session
def __init__(
self,
max_parallel_dl: int = 5,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._jobs = {}
self._next_job_id = 0
self._queue = PriorityQueue()
self._stop_event = threading.Event()
self._worker_pool = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
self._event_bus = event_bus
self._requests = requests_session or requests.Session()
self._accept_download_requests = False
self._max_parallel_dl = max_parallel_dl
def start(self, *args: Any, **kwargs: Any) -> None:
"""Start the download worker threads."""
with self._lock:
if self._worker_pool:
raise Exception("Attempt to start the download service twice")
self._stop_event.clear()
self._start_workers(self._max_parallel_dl)
self._accept_download_requests = True
def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop the download worker threads."""
with self._lock:
if not self._worker_pool:
raise Exception("Attempt to stop the download service before it was started")
self._accept_download_requests = False # reject attempts to add new jobs to queue
queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING]
active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING]
if queued_jobs:
self._logger.warning(f"Cancelling {len(queued_jobs)} queued downloads")
if active_jobs:
self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete")
with self._queue.mutex:
self._queue.queue.clear()
self.join() # wait for all active jobs to finish
self._stop_event.set()
self._worker_pool.clear()
def download(
self,
source: AnyHttpUrl,
dest: Path,
priority: int = 10,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadEventHandler] = None,
) -> DownloadJob:
"""Create a download job and return its ID."""
if not self._accept_download_requests:
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
with self._lock:
id = self._next_job_id
self._next_job_id += 1
job = DownloadJob(
id=id,
source=source,
dest=dest,
priority=priority,
access_token=access_token,
)
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[id] = job
self._queue.put(job)
return job
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
def prune_jobs(self) -> None:
"""Prune completed and errored queue items from the job list."""
with self._lock:
to_delete = set()
for job_id, job in self._jobs.items():
if self._in_terminal_state(job):
to_delete.add(job_id)
for job_id in to_delete:
del self._jobs[job_id]
def id_to_job(self, id: int) -> DownloadJob:
"""Translate a job ID into a DownloadJob object."""
try:
return self._jobs[id]
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None:
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
job.cancel()
def cancel_all_jobs(self, preserve_partial: bool = False) -> None:
"""Cancel all jobs (those not in enqueued, running or paused state)."""
for job in self._jobs.values():
if not self._in_terminal_state(job):
self.cancel_job(job)
def _in_terminal_state(self, job: DownloadJob) -> bool:
return job.status in [
DownloadJobStatus.COMPLETED,
DownloadJobStatus.CANCELLED,
DownloadJobStatus.ERROR,
]
def _start_workers(self, max_workers: int) -> None:
"""Start the requested number of worker threads."""
self._stop_event.clear()
for i in range(0, max_workers): # noqa B007
worker = threading.Thread(target=self._download_next_item, daemon=True)
self._logger.debug(f"Download queue worker thread {worker.name} starting.")
worker.start()
self._worker_pool.add(worker)
def _download_next_item(self) -> None:
"""Worker thread gets next job on priority queue."""
done = False
while not done:
if self._stop_event.is_set():
done = True
continue
try:
job = self._queue.get(timeout=1)
except Empty:
continue
try:
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
finally:
job.job_ended = get_iso_timestamp()
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
# Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), headers=header, stream=True)
if not resp.ok:
raise HTTPError(resp.reason)
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
if job.dest.is_dir():
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
if match := re.search('filename="(.+)"', resp.headers.get("Content-Disposition", "")):
remote_name = match.group(1)
if self._validate_filename(job.dest.as_posix(), remote_name):
file_name = remote_name
job.download_path = job.dest / file_name
else:
job.dest.parent.mkdir(parents=True, exist_ok=True)
job.download_path = job.dest
assert job.download_path
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():
raise OSError(f"[Errno 17] File {job.download_path} exists")
# append ".downloading" to the path
in_progress_path = self._in_progress_path(job.download_path)
# signal caller that the download is starting. At this point, key fields such as
# download_path and total_bytes will be populated. We call it here because the might
# discover that the local file is already complete and generate a COMPLETED status.
self._signal_job_started(job)
# "range not satisfiable" - local file is at least as large as the remote file
if resp.status_code == 416 or (content_length > 0 and job.bytes >= content_length):
self._logger.warning(f"{job.download_path}: complete file found. Skipping.")
return
# "partial content" - local file is smaller than remote file
elif resp.status_code == 206 or job.bytes > 0:
self._logger.warning(f"{job.download_path}: partial file found. Resuming")
# some other error
elif resp.status_code != 200:
raise HTTPError(resp.reason)
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
# DOWNLOAD LOOP
with open(in_progress_path, open_mode) as file:
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
if job.cancelled:
raise DownloadJobCancelledException("Job was cancelled at caller's request")
job.bytes += file.write(data)
if (job.bytes - last_report_bytes >= report_delta) or (job.bytes >= job.total_bytes):
last_report_bytes = job.bytes
self._signal_job_progress(job)
# if we get here we are done and can rename the file to the original dest
in_progress_path.rename(job.download_path)
def _validate_filename(self, directory: str, filename: str) -> bool:
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
pc_path_max = (
os.pathconf(directory, "PC_PATH_MAX") if hasattr(os, "pathconf") else 32767
) # hardcoded for windows with long names enabled
if "/" in filename:
return False
if filename.startswith(".."):
return False
if len(filename) > pc_name_max:
return False
if len(os.path.join(directory, filename)) > pc_path_max:
return False
return True
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source))
def _signal_job_error(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.ERROR
if job.on_error:
try:
job.on_error(job)
except Exception as e:
self._logger.error(e)
if self._event_bus:
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.download_path}")
try:
if job.download_path:
partial_file = self._in_progress_path(job.download_path)
partial_file.unlink()
except OSError as excp:
self._logger.warning(excp)
# Example on_progress event handler to display a TQDM status bar
# Activate with:
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
class TqdmProgress(object):
"""TQDM-based progress bar object to use in on_progress handlers."""
_bars: Dict[int, tqdm] # the tqdm object
_last: Dict[int, int] # last bytes downloaded
def __init__(self) -> None: # noqa D107
self._bars = {}
self._last = {}
def update(self, job: DownloadJob) -> None: # noqa D102
job_id = job.id
# new job
if job_id not in self._bars:
assert job.download_path
dest = Path(job.download_path).name
self._bars[job_id] = tqdm(
desc=dest,
initial=0,
total=job.total_bytes,
unit="iB",
unit_scale=True,
)
self._last[job_id] = 0
self._bars[job_id].update(job.bytes - self._last[job_id])
self._last[job_id] = job.bytes

View File

@ -1 +0,0 @@
from .events_base import EventServiceBase # noqa F401

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
@ -17,8 +16,6 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
class EventServiceBase:
queue_event: str = "queue_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
@ -33,20 +30,6 @@ class EventServiceBase:
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
@ -330,146 +313,3 @@ class EventServiceBase:
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_install_started(self, source: str) -> None:
"""
Emitted when an install job is started.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_started",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str) -> None:
"""
Emitted when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={
"source": source,
"key": key,
},
)
def emit_model_install_progress(
self,
source: str,
current_bytes: int,
total_bytes: int,
) -> None:
"""
Emitted while the install job is in progress.
(Downloaded models only)
:param source: Source of the model
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: Total bytes to download
"""
self.__emit_model_event(
event_name="model_install_progress",
payload={
"source": source,
"current_bytes": int,
"total_bytes": int,
},
)
def emit_model_install_error(
self,
source: str,
error_type: str,
error: str,
) -> None:
"""
Emitted when an install job encounters an exception.
:param source: Source of the model
:param exception: The exception that raised the error
"""
self.__emit_model_event(
event_name="model_install_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)

View File

@ -32,6 +32,101 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "starred" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "has_workflow" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
"""
)
def get(self, image_name: str) -> ImageRecord:
try:
self._lock.acquire()

View File

@ -11,7 +11,6 @@ if TYPE_CHECKING:
from .board_records.board_records_base import BoardRecordStorageBase
from .boards.boards_base import BoardServiceABC
from .config import InvokeAIAppConfig
from .download import DownloadQueueServiceBase
from .events.events_base import EventServiceBase
from .image_files.image_files_base import ImageFileStorageBase
from .image_records.image_records_base import ImageRecordStorageBase
@ -22,13 +21,12 @@ if TYPE_CHECKING:
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .item_storage.item_storage_base import ItemStorageABC
from .latents_storage.latents_storage_base import LatentsStorageBase
from .model_install import ModelInstallServiceBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .model_records import ModelRecordServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState
from .shared.graph import GraphExecutionState, LibraryGraph
from .urls.urls_base import UrlServiceBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@ -44,6 +42,7 @@ class InvocationServices:
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC"
image_records: "ImageRecordStorageBase"
image_files: "ImageFileStorageBase"
@ -51,8 +50,6 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
model_records: "ModelRecordServiceBase"
download_queue: "DownloadQueueServiceBase"
model_install: "ModelInstallServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
@ -72,6 +69,7 @@ class InvocationServices:
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
graph_library: "ItemStorageABC[LibraryGraph]",
images: "ImageServiceABC",
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
@ -79,8 +77,6 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
model_records: "ModelRecordServiceBase",
download_queue: "DownloadQueueServiceBase",
model_install: "ModelInstallServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
@ -98,6 +94,7 @@ class InvocationServices:
self.configuration = configuration
self.events = events
self.graph_execution_manager = graph_execution_manager
self.graph_library = graph_library
self.images = images
self.image_files = image_files
self.image_records = image_records
@ -105,8 +102,6 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.model_records = model_records
self.download_queue = download_queue
self.model_install = model_install
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -1,25 +0,0 @@
"""Initialization file for model install service package."""
from .model_install_base import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
URLModelSource,
)
from .model_install_default import ModelInstallService
__all__ = [
"ModelInstallServiceBase",
"ModelInstallService",
"InstallStatus",
"ModelInstallJob",
"UnknownInstallJobException",
"ModelSource",
"LocalModelSource",
"HFModelSource",
"URLModelSource",
]

View File

@ -1,305 +0,0 @@
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""A HuggingFace repo_id, with optional variant and sub-folder."""
repo_id: str
variant: Optional[str] = None
subfolder: Optional[str | Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
base += f":{self.subfolder}" if self.subfolder else ""
base += f" ({self.variant})" if self.variant else ""
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["generic_url"] = "generic_url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self.error_type = e.__class__.__name__
self.error = "".join(traceback.format_exception(e))
self.status = InstallStatus.ERROR
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@abstractmethod
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
event_bus: Optional["EventServiceBase"] = None,
):
"""
Create ModelInstallService object.
:param config: Systemwide InvokeAIAppConfig.
:param store: Systemwide ModelConfigStore
:param event_bus: InvokeAI event bus for reporting events to.
"""
@abstractmethod
def start(self, *args: Any, **kwarg: Any) -> None:
"""Start the installer service."""
@abstractmethod
def stop(self, *args: Any, **kwarg: Any) -> None:
"""Stop the model install service. After this the objection can be safely deleted."""
@property
@abstractmethod
def app_config(self) -> InvokeAIAppConfig:
"""Return the appConfig object associated with the installer."""
@property
@abstractmethod
def record_store(self) -> ModelRecordServiceBase:
"""Return the ModelRecoreService object associated with the installer."""
@property
@abstractmethod
def event_bus(self) -> Optional[EventServiceBase]:
"""Return the event service base object associated with the installer."""
@abstractmethod
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe and register the model at model_path.
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@abstractmethod
def unregister(self, key: str) -> None:
"""Remove model with indicated key from the database."""
@abstractmethod
def delete(self, key: str) -> None:
"""Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
@abstractmethod
def unconditionally_delete(self, key: str) -> None:
"""Remove model with indicated key from the database and unconditionally delete weight files from disk."""
@abstractmethod
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe, register and install the model in the models directory.
This moves the model from its current location into
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@abstractmethod
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""Install the indicated model.
:param source: ModelSource object
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record. Use it to override
`name`, `description`, `base_type`, `model_type`, `format`,
`prediction_type`, `image_size`, and/or `ztsnr_training`.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread and will issue the following events on the event bus:
- model_install_started
- model_install_error
- model_install_completed
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
The call returns a ModelInstallJob object which can be
polled to learn the current status and/or error message.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
@abstractmethod
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""
@abstractmethod
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
"""
List active and complete install jobs.
"""
@abstractmethod
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@abstractmethod
def wait_for_installs(self) -> List[ModelInstallJob]:
"""
Wait for all pending installs to complete.
This will block until all pending installs have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return the current list of jobs.
"""
@abstractmethod
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
"""
Recursively scan directory for new models and register or install them.
:param scan_dir: Path to the directory to scan.
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
@abstractmethod
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""

View File

@ -1,399 +0,0 @@
"""Model installation class."""
import threading
from hashlib import sha256
from logging import Logger
from pathlib import Path
from queue import Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from typing import Any, Dict, List, Optional, Set, Union
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
from .model_install_base import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
)
# marker that the queue is done and that thread should exit
STOP_JOB = ModelInstallJob(
source=LocalModelSource(path="stop"),
local_path=Path("/dev/null"),
)
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
_app_config: InvokeAIAppConfig
_record_store: ModelRecordServiceBase
_event_bus: Optional[EventServiceBase] = None
_install_queue: Queue[ModelInstallJob]
_install_jobs: List[ModelInstallJob]
_logger: Logger
_cached_model_paths: Set[Path]
_models_installed: Set[str]
def __init__(
self,
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
event_bus: Optional[EventServiceBase] = None,
):
"""
Initialize the installer object.
:param app_config: InvokeAIAppConfig object
:param record_store: Previously-opened ModelRecordService database
:param event_bus: Optional EventService object
"""
self._app_config = app_config
self._record_store = record_store
self._event_bus = event_bus
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
self._install_jobs = []
self._install_queue = Queue()
self._cached_model_paths = set()
self._models_installed = set()
@property
def app_config(self) -> InvokeAIAppConfig: # noqa D102
return self._app_config
@property
def record_store(self) -> ModelRecordServiceBase: # noqa D102
return self._record_store
@property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def start(self, *args: Any, **kwarg: Any) -> None:
"""Start the installer thread."""
self._start_installer_thread()
self.sync_to_config()
def stop(self, *args: Any, **kwarg: Any) -> None:
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB)
def _start_installer_thread(self) -> None:
threading.Thread(target=self._install_next_item, daemon=True).start()
def _install_next_item(self) -> None:
done = False
while not done:
job = self._install_queue.get()
if job == STOP_JOB:
done = True
continue
assert job.local_path is not None
try:
self._signal_job_running(job)
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
else:
key = self.install_path(job.local_path, job.config_in)
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
self._signal_job_errored(job, excp)
finally:
self._install_queue.task_done()
self._logger.info("Install thread exiting")
def _signal_job_running(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.RUNNING
self._logger.info(f"{job.source}: model installation started")
if self._event_bus:
self._event_bus.emit_model_install_started(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
assert job.config_out
self._logger.info(
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
)
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key)
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
job.set_error(excp)
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
if self._event_bus:
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
def register_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
config["source"] = model_path.resolve().as_posix()
return self._register(model_path, config)
def install_path(
self,
model_path: Union[Path, str],
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or {}
if config.get("source") is None:
config["source"] = model_path.resolve().as_posix()
info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.original_hash
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
new_path = self._copy_model(model_path, dest_path)
new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
config,
info,
)
def import_model(
self,
source: ModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob: # noqa D102
if not config:
config = {}
# Installing a local path
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
job = ModelInstallJob(
source=source,
config_in=config,
local_path=Path(source.path),
)
self._install_jobs.append(job)
self._install_queue.put(job)
return job
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
raise UnknownModelException("File or directory not found")
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
return [x for x in self._install_jobs if x.source == source]
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
self._install_queue.join()
return self._install_jobs
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
unfinished_jobs = [
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
]
self._install_jobs = unfinished_jobs
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the config record store database."""
self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
installed = self.scan_directory(self._app_config.root_path / autoimport)
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._models_installed: Set[str] = set()
search.search(scan_dir)
return list(self._models_installed)
def _scan_models_directory(self) -> None:
"""
Scan the models directory for new and missing models.
New models will be added to the storage backend. Missing models
will be deleted.
"""
defunct_models = set()
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk")
for model_config in self.record_store.all_models():
path = Path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
defunct_models.add(model_config.key)
for key in defunct_models:
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
old_path = Path(model.path)
models_dir = self.app_config.models_path
if not old_path.is_relative_to(models_dir):
return model
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
new_hash = FastModelHash.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
model.current_hash = new_hash
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
self.record_store.update_model(key, model)
return model
def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.register_path(model)
self._sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}")
self._models_installed.add(id)
except DuplicateModelException:
pass
return True
def _scan_install(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.install_path(model)
self._logger.info(f"Installed {model} with id {id}")
self._models_installed.add(id)
except DuplicateModelException:
pass
return True
def unregister(self, key: str) -> None: # noqa D102
self.record_store.del_model(key)
def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
model = self.record_store.get_model(key)
path = self.app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
if old_path.is_dir():
copytree(old_path, new_path)
else:
copyfile(old_path, new_path)
return new_path
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
move(old_path, new_path)
return new_path
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
if config: # used to override probe fields
for key, value in config.items():
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
info = info or ModelProbe.probe(model_path, config)
key = self._create_key()
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
info.path = model_path.as_posix()
# add 'main' specific fields
if hasattr(info, "config"):
# make config relative to our root
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
self.record_store.add_model(key, info)
return key

View File

@ -6,11 +6,3 @@ from .model_records_base import ( # noqa F401
UnknownModelException,
)
from .model_records_sql import ModelRecordServiceSQL # noqa F401
__all__ = [
"ModelRecordServiceBase",
"ModelRecordServiceSQL",
"DuplicateModelException",
"InvalidModelException",
"UnknownModelException",
]

View File

@ -7,7 +7,10 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception):
@ -29,6 +32,12 @@ class ConfigFileVersionMismatchException(Exception):
class ModelRecordServiceBase(ABC):
"""Abstract base class for storage and retrieval of model configs."""
@property
@abstractmethod
def version(self) -> str:
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
@ -106,7 +115,6 @@ class ModelRecordServiceBase(ABC):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -114,7 +122,6 @@ class ModelRecordServiceBase(ABC):
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all
models in the database.

View File

@ -49,12 +49,12 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelFormat,
ModelType,
)
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,
ModelRecordServiceBase,
UnknownModelException,
@ -78,6 +78,85 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._cursor = self._db.conn.cursor()
with self._db.lock:
# Enable foreign keys
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
def _create_tables(self) -> None:
"""Create sqlite3 tables."""
# model_config table breaks out the fields that are common to all config objects
# and puts class-specific ones in a serialized json object
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
# metadata table
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
)
# Add indexes for searchable fields
for stmt in [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]:
self._cursor.execute(stmt)
# Add our version to the metadata table
self._cursor.execute(
"""--sql
INSERT OR IGNORE into model_manager_metadata (
metadata_key,
metadata_value
)
VALUES (?,?);
""",
("version", CONFIG_FILE_VERSION),
)
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Add a model to the database.
@ -96,13 +175,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?);
VALUES (?,?,?,?,?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
@ -127,6 +214,22 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return self.get_model(key)
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
WHERE metadata_key=?;
""",
("version",),
)
rows = self._cursor.fetchone()
if not rows:
raise KeyError("Models database does not have metadata key 'version'")
return rows[0]
def del_model(self, key: str) -> None:
"""
Delete a model.
@ -166,11 +269,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
UPDATE model_config
SET
SET base=?,
type=?,
name=?,
path=?,
config=?
WHERE id=?;
""",
(json_serialized, key),
(record.base, record.type, record.name, record.path, json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
@ -226,7 +332,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -234,7 +339,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all
models in the database.
@ -251,9 +355,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(
@ -273,7 +374,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE path=?;
WHERE model_path=?;
""",
(str(path),),
)

View File

@ -50,6 +50,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
@ -97,6 +98,123 @@ class SqliteSessionQueue(SessionQueueBase):
except SessionQueueItemNotFoundError:
return
def _create_tables(self) -> None:
"""Creates the session queue tables, indicies, and triggers"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in self.__cursor.fetchall()]
if "workflow" not in columns:
self.__cursor.execute(
"""--sql
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.

View File

@ -3,65 +3,45 @@ import threading
from logging import Logger
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase:
"""
Manages a connection to an SQLite database.
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
:param db_path: Path to the database file. If None, an in-memory database is used.
:param logger: Logger to use for logging.
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
- The database file is written to disk if it does not exist.
- Foreign key constraints are enabled by default.
- The connection is configured to use the `sqlite3.Row` row factory.
In addition to the constructor args, the instance provides the following attributes and methods:
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
- `lock`: A shared re-entrant lock, used to approximate thread safety.
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
"""
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
self.verbose = verbose
if not self.db_path:
logger.info("Initializing in-memory database")
if self._config.use_memory_db:
self.db_path = sqlite_memory
logger.info("Using in-memory database")
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
self.db_path = str(db_path)
self._logger.info(f"Using database at {self.db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
if self._config.log_sql:
self.conn.set_trace_callback(self._logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;")
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self.db_path:
return
with self.lock:
try:
if self.db_path == sqlite_memory:
return
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
self._logger.error(f"Error cleaning database: {e}")
raise

View File

@ -1,34 +0,0 @@
from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileStorageBase) -> SqliteDatabase:
"""
Initializes the SQLite database.
:param config: The app config
:param logger: The logger
:param image_files: The image files service (used by migration 2)
This function:
- Instantiates a :class:`SqliteDatabase`
- Instantiates a :class:`SqliteMigrator` and registers all migrations
- Runs all migrations
"""
db_path = None if config.use_memory_db else config.db_path
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
migrator.register_migration(build_migration_3())
migrator.run_migrations()
return db

View File

@ -1,372 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration1Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
"""Migration callback for database version 1."""
self._create_board_images(cursor)
self._create_boards(cursor)
self._create_images(cursor)
self._create_model_config(cursor)
self._create_session_queue(cursor)
self._create_workflow_images(cursor)
self._create_workflows(cursor)
def _create_board_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `board_images` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_boards(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `boards` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
]
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_images(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
]
# Add the 'starred' column to `images` if it doesn't exist
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "starred" not in columns:
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_model_config(self, cursor: sqlite3.Cursor) -> None:
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
""",
"""--sql
CREATE TABLE IF NOT EXISTS model_manager_metadata (
metadata_key TEXT NOT NULL PRIMARY KEY,
metadata_value TEXT NOT NULL
);
""",
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
AFTER UPDATE
ON model_config FOR EACH ROW
BEGIN
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_session_queue(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
""",
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
""",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflow_images(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _create_workflows(self, cursor: sqlite3.Cursor) -> None:
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
);
"""
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
BEGIN
UPDATE workflows
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + triggers:
cursor.execute(stmt)
def build_migration_1() -> Migration:
"""
Builds the migration from database version 0 (init) to 1.
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
version to not use migrations to manage the database.
As such, this migration does include some ALTER statements, and the SQL statements are written
to be idempotent.
- Create `board_images` junction table
- Create `boards` table
- Create `images` table, add `starred` column
- Create `model_config` table
- Create `session_queue` table
- Create `workflow_images` junction table
- Create `workflows` table
"""
migration_1 = Migration(
from_version=0,
to_version=1,
callback=Migration1Callback(),
)
return migration_1

View File

@ -1,209 +0,0 @@
import sqlite3
from logging import Logger
from pydantic import ValidationError
from tqdm import tqdm
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.image_files.image_files_common import ImageFileNotFoundException
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.app.services.workflow_records.workflow_records_common import (
UnsafeWorkflowWithVersionValidator,
)
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration2Callback:
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
self._image_files = image_files
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor):
self._add_images_has_workflow(cursor)
self._add_session_queue_workflow(cursor)
self._drop_old_workflow_tables(cursor)
self._add_workflow_library(cursor)
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_model_config_records(cursor)
self._migrate_embedded_workflows(cursor)
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `has_workflow` column to `images` table."""
cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in cursor.fetchall()]
if "has_workflow" not in columns:
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
def _add_session_queue_workflow(self, cursor: sqlite3.Cursor) -> None:
"""Add the `workflow` column to `session_queue` table."""
cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in cursor.fetchall()]
if "workflow" not in columns:
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
def _drop_old_workflow_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `workflows` and `workflow_images` tables."""
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
cursor.execute("DROP TABLE IF EXISTS workflows;")
def _add_workflow_library(self, cursor: sqlite3.Cursor) -> None:
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
""",
]
indices = [
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
]
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
model_record_migrator = MigrateModelYamlToDb1(cursor)
model_record_migrator.migrate()
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
"""
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
This migrate callback checks each image for the presence of an embedded workflow, then updates its entry
in the database accordingly.
"""
# Get all image names
cursor.execute("SELECT image_name FROM images")
image_names: list[str] = [image[0] for image in cursor.fetchall()]
total_image_names = len(image_names)
if not total_image_names:
return
self._logger.info(f"Migrating workflows for {total_image_names} images")
# Migrate the images
to_migrate: list[tuple[bool, str]] = []
pbar = tqdm(image_names)
for idx, image_name in enumerate(pbar):
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
try:
pil_image = self._image_files.get(image_name)
except ImageFileNotFoundException:
self._logger.warning(f"Image {image_name} not found, skipping")
continue
except Exception as e:
self._logger.warning(f"Error while checking image {image_name}, skipping: {e}")
continue
if "invokeai_workflow" in pil_image.info:
try:
UnsafeWorkflowWithVersionValidator.validate_json(pil_image.info.get("invokeai_workflow", ""))
except ValidationError:
self._logger.warning(f"Image {image_name} has invalid embedded workflow, skipping")
continue
to_migrate.append((True, image_name))
self._logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
def build_migration_2(image_files: ImageFileStorageBase, logger: Logger) -> Migration:
"""
Builds the migration from database version 1 to 2.
Introduced in v3.5.0 for the new workflow library.
:param image_files: The image files service, used to check for embedded workflows
:param logger: The logger, used to log progress during embedded workflows handling
This migration does the following:
- Add `has_workflow` column to `images` table
- Add `workflow` column to `session_queue` table
- Drop `workflows` and `workflow_images` tables
- Add `workflow_library` table
- Drops the `model_manager_metadata` table
- Drops the `model_config` table, recreating it (at this point, there is no user data in this table)
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
"""
migration_2 = Migration(
from_version=1,
to_version=2,
callback=Migration2Callback(image_files=image_files, logger=logger),
)
return migration_2

View File

@ -1,75 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
class Migration3Callback:
def __init__(self) -> None:
pass
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_model_manager_metadata(cursor)
self._recreate_model_config(cursor)
self._migrate_model_config_records(cursor)
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
"""Drops the `model_manager_metadata` table."""
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
"""After updating the model config table, we repopulate it."""
model_record_migrator = MigrateModelYamlToDb1(cursor)
model_record_migrator.migrate()
def build_migration_3() -> Migration:
"""
Build the migration from database version 2 to 3.
This migration does the following:
- Drops the `model_config` table, recreating it
- Migrates data from `models.yaml` into the `model_config` table
"""
migration_3 = Migration(
from_version=2,
to_version=3,
callback=Migration3Callback(),
)
return migration_3

View File

@ -1,148 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
import json
import sqlite3
from hashlib import sha1
from logging import Logger
from pathlib import Path
from typing import Optional
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
UnknownModelException,
)
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
class MigrateModelYamlToDb1:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.5.0).
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig
logger: Logger
cursor: sqlite3.Cursor
def __init__(self, cursor: sqlite3.Cursor = None) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
self.cursor = cursor
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf, DictConfig)
return omegaconf
def migrate(self) -> None:
"""Do the migration from models.yaml to invokeai.db."""
try:
yaml = self.get_yaml()
except OSError:
return
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
try:
if original_record := self._search_by_path(stanza.path):
key = original_record.key
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
except UnknownModelException:
self.logger.warning(f"Model at {stanza.path} could not be found in database")
def _search_by_path(self, path: Path) -> Optional[AnyModelConfig]:
self.cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self.cursor.fetchall()]
return results[0] if results else None
def _update_model(self, key: str, config: AnyModelConfig) -> None:
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = record.model_dump_json() # and turn it into a json string.
self.cursor.execute(
"""--sql
UPDATE model_config
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if self.cursor.rowcount == 0:
raise UnknownModelException("model not found")
def _add_model(self, key: str, config: AnyModelConfig) -> None:
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = record.model_dump_json() # and turn it into a json string.
try:
self.cursor.execute(
"""--sql
INSERT INTO model_config (
id,
original_hash,
config
)
VALUES (?,?,?);
""",
(
key,
record.original_hash,
json_serialized,
),
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc

View File

@ -1,164 +0,0 @@
import sqlite3
from typing import Optional, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, model_validator
@runtime_checkable
class MigrateCallback(Protocol):
"""
A callback that performs a migration.
Migrate callbacks are provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
If the callback needs to access additional dependencies, will be provided to the callback at runtime.
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):
"""Raised when a migration fails."""
class MigrationVersionError(ValueError):
"""Raised when a migration version is invalid."""
class Migration(BaseModel):
"""
Represents a migration for a SQLite database.
:param from_version: The database version on which this migration may be run
:param to_version: The database version that results from this migration
:param migrate_callback: The callback to run to perform the migration
Migration callbacks will be provided an open cursor to the database. They should not commit their
transaction; this is handled by the migrator.
It is suggested to use a class to define the migration callback and a builder function to create
the :class:`Migration`. This allows the callback to be provided with additional dependencies and
keeps things tidy, as all migration logic is self-contained.
Example:
```py
# Define the migration callback class
class Migration1Callback:
# This migration needs a logger, so we define a class that accepts a logger in its constructor.
def __init__(self, image_files: ImageFileStorageBase) -> None:
self._image_files = ImageFileStorageBase
# This dunder method allows the instance of the class to be called like a function.
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_with_banana_column(cursor)
self._do_something_with_images(cursor)
def _add_with_banana_column(self, cursor: sqlite3.Cursor) -> None:
\"""Adds the with_banana column to the sushi table.\"""
# Execute SQL using the cursor, taking care to *not commit* a transaction
cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;')
def _do_something_with_images(self, cursor: sqlite3.Cursor) -> None:
\"""Does something with the image files service.\"""
self._image_files.get(...)
# Define the migration builder function. This function creates an instance of the migration callback
# class and returns a Migration.
def build_migration_1(image_files: ImageFileStorageBase) -> Migration:
\"""Builds the migration from database version 0 to 1.
Requires the image files service to...
\"""
migration_1 = Migration(
from_version=0,
to_version=1,
migrate_callback=Migration1Callback(image_files=image_files),
)
return migration_1
# Register the migration after all dependencies have been initialized
db = SqliteDatabase(db_path, logger)
migrator = SqliteMigrator(db)
migrator.register_migration(build_migration_1(image_files))
migrator.run_migrations()
```
"""
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
callback: MigrateCallback = Field(description="The callback to run to perform the migration")
@model_validator(mode="after")
def validate_to_version(self) -> "Migration":
"""Validates that to_version is one greater than from_version."""
if self.to_version != self.from_version + 1:
raise MigrationVersionError("to_version must be one greater than from_version")
return self
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))
model_config = ConfigDict(arbitrary_types_allowed=True)
class MigrationSet:
"""
A set of Migrations. Performs validation during migration registration and provides utility methods.
Migrations should be registered with `register()`. Once all are registered, `validate_migration_chain()`
should be called to ensure that the migrations form a single chain of migrations from version 0 to the latest version.
"""
def __init__(self) -> None:
self._migrations: set[Migration] = set()
def register(self, migration: Migration) -> None:
"""Registers a migration."""
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
if migration_from_already_registered or migration_to_already_registered:
raise MigrationVersionError("Migration with from_version or to_version already registered")
self._migrations.add(migration)
def get(self, from_version: int) -> Optional[Migration]:
"""Gets the migration that may be run on the given database version."""
# register() ensures that there is only one migration with a given from_version, so this is safe.
return next((m for m in self._migrations if m.from_version == from_version), None)
def validate_migration_chain(self) -> None:
"""
Validates that the migrations form a single chain of migrations from version 0 to the latest version,
Raises a MigrationError if there is a problem.
"""
if self.count == 0:
return
if self.latest_version == 0:
return
next_migration = self.get(from_version=0)
if next_migration is None:
raise MigrationError("Migration chain is fragmented")
touched_count = 1
while next_migration is not None:
next_migration = self.get(next_migration.to_version)
if next_migration is not None:
touched_count += 1
if touched_count != self.count:
raise MigrationError("Migration chain is fragmented")
@property
def count(self) -> int:
"""The count of registered migrations."""
return len(self._migrations)
@property
def latest_version(self) -> int:
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
if self.count == 0:
return 0
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version

View File

@ -1,130 +0,0 @@
import sqlite3
from pathlib import Path
from typing import Optional
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
class SqliteMigrator:
"""
Manages migrations for a SQLite database.
:param db: The instance of :class:`SqliteDatabase` to migrate.
Migrations should be registered with :meth:`register_migration`.
Each migration is run in a transaction. If a migration fails, the transaction is rolled back.
Example Usage:
```py
db = SqliteDatabase(db_path="my_db.db", logger=logger)
migrator = SqliteMigrator(db=db)
migrator.register_migration(build_migration_1())
migrator.register_migration(build_migration_2())
migrator.run_migrations()
```
"""
backup_path: Optional[Path] = None
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db.logger
self._migration_set = MigrationSet()
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
self._migration_set.register(migration)
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
def run_migrations(self) -> bool:
"""Migrates the database to the latest version."""
with self._db.lock:
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
self._logger.info("Database update needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.lock, self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(
f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}"
)
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
# Run the actual migration
migration.callback(cursor)
# Update the version
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
self._logger.debug(
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
)
# We want to catch *any* error, mirroring the behaviour of the sqlite3 module.
except Exception as e:
# The connection context manager has already rolled back the migration, so we don't need to do anything.
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
self._logger.error(msg)
raise MigrationError(msg) from e
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
with self._db.lock:
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
@classmethod
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
"""Gets the current version of the database, or 0 if the migrations table does not exist."""
try:
cursor.execute("SELECT MAX(version) FROM migrations;")
version: int = cursor.fetchone()[0]
if version is None:
return 0
return version
except sqlite3.OperationalError as e:
if "no such table" in str(e):
return 0
raise

View File

@ -1,975 +0,0 @@
{
"name": "Prompt from File",
"author": "InvokeAI",
"description": "Sample workflow using Prompt from File node",
"version": "0.1.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, prompt from file, default",
"notes": "",
"exposedFields": [
{
"nodeId": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"fieldName": "model"
},
{
"nodeId": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"fieldName": "file_path"
}
],
"meta": {
"category": "default",
"version": "2.0.0"
},
"id": "d1609af5-eb0a-4f73-b573-c9af96a8d6bf",
"nodes": [
{
"id": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"type": "invocation",
"data": {
"id": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"type": "compel",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "dcdf3f6d-9b96-4bcd-9b8d-f992fefe4f62",
"name": "prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "3f1981c9-d8a9-42eb-a739-4f120eb80745",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "46205e6c-c5e2-44cb-9c82-1cd20b95674a",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -200
}
},
{
"id": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"type": "invocation",
"data": {
"id": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"type": "prompt_from_file",
"label": "Prompts from File",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"nodePack": "invokeai",
"inputs": {
"file_path": {
"id": "37e37684-4f30-4ec8-beae-b333e550f904",
"name": "file_path",
"fieldKind": "input",
"label": "Prompts File Path",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"pre_prompt": {
"id": "7de02feb-819a-4992-bad3-72a30920ddea",
"name": "pre_prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"post_prompt": {
"id": "95f191d8-a282-428e-bd65-de8cb9b7513a",
"name": "post_prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"start_line": {
"id": "efee9a48-05ab-4829-8429-becfa64a0782",
"name": "start_line",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 1
},
"max_prompts": {
"id": "abebb428-3d3d-49fd-a482-4e96a16fff08",
"name": "max_prompts",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 1
}
},
"outputs": {
"collection": {
"id": "77d5d7f1-9877-4ab1-9a8c-33e9ffa9abf3",
"name": "collection",
"fieldKind": "output",
"type": {
"isCollection": true,
"isCollectionOrScalar": false,
"name": "StringField"
}
}
}
},
"width": 320,
"height": 580,
"position": {
"x": 475,
"y": -400
}
},
{
"id": "1b89067c-3f6b-42c8-991f-e3055789b251",
"type": "invocation",
"data": {
"id": "1b89067c-3f6b-42c8-991f-e3055789b251",
"type": "iterate",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.1.0",
"inputs": {
"collection": {
"id": "4c564bf8-5ed6-441e-ad2c-dda265d5785f",
"name": "collection",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": true,
"isCollectionOrScalar": false,
"name": "CollectionField"
}
}
},
"outputs": {
"item": {
"id": "36340f9a-e7a5-4afa-b4b5-313f4e292380",
"name": "item",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "CollectionItemField"
}
},
"index": {
"id": "1beca95a-2159-460f-97ff-c8bab7d89336",
"name": "index",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"total": {
"id": "ead597b8-108e-4eda-88a8-5c29fa2f8df9",
"name": "total",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -400
}
},
{
"id": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"type": "invocation",
"data": {
"id": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"type": "main_model_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"model": {
"id": "3f264259-3418-47d5-b90d-b6600e36ae46",
"name": "model",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MainModelField"
},
"value": {
"model_name": "stable-diffusion-v1-5",
"base_model": "sd-1",
"model_type": "main"
}
}
},
"outputs": {
"unet": {
"id": "8e182ea2-9d0a-4c02-9407-27819288d4b5",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "d67d9d30-058c-46d5-bded-3d09d6d1aa39",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
},
"vae": {
"id": "89641601-0429-4448-98d5-190822d920d8",
"name": "vae",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 0,
"y": -375
}
},
{
"id": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "invocation",
"data": {
"id": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "compel",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "dcdf3f6d-9b96-4bcd-9b8d-f992fefe4f62",
"name": "prompt",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "3f1981c9-d8a9-42eb-a739-4f120eb80745",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "46205e6c-c5e2-44cb-9c82-1cd20b95674a",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -275
}
},
{
"id": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "invocation",
"data": {
"id": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "noise",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"nodePack": "invokeai",
"inputs": {
"seed": {
"id": "b722d84a-eeee-484f-bef2-0250c027cb67",
"name": "seed",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"width": {
"id": "d5f8ce11-0502-4bfc-9a30-5757dddf1f94",
"name": "width",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"height": {
"id": "f187d5ff-38a5-4c3f-b780-fc5801ef34af",
"name": "height",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"use_cpu": {
"id": "12f112b8-8b76-4816-b79e-662edc9f9aa5",
"name": "use_cpu",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
"outputs": {
"noise": {
"id": "08576ad1-96d9-42d2-96ef-6f5c1961933f",
"name": "noise",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "f3e1f94a-258d-41ff-9789-bd999bd9f40d",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "6cefc357-4339-415e-a951-49b9c2be32f4",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": 25
}
},
{
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"type": "invocation",
"data": {
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"type": "rand_int",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": false,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"low": {
"id": "b9fc6cf1-469c-4037-9bf0-04836965826f",
"name": "low",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"high": {
"id": "06eac725-0f60-4ba2-b8cd-7ad9f757488c",
"name": "high",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 2147483647
}
},
"outputs": {
"value": {
"id": "df08c84e-7346-4e92-9042-9e5cb773aaff",
"name": "value",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 925,
"y": -50
}
},
{
"id": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "invocation",
"data": {
"id": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "l2i",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.2.0",
"nodePack": "invokeai",
"inputs": {
"metadata": {
"id": "022e4b33-562b-438d-b7df-41c3fd931f40",
"name": "metadata",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MetadataField"
}
},
"latents": {
"id": "67cb6c77-a394-4a66-a6a9-a0a7dcca69ec",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"vae": {
"id": "7b3fd9ad-a4ef-4e04-89fa-3832a9902dbd",
"name": "vae",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"tiled": {
"id": "5ac5680d-3add-4115-8ec0-9ef5bb87493b",
"name": "tiled",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
},
"fp32": {
"id": "db8297f5-55f8-452f-98cf-6572c2582152",
"name": "fp32",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
}
},
"outputs": {
"image": {
"id": "d8778d0c-592a-4960-9280-4e77e00a7f33",
"name": "image",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ImageField"
}
},
"width": {
"id": "c8b0a75a-f5de-4ff2-9227-f25bb2b97bec",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "83c05fbf-76b9-49ab-93c4-fa4b10e793e4",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 2037.861329274915,
"y": -329.8393457509562
}
},
{
"id": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "invocation",
"data": {
"id": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "denoise_latents",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
"id": "751fb35b-3f23-45ce-af1c-053e74251337",
"name": "positive_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"negative_conditioning": {
"id": "b9dc06b6-7481-4db1-a8c2-39d22a5eacff",
"name": "negative_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"noise": {
"id": "6e15e439-3390-48a4-8031-01e0e19f0e1d",
"name": "noise",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"steps": {
"id": "bfdfb3df-760b-4d51-b17b-0abb38b976c2",
"name": "steps",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 10
},
"cfg_scale": {
"id": "47770858-322e-41af-8494-d8b63ed735f3",
"name": "cfg_scale",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "FloatField"
},
"value": 7.5
},
"denoising_start": {
"id": "2ba78720-ee02-4130-a348-7bc3531f790b",
"name": "denoising_start",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"denoising_end": {
"id": "a874dffb-d433-4d1a-9f59-af4367bb05e4",
"name": "denoising_end",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 1
},
"scheduler": {
"id": "36e021ad-b762-4fe4-ad4d-17f0291c40b2",
"name": "scheduler",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SchedulerField"
},
"value": "euler"
},
"unet": {
"id": "98d3282d-f9f6-4b5e-b9e8-58658f1cac78",
"name": "unet",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"control": {
"id": "f2ea3216-43d5-42b4-887f-36e8f7166d53",
"name": "control",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "ControlField"
}
},
"ip_adapter": {
"id": "d0780610-a298-47c8-a54e-70e769e0dfe2",
"name": "ip_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "IPAdapterField"
}
},
"t2i_adapter": {
"id": "fdb40970-185e-4ea8-8bb5-88f06f91f46a",
"name": "t2i_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "T2IAdapterField"
}
},
"cfg_rescale_multiplier": {
"id": "3af2d8c5-de83-425c-a100-49cb0f1f4385",
"name": "cfg_rescale_multiplier",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"latents": {
"id": "e05b538a-1b5a-4aa5-84b1-fd2361289a81",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"denoise_mask": {
"id": "463a419e-df30-4382-8ffb-b25b25abe425",
"name": "denoise_mask",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "DenoiseMaskField"
}
}
},
"outputs": {
"latents": {
"id": "559ee688-66cf-4139-8b82-3d3aa69995ce",
"name": "latents",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "0b4285c2-e8b9-48e5-98f6-0a49d3f98fd2",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "8b0881b9-45e5-47d5-b526-24b6661de0ee",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 1570.9941088179146,
"y": -407.6505491604564
}
}
],
"edges": [
{
"id": "1b89067c-3f6b-42c8-991f-e3055789b251-fc9d0e35-a6de-4a19-84e1-c72497c823f6-collapsed",
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "collapsed"
},
{
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77-collapsed",
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "collapsed"
},
{
"id": "reactflow__edge-1b7e0df8-8589-4915-a4ea-c0088f15d642collection-1b89067c-3f6b-42c8-991f-e3055789b251collection",
"source": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
"target": "1b89067c-3f6b-42c8-991f-e3055789b251",
"type": "default",
"sourceHandle": "collection",
"targetHandle": "collection"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-fc9d0e35-a6de-4a19-84e1-c72497c823f6clip",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-1b89067c-3f6b-42c8-991f-e3055789b251item-fc9d0e35-a6de-4a19-84e1-c72497c823f6prompt",
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"type": "default",
"sourceHandle": "item",
"targetHandle": "prompt"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-c2eaf1ba-5708-4679-9e15-945b8b432692clip",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5value-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77seed",
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-fc9d0e35-a6de-4a19-84e1-c72497c823f6conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5epositive_conditioning",
"source": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-c2eaf1ba-5708-4679-9e15-945b8b432692conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enegative_conditioning",
"source": "c2eaf1ba-5708-4679-9e15-945b8b432692",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77noise-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enoise",
"source": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426unet-2fb1577f-0a56-4f12-8711-8afcaaaf1d5eunet",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-2fb1577f-0a56-4f12-8711-8afcaaaf1d5elatents-491ec988-3c77-4c37-af8a-39a0c4e7a2a1latents",
"source": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426vae-491ec988-3c77-4c37-af8a-39a0c4e7a2a1vae",
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
}
]
}

View File

@ -1,903 +0,0 @@
{
"name": "Text to Image with LoRA",
"author": "InvokeAI",
"description": "Simple text to image workflow with a LoRA",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "text to image, lora, default",
"notes": "",
"exposedFields": [
{
"nodeId": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"fieldName": "model"
},
{
"nodeId": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"fieldName": "lora"
},
{
"nodeId": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"fieldName": "weight"
},
{
"nodeId": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"fieldName": "prompt"
}
],
"meta": {
"version": "2.0.0",
"category": "default"
},
"id": "a9d70c39-4cdd-4176-9942-8ff3fe32d3b1",
"nodes": [
{
"id": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"type": "invocation",
"data": {
"id": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"type": "compel",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"prompt": {
"id": "39fe92c4-38eb-4cc7-bf5e-cbcd31847b11",
"name": "prompt",
"fieldKind": "input",
"label": "Negative Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "14313164-e5c4-4e40-a599-41b614fe3690",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "02140b9d-50f3-470b-a0b7-01fc6ed2dcd6",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 3425,
"y": -300
}
},
{
"id": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"type": "invocation",
"data": {
"id": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"type": "main_model_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"model": {
"id": "e2e1c177-ae39-4244-920e-d621fa156a24",
"name": "model",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MainModelField"
},
"value": {
"model_name": "Analog-Diffusion",
"base_model": "sd-1",
"model_type": "main"
}
}
},
"outputs": {
"vae": {
"id": "f91410e8-9378-4298-b285-f0f40ffd9825",
"name": "vae",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"clip": {
"id": "928d91bf-de0c-44a8-b0c8-4de0e2e5b438",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
},
"unet": {
"id": "eacaf530-4e7e-472e-b904-462192189fc1",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
}
}
},
"width": 320,
"height": 227,
"position": {
"x": 2500,
"y": -600
}
},
{
"id": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "invocation",
"data": {
"id": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "lora_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"lora": {
"id": "36d867e8-92ea-4c3f-9ad5-ba05c64cf326",
"name": "lora",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LoRAModelField"
},
"value": {
"model_name": "Ink scenery",
"base_model": "sd-1"
}
},
"weight": {
"id": "8be86540-ba81-49b3-b394-2b18fa70b867",
"name": "weight",
"fieldKind": "input",
"label": "LoRA Weight",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0.75
},
"unet": {
"id": "9c4d5668-e9e1-411b-8f4b-e71115bc4a01",
"name": "unet",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "918ec00e-e76f-4ad0-aee1-3927298cf03b",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"unet": {
"id": "c63f7825-1bcf-451d-b7a7-aa79f5c77416",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "6f79ef2d-00f7-4917-bee3-53e845bf4192",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
}
},
"width": 320,
"height": 252,
"position": {
"x": 2975,
"y": -600
}
},
{
"id": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"type": "invocation",
"data": {
"id": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"type": "compel",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"inputs": {
"prompt": {
"id": "39fe92c4-38eb-4cc7-bf5e-cbcd31847b11",
"name": "prompt",
"fieldKind": "input",
"label": "Positive Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": "cute tiger cub"
},
"clip": {
"id": "14313164-e5c4-4e40-a599-41b614fe3690",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "02140b9d-50f3-470b-a0b7-01fc6ed2dcd6",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 256,
"position": {
"x": 3425,
"y": -575
}
},
{
"id": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "invocation",
"data": {
"id": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "denoise_latents",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.0",
"inputs": {
"positive_conditioning": {
"id": "025ff44b-c4c6-4339-91b4-5f461e2cadc5",
"name": "positive_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"negative_conditioning": {
"id": "2d92b45a-a7fb-4541-9a47-7c7495f50f54",
"name": "negative_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"noise": {
"id": "4d0deeff-24ed-4562-a1ca-7833c0649377",
"name": "noise",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"steps": {
"id": "c9907328-aece-4af9-8a95-211b4f99a325",
"name": "steps",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 10
},
"cfg_scale": {
"id": "7cf0f031-2078-49f4-9273-bb3a64ad7130",
"name": "cfg_scale",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "FloatField"
},
"value": 7.5
},
"denoising_start": {
"id": "44cec3ba-b404-4b51-ba98-add9d783279e",
"name": "denoising_start",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"denoising_end": {
"id": "3e7975f3-e438-4a13-8a14-395eba1fb7cd",
"name": "denoising_end",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 1
},
"scheduler": {
"id": "a6f6509b-7bb4-477d-b5fb-74baefa38111",
"name": "scheduler",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SchedulerField"
},
"value": "euler"
},
"unet": {
"id": "5a87617a-b09f-417b-9b75-0cea4c255227",
"name": "unet",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"control": {
"id": "db87aace-ace8-4f2a-8f2b-1f752389fa9b",
"name": "control",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "ControlField"
}
},
"ip_adapter": {
"id": "f0c133ed-4d6d-4567-bb9a-b1779810993c",
"name": "ip_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "IPAdapterField"
}
},
"t2i_adapter": {
"id": "59ee1233-887f-45e7-aa14-cbad5f6cb77f",
"name": "t2i_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "T2IAdapterField"
}
},
"cfg_rescale_multiplier": {
"id": "1a12e781-4b30-4707-b432-18c31866b5c3",
"name": "cfg_rescale_multiplier",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"latents": {
"id": "d0e593ae-305c-424b-9acd-3af830085832",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"denoise_mask": {
"id": "b81b5a79-fc2b-4011-aae6-64c92bae59a7",
"name": "denoise_mask",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "DenoiseMaskField"
}
}
},
"outputs": {
"latents": {
"id": "9ae4022a-548e-407e-90cf-cc5ca5ff8a21",
"name": "latents",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "730ba4bd-2c52-46bb-8c87-9b3aec155576",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "52b98f0b-b5ff-41b5-acc7-d0b1d1011a6f",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 705,
"position": {
"x": 3975,
"y": -575
}
},
{
"id": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "invocation",
"data": {
"id": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "noise",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"inputs": {
"seed": {
"id": "446ac80c-ba0a-4fea-a2d7-21128f52e5bf",
"name": "seed",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"width": {
"id": "779831b3-20b4-4f5f-9de7-d17de57288d8",
"name": "width",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"height": {
"id": "08959766-6d67-4276-b122-e54b911f2316",
"name": "height",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"use_cpu": {
"id": "53b36a98-00c4-4dc5-97a4-ef3432c0a805",
"name": "use_cpu",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
"outputs": {
"noise": {
"id": "eed95824-580b-442f-aa35-c073733cecce",
"name": "noise",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "7985a261-dfee-47a8-908a-c5a8754f5dc4",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "3d00f6c1-84b0-4262-83d9-3bf755babeea",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 3425,
"y": 75
}
},
{
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"type": "invocation",
"data": {
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"type": "rand_int",
"label": "",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": false,
"version": "1.0.0",
"inputs": {
"low": {
"id": "d25305f3-bfd6-446c-8e2c-0b025ec9e9ad",
"name": "low",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"high": {
"id": "10376a3d-b8fe-4a51-b81a-ea46d8c12c78",
"name": "high",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 2147483647
}
},
"outputs": {
"value": {
"id": "c64878fa-53b1-4202-b88a-cfb854216a57",
"name": "value",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 3425,
"y": 0
}
},
{
"id": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "invocation",
"data": {
"id": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "l2i",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": false,
"useCache": true,
"version": "1.2.0",
"inputs": {
"metadata": {
"id": "b1982e8a-14ad-4029-a697-beb30af8340f",
"name": "metadata",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MetadataField"
}
},
"latents": {
"id": "f7669388-9f91-46cc-94fc-301fa7041c3e",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"vae": {
"id": "c6f2d4db-4d0a-4e3d-acb4-b5c5a228a3e2",
"name": "vae",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"tiled": {
"id": "19ef7d31-d96f-4e94-b7e5-95914e9076fc",
"name": "tiled",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
},
"fp32": {
"id": "a9454533-8ab7-4225-b411-646dc5e76d00",
"name": "fp32",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
}
},
"outputs": {
"image": {
"id": "4f81274e-e216-47f3-9fb6-f97493a40e6f",
"name": "image",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ImageField"
}
},
"width": {
"id": "61a9acfb-1547-4f1e-8214-e89bd3855ee5",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "b15cc793-4172-4b07-bcf4-5627bbc7d0d7",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 267,
"position": {
"x": 4450,
"y": -550
}
}
],
"edges": [
{
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953-ea18915f-2c5b-4569-b725-8e9e9122e8d3-collapsed",
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "collapsed"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818clip-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-c3fa6872-2599-4a82-a596-b3446a66cf8bclip",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818unet-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63unet",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-85b77bb2-c67a-416a-b3e8-291abe746c44conditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63negative_conditioning",
"source": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-c3fa6872-2599-4a82-a596-b3446a66cf8bconditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63positive_conditioning",
"source": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-ea18915f-2c5b-4569-b725-8e9e9122e8d3noise-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63noise",
"source": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-6fd74a17-6065-47a5-b48b-f4e2b8fa7953value-ea18915f-2c5b-4569-b725-8e9e9122e8d3seed",
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63latents-a9683c0a-6b1f-4a5e-8187-c57e764b3400latents",
"source": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818vae-a9683c0a-6b1f-4a5e-8187-c57e764b3400vae",
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-85b77bb2-c67a-416a-b3e8-291abe746c44clip",
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
"target": "85b77bb2-c67a-416a-b3e8-291abe746c44",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
}
]
}

View File

@ -3,9 +3,10 @@ from enum import Enum
from typing import Any, Union
import semver
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
from pydantic import BaseModel, Field, JsonValue, TypeAdapter, field_validator
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string
__workflow_meta_version__ = semver.Version.parse("1.0.0")
@ -31,13 +32,12 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum):
class WorkflowCategory(str, Enum, metaclass=MetaEnum):
User = "user"
Default = "default"
Project = "project"
class WorkflowMeta(BaseModel):
version: str = Field(description="The version of the workflow schema.")
category: WorkflowCategory = Field(
default=WorkflowCategory.User, description="The category of the workflow (user or default)."
)
category: WorkflowCategory = Field(description="The category of the workflow (user or default).")
@field_validator("version")
def validate_version(cls, version: str):
@ -65,26 +65,12 @@ class WorkflowWithoutID(BaseModel):
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
model_config = ConfigDict(extra="ignore")
WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
class UnsafeWorkflowWithVersion(BaseModel):
"""
This utility model only requires a workflow to have a valid version string.
It is used to validate a workflow version without having to validate the entire workflow.
"""
meta: WorkflowMeta = Field(description="The meta of the workflow.")
UnsafeWorkflowWithVersionValidator = TypeAdapter(UnsafeWorkflowWithVersion)
class Workflow(WorkflowWithoutID):
id: str = Field(description="The id of the workflow.")
id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
WorkflowValidator = TypeAdapter(Workflow)

View File

@ -14,10 +14,9 @@ from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowRecordListItemDTO,
WorkflowRecordListItemDTOValidator,
WorkflowRecordOrderBy,
WorkflowValidator,
WorkflowWithoutID,
WorkflowWithoutIDValidator,
)
from invokeai.app.util.misc import uuid_string
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
@ -26,6 +25,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
self._create_tables()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@ -66,7 +66,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
try:
# Only user workflows may be created by this method
assert workflow.meta.category is WorkflowCategory.User
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
workflow_with_id = WorkflowValidator.validate_python(workflow.model_dump())
self._lock.acquire()
self._cursor.execute(
"""--sql
@ -204,8 +204,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
workflow_paths = workflows_dir.glob("*.json")
for path in workflow_paths:
bytes_ = path.read_bytes()
workflow_without_id = WorkflowWithoutIDValidator.validate_json(bytes_)
workflow = Workflow(**workflow_without_id.model_dump(), id=uuid_string())
workflow = WorkflowValidator.validate_json(bytes_)
workflows.append(workflow)
# Only default workflows may be managed by this method
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
@ -232,3 +231,87 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise
finally:
self._lock.release()
def _create_tables(self) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
"""
)
# We do not need the original `workflows` table or `workflow_images` junction table.
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflow_images;
"""
)
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflows;
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()

View File

@ -1,8 +0,0 @@
import re
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
ti_triggers = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
ti_triggers.append(trigger)
return ti_triggers

View File

@ -28,7 +28,7 @@ def check_invokeai_root(config: InvokeAIAppConfig):
print("== STARTUP ABORTED ==")
print("** One or more necessary files is missing from your InvokeAI root directory **")
print("** Please rerun the configuration script to fix this problem. **")
print("** From the launcher, selection option [6]. **")
print("** From the launcher, selection option [7]. **")
print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
)

View File

@ -215,9 +215,7 @@ class ModelPatcher:
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
for ti_name, _ in ti_list:
ti_tokens = []
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i]

View File

@ -32,8 +32,6 @@ class ModelProbeInfo(object):
upcast_attention: bool
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
image_size: int
name: Optional[str] = None
description: Optional[str] = None
class ProbeBase(object):
@ -115,16 +113,12 @@ class ModelProbe(object):
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
name = cls.get_model_name(model_path)
description = f"{base_type.value} {model_type.value} model {name}"
format = probe.get_format()
model_info = ModelProbeInfo(
model_type=model_type,
base_type=base_type,
variant_type=variant_type,
prediction_type=prediction_type,
name=name,
description=description,
upcast_attention=(
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
@ -148,13 +142,6 @@ class ModelProbe(object):
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
@ -389,7 +376,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[-1]
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:

View File

@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
:param checkpoint: The checkpoint
"""
def _get_shape_1(key: str, tensor, checkpoint) -> int:
def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None
if "." not in key:
@ -57,10 +57,6 @@ def lora_token_vector_length(checkpoint: dict) -> int:
for key, tensor in checkpoint.items():
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_unet_") and (
"time_emb_proj.lora_down" in key
): # recognizes format at https://civitai.com/models/224641
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
elif key.startswith("lora_te") and "_self_attn_" in key:
tmp_length = _get_shape_1(key, tensor, checkpoint)
if key.startswith("lora_te_"):

View File

@ -1,29 +0,0 @@
"""Re-export frequently-used symbols from the Model Manager backend."""
from .config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigFactory,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from .probe import ModelProbe
from .search import ModelSearch
__all__ = [
"ModelProbe",
"ModelSearch",
"InvalidModelConfigException",
"ModelConfigFactory",
"BaseModelType",
"ModelType",
"SubModelType",
"ModelVariantType",
"ModelFormat",
"SchedulerPredictionType",
"AnyModelConfig",
]

View File

@ -23,7 +23,7 @@ from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from typing_extensions import Annotated
class InvalidModelConfigException(Exception):
@ -122,7 +122,7 @@ class ModelConfigBase(BaseModel):
validate_assignment=True,
)
def update(self, attributes: Dict[str, Any]) -> None:
def update(self, attributes: dict):
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
@ -195,6 +195,8 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):

View File

@ -0,0 +1,93 @@
# Copyright (c) 2023 Lincoln D. Stein
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
class MigrateModelYamlToDb:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
def __init__(self):
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger)
return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
def migrate(self):
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
for model_key, stanza in yaml.items():
if model_key == "__metadata__":
assert (
stanza["version"] == "3.0.0"
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
continue
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza)
self.logger.info(f"Adding model {model_name} with key {model_key}")
try:
db.add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
def main():
MigrateModelYamlToDb().migrate()
if __name__ == "__main__":
main()

View File

@ -1,686 +0,0 @@
import json
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.base import read_checkpoint_meta
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from invokeai.backend.model_management.util import lora_token_vector_length
from invokeai.backend.util.util import SilenceWarnings
from .config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigFactory,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
from .hash import FastModelHash
CkptType = Dict[str, Any]
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
class ProbeBase(object):
"""Base class for probes."""
def __init__(self, model_path: Path):
self.model_path = model_path
def get_base_type(self) -> BaseModelType:
"""Get model base type."""
raise NotImplementedError
def get_format(self) -> ModelFormat:
"""Get model file format."""
raise NotImplementedError
def get_variant_type(self) -> Optional[ModelVariantType]:
"""Get model variant type."""
return None
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
"""Get model scheduler prediction type."""
return None
class ModelProbe(object):
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
"checkpoint": {},
"onnx": {},
}
CLASS2TYPE = {
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
) -> None:
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
return cls.probe(model_path, fields)
@classmethod
def probe(
cls,
model_path: Path,
fields: Optional[Dict[str, Any]] = None,
) -> AnyModelConfig:
"""
Probe the model at model_path and return its configuration record.
:param model_path: Path to the model file (checkpoint) or directory (diffusers).
:param fields: An optional dictionary that can be used to override probed
fields. Typically used for fields that don't probe well, such as prediction_type.
Returns: The appropriate model configuration derived from ModelConfigBase.
"""
if fields is None:
fields = {}
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = None
if format_type == "diffusers":
model_type = cls.get_model_type_from_folder(model_path)
else:
model_type = cls.get_model_type_from_checkpoint(model_path)
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
hash = FastModelHash.hash(model_path)
probe = probe_class(model_path)
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
# additional fields needed for main and controlnet models
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
fields["config"] = cls._get_checkpoint_config_path(
model_path,
model_type=fields["type"],
base_type=fields["base"],
variant_type=fields["variant"],
prediction_type=fields["prediction_type"],
).as_posix()
# additional fields needed for main non-checkpoint models
elif fields["type"] == ModelType.Main and fields["format"] in [
ModelFormat.Onnx,
ModelFormat.Olive,
ModelFormat.Diffusers,
]:
fields["upcast_attention"] = fields.get("upcast_attention") or (
fields["base"] == BaseModelType.StableDiffusion2
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
model_info = ModelConfigFactory.make_config(fields)
return model_info
@classmethod
def get_model_name(cls, model_path: Path) -> str:
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
return model_path.stem
else:
return model_path.name
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
ckpt = ckpt.get("state_dict", ckpt)
for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.Vae
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.Lora
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.Lora
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
return ModelType.ControlNet
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
"""Get the model type of a hugging-face style folder."""
class_name = None
error_hint = None
for suffix in ["bin", "safetensors"]:
if (folder_path / f"learned_embeds.{suffix}").exists():
return ModelType.TextualInversion
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
return ModelType.Lora
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
config_path = i if i.exists() else c if c.exists() else None
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up
raise InvalidModelConfigException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod
def _get_checkpoint_config_path(
cls,
model_path: Path,
model_type: ModelType,
base_type: BaseModelType,
variant_type: ModelVariantType,
prediction_type: SchedulerPredictionType,
) -> Path:
# look for a YAML file adjacent to the model file first
possible_conf = model_path.with_suffix(".yaml")
if possible_conf.exists():
return possible_conf.absolute()
if model_type == ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
elif model_type == ModelType.ControlNet:
config_file = (
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
)
else:
raise InvalidModelConfigException(
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
)
assert isinstance(config_file, str)
return Path(config_file)
@classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path)
assert isinstance(model, dict)
return model
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
# ##################################################3
# Checkpoint probing
# ##################################################3
class CheckpointProbeBase(ProbeBase):
def __init__(self, model_path: Path):
super().__init__(model_path)
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
def get_format(self) -> ModelFormat:
return ModelFormat("checkpoint")
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise InvalidModelConfigException(
f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
)
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelConfigException("Cannot determine base type")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
"""Return model prediction type."""
type = self.get_base_type()
if type == BaseModelType.StableDiffusion2:
checkpoint = self.checkpoint
state_dict = self.checkpoint.get("state_dict") or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if "global_step" in checkpoint:
if checkpoint["global_step"] == 220000:
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
elif type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
else:
return SchedulerPredictionType.Epsilon
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
"""Class for LoRA checkpoints."""
def get_format(self) -> ModelFormat:
return ModelFormat("lycoris")
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
token_vector_length = lora_token_vector_length(checkpoint)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):
"""Class for probing embeddings."""
def get_format(self) -> ModelFormat:
return ModelFormat.EmbeddingFile
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if "string_to_token" in checkpoint:
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
elif token_dim == 1280:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
class ControlNetCheckpointProbe(CheckpointProbeBase):
"""Class for probing controlnets."""
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1
elif checkpoint[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
def get_format(self) -> ModelFormat:
return ModelFormat("diffusers")
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
config_file = self.model_path / "unet" / "config.json"
with open(config_file, "r") as file:
conf = json.load(file)
in_channels = conf["in_channels"]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
return ModelVariantType.Depth
elif in_channels == 4:
return ModelVariantType.Normal
except Exception:
pass
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
return BaseModelType.StableDiffusionXL
else:
return BaseModelType.StableDiffusion1
def _config_looks_like_sdxl(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _guess_name(self) -> str:
name = self.model_path.name
if name == "vae":
name = self.model_path.parent.name
return name
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin"
if not path.exists():
raise InvalidModelConfigException(
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
)
return TextualInversionCheckpointProbe(path).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat("onnx")
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelConfigException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> IPAdapterModelFormat:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.model_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelConfigException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelConfigException(
f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -1,190 +0,0 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class and implementation for recursive directory search for models.
Example usage:
```
from invokeai.backend.model_manager import ModelSearch, ModelProbe
def find_main_models(model: Path) -> bool:
info = ModelProbe.probe(model)
if info.model_type == 'main' and info.base_type == 'sd-1':
return True
else:
return False
search = ModelSearch(on_model_found=report_it)
found = search.search('/tmp/models')
print(found) # list of matching model paths
print(search.stats) # search stats
```
"""
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.backend.util.logging import InvokeAILogger
default_logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel):
items_scanned: int = 0
models_found: int = 0
models_filtered: int = 0
class ModelSearchBase(ABC, BaseModel):
"""
Abstract directory traversal model search class
Usage:
search = ModelSearchBase(
on_search_started = search_started_callback,
on_search_completed = search_completed_callback,
on_model_found = model_found_callback,
)
models_found = search.search('/path/to/directory')
"""
# fmt: off
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
arbitrary_types_allowed = True
@abstractmethod
def search_started(self) -> None:
"""
Called before the scan starts.
Passes the root search directory to the Callable `on_search_started`.
"""
pass
@abstractmethod
def model_found(self, model: Path) -> None:
"""
Called when a model is found during search.
:param model: Model to process - could be a directory or checkpoint.
Passes the model's Path to the Callable `on_model_found`.
This Callable receives the path to the model and returns a boolean
to indicate whether the model should be returned in the search
results.
"""
pass
@abstractmethod
def search_completed(self) -> None:
"""
Called before the scan starts.
Passes the Set of found model Paths to the Callable `on_search_completed`.
"""
pass
@abstractmethod
def search(self, directory: Union[Path, str]) -> Set[Path]:
"""
Recursively search for models in `directory` and return a set of model paths.
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
Callables will be invoked during the search.
"""
pass
class ModelSearch(ModelSearchBase):
"""
Implementation of ModelSearch with callbacks.
Usage:
search = ModelSearch()
search.model_found = lambda path : 'anime' in path.as_posix()
found = search.list_models(['/tmp/models1','/tmp/models2'])
# returns all models that have 'anime' in the path
"""
models_found: Set[Path] = Field(default=None)
scanned_dirs: Set[Path] = Field(default=None)
pruned_paths: Set[Path] = Field(default=None)
def search_started(self) -> None:
self.models_found = set()
self.scanned_dirs = set()
self.pruned_paths = set()
if self.on_search_started:
self.on_search_started(self._directory)
def model_found(self, model: Path) -> None:
self.stats.models_found += 1
if not self.on_model_found or self.on_model_found(model):
self.stats.models_filtered += 1
self.models_found.add(model)
def search_completed(self) -> None:
if self.on_search_completed:
self.on_search_completed(self._models_found)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
self.search_completed()
return self.models_found
def _walk_directory(self, path: Union[Path, str]) -> None:
for root, dirs, files in os.walk(path, followlinks=True):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self.pruned_paths.add(Path(root))
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
continue
self.stats.items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path.parent in self.scanned_dirs:
self.scanned_dirs.add(path)
continue
if any(
(path / x).exists()
for x in [
"config.json",
"model_index.json",
"learned_embeds.bin",
"pytorch_lora_weights.bin",
"image_encoder.txt",
]
):
self.scanned_dirs.add(path)
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self.scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except KeyboardInterrupt:
raise
except Exception as e:
self.logger.warning(str(e))

View File

@ -242,6 +242,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_model: ControlNetModel = None,
):
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
@ -249,9 +260,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
# FIXME: can't currently register control module
# control_model=control_model,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False
@ -276,11 +287,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing()
return
elif config.attention_type == "torch-sdp":
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
else:
raise Exception("torch-sdp attention slicing not available")
raise Exception("torch-sdp attention slicing not yet implemented")
# the remainder if this code is called when attention_type=='auto'
if self.unet.device.type == "cuda":
@ -288,7 +295,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.enable_xformers_memory_efficient_attention()
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
# diffusers enable sdp automatically
return
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":

View File

@ -3,42 +3,7 @@ from typing import Union
import numpy as np
from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR
from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend
def calc_overlap(tiles: list[Tile], num_tiles_x: int, num_tiles_y: int) -> list[Tile]:
"""Calculate and update the overlap of a list of tiles.
Args:
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
num_tiles_x: the number of tiles on the x axis.
num_tiles_y: the number of tiles on the y axis.
"""
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
for tile_idx_y in range(num_tiles_y):
for tile_idx_x in range(num_tiles_x):
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
assert cur_tile is not None
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
return tiles
from invokeai.backend.tiles.utils import TBLR, Tile, paste
def calc_tiles_with_overlap(
@ -98,133 +63,31 @@ def calc_tiles_with_overlap(
tiles.append(tile)
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
def calc_tiles_even_split(
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap: int = 0
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape with the number of tiles requested.
Args:
image_height (int): The image height in px.
image_width (int): The image width in px.
num_x_tiles (int): The number of tile to split the image into on the X-axis.
num_y_tiles (int): The number of tile to split the image into on the Y-axis.
overlap (int, optional): The overlap between adjacent tiles in pixels. Defaults to 0.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
# Ensure the image is divisible by LATENT_SCALE_FACTOR
if image_width % LATENT_SCALE_FACTOR != 0 or image_height % LATENT_SCALE_FACTOR != 0:
raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by {LATENT_SCALE_FACTOR}")
# Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down)
if num_tiles_x > 1:
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
assert overlap <= image_width - (LATENT_SCALE_FACTOR * (num_tiles_x - 1))
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
((image_width + overlap * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
)
assert overlap < tile_size_x
else:
tile_size_x = image_width
if num_tiles_y > 1:
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
assert overlap <= image_height - (LATENT_SCALE_FACTOR * (num_tiles_y - 1))
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
((image_height + overlap * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
)
assert overlap < tile_size_y
else:
tile_size_y = image_height
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
# Iterate over tiles again and calculate overlaps.
for tile_idx_y in range(num_tiles_y):
# Calculate the top and bottom of the row
top = tile_idx_y * (tile_size_y - overlap)
bottom = min(top + tile_size_y, image_height)
# For the last row adjust bottom to be the height of the image
if tile_idx_y == num_tiles_y - 1:
bottom = image_height
for tile_idx_x in range(num_tiles_x):
# Calculate the left & right coordinate of each tile
left = tile_idx_x * (tile_size_x - overlap)
right = min(left + tile_size_x, image_width)
# For the last tile in the row adjust right to be the width of the image
if tile_idx_x == num_tiles_x - 1:
right = image_width
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
tile = Tile(
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
assert cur_tile is not None
tiles.append(tile)
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
def calc_tiles_min_overlap(
image_height: int,
image_width: int,
tile_height: int,
tile_width: int,
min_overlap: int = 0,
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
Args:
image_height (int): The image height in px.
image_width (int): The image width in px.
tile_height (int): The tile height in px. All tiles will have this height.
tile_width (int): The tile width in px. All tiles will have this width.
min_overlap (int): The target minimum overlap between adjacent tiles. If the tiles do not evenly cover the image
shape, then the overlap will be spread between the tiles.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
assert min_overlap < tile_height
assert min_overlap < tile_width
# catches the cases when the tile size is larger than the images size and adjusts the tile size
if image_width < tile_width:
tile_width = image_width
if image_height < tile_height:
tile_height = image_height
num_tiles_x = math.ceil((image_width - min_overlap) / (tile_width - min_overlap))
num_tiles_y = math.ceil((image_height - min_overlap) / (tile_height - min_overlap))
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y):
top = (tile_idx_y * (image_height - tile_height)) // (num_tiles_y - 1) if num_tiles_y > 1 else 0
bottom = top + tile_height
for tile_idx_x in range(num_tiles_x):
left = (tile_idx_x * (image_width - tile_width)) // (num_tiles_x - 1) if num_tiles_x > 1 else 0
right = left + tile_width
tile = Tile(
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
tiles.append(tile)
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
return tiles
def merge_tiles_with_linear_blending(
@ -336,91 +199,3 @@ def merge_tiles_with_linear_blending(
),
mask=mask,
)
def merge_tiles_with_seam_blending(
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
):
"""Merge a set of image tiles into `dst_image` with seam blending between the tiles.
We expect every tile edge to either:
1) have an overlap of 0, because it is aligned with the image edge, or
2) have an overlap >= blend_amount.
If neither of these conditions are satisfied, we raise an exception.
The seam blending is centered on a seam of least energy of the overlap between adjacent tiles.
Args:
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
"""
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
# iterate over tiles left-to-right, top-to-bottom.
tiles_and_images = list(zip(tiles, tile_images, strict=True))
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
# Organize tiles into rows.
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
first_tile_in_cur_row, _ = tiles_and_images[0]
for tile_and_image in tiles_and_images:
tile, _ = tile_and_image
if not (
tile.coords.top == first_tile_in_cur_row.coords.top
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
):
# Store the previous row, and start a new one.
tile_and_image_rows.append(cur_tile_and_image_row)
cur_tile_and_image_row = []
first_tile_in_cur_row, _ = tile_and_image
cur_tile_and_image_row.append(tile_and_image)
tile_and_image_rows.append(cur_tile_and_image_row)
for tile_and_image_row in tile_and_image_rows:
first_tile_in_row, _ = tile_and_image_row[0]
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
# Blend the tiles in the row horizontally.
for tile, tile_image in tile_and_image_row:
# We expect the tiles to be ordered left-to-right.
# For each tile:
# - extract the overlap regions and pass to seam_blend()
# - apply blended region to the row_image
# - apply the un-blended region to the row_image
tile_height, tile_width, _ = tile_image.shape
overlap_size = tile.overlap.left
# Left blending:
if overlap_size > 0:
assert overlap_size >= blend_amount
overlap_coord_right = tile.coords.left + overlap_size
src_overlap = row_image[:, tile.coords.left : overlap_coord_right]
dst_overlap = tile_image[:, :overlap_size]
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=False)
row_image[:, tile.coords.left : overlap_coord_right] = blended_overlap
row_image[:, overlap_coord_right : tile.coords.right] = tile_image[:, overlap_size:]
else:
# no overlap just paste the tile
row_image[:, tile.coords.left : tile.coords.right] = tile_image
# Blend the row into the dst_image
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
# Rows are processed in the same way as tiles (extract overlap, blend, apply)
row_overlap_size = first_tile_in_row.overlap.top
if row_overlap_size > 0:
assert row_overlap_size >= blend_amount
overlap_coords_bottom = first_tile_in_row.coords.top + row_overlap_size
src_overlap = dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :]
dst_overlap = row_image[:row_overlap_size, :]
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=True)
dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] = blended_overlap
dst_image[overlap_coords_bottom : first_tile_in_row.coords.bottom, :] = row_image[row_overlap_size:, :]
else:
# no overlap just paste the row
dst_image[first_tile_in_row.coords.top : first_tile_in_row.coords.bottom, :] = row_image

View File

@ -1,7 +1,5 @@
import math
from typing import Optional
import cv2
import numpy as np
from pydantic import BaseModel, Field
@ -33,10 +31,10 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
"""Paste a source image into a destination image.
Args:
dst_image (np.array): The destination image to paste into. Shape: (H, W, C).
src_image (np.array): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C).
src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
mask (Optional[np.array]): A mask that defines the blending between 'src_image' and 'dst_image'.
mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'.
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
`src * mask + dst * (1 - mask)`.
"""
@ -47,106 +45,3 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
mask = np.expand_dims(mask, -1)
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
def seam_blend(ia1: np.ndarray, ia2: np.ndarray, blend_amount: int, x_seam: bool) -> np.ndarray:
"""Blend two overlapping tile sections using a seams to find a path.
It is assumed that input images will be RGB np arrays and are the same size.
Args:
ia1 (np.array): Image array 1 Shape: (H, W, C).
ia2 (np.array): Image array 2 Shape: (H, W, C).
x_seam (bool): If the images should be blended on the x axis or not.
blend_amount (int): The size of the blur to use on the seam. Half of this value will be used to avoid the edges of the image.
"""
assert ia1.shape == ia2.shape
assert ia2.size == ia2.size
def shift(arr, num, fill_value=255.0):
result = np.full_like(arr, fill_value)
if num > 0:
result[num:] = arr[:-num]
elif num < 0:
result[:num] = arr[-num:]
else:
result[:] = arr
return result
# Assume RGB and convert to grey
# Could offer other options for the luminance conversion
# BT.709 [0.2126, 0.7152, 0.0722], BT.2020 [0.2627, 0.6780, 0.0593])
# it might not have a huge impact due to the blur that is applied over the seam
iag1 = np.dot(ia1, [0.2989, 0.5870, 0.1140]) # BT.601 perceived brightness
iag2 = np.dot(ia2, [0.2989, 0.5870, 0.1140])
# Calc Difference between the images
ia = iag2 - iag1
# If the seam is on the X-axis rotate the array so we can treat it like a vertical seam
if x_seam:
ia = np.rot90(ia, 1)
# Calc max and min X & Y limits
# gutter is used to avoid the blur hitting the edge of the image
gutter = math.ceil(blend_amount / 2) if blend_amount > 0 else 0
max_y, max_x = ia.shape
max_x -= gutter
min_x = gutter
# Calc the energy in the difference
# Could offer different energy calculations e.g. Sobel or Scharr
energy = np.abs(np.gradient(ia, axis=0)) + np.abs(np.gradient(ia, axis=1))
# Find the starting position of the seam
res = np.copy(energy)
for y in range(1, max_y):
row = res[y, :]
rowl = shift(row, -1)
rowr = shift(row, 1)
res[y, :] = res[y - 1, :] + np.min([row, rowl, rowr], axis=0)
# create an array max_y long
lowest_energy_line = np.empty([max_y], dtype="uint16")
lowest_energy_line[max_y - 1] = np.argmin(res[max_y - 1, min_x : max_x - 1])
# Calc the path of the seam
# could offer options for larger search than just 1 pixel by adjusting lpos and rpos
for ypos in range(max_y - 2, -1, -1):
lowest_pos = lowest_energy_line[ypos + 1]
lpos = lowest_pos - 1
rpos = lowest_pos + 1
lpos = np.clip(lpos, min_x, max_x - 1)
rpos = np.clip(rpos, min_x, max_x - 1)
lowest_energy_line[ypos] = np.argmin(energy[ypos, lpos : rpos + 1]) + lpos
# Draw the mask
mask = np.zeros_like(ia)
for ypos in range(0, max_y):
to_fill = lowest_energy_line[ypos]
mask[ypos, :to_fill] = 1
# If the seam is on the X-axis rotate the array back
if x_seam:
mask = np.rot90(mask, 3)
# blur the seam mask if required
if blend_amount > 0:
mask = cv2.blur(mask, (blend_amount, blend_amount))
# for visual debugging
# from PIL import Image
# m_image = Image.fromarray((mask * 255.0).astype("uint8"))
# copy ia2 over ia1 while applying the seam mask
mask = np.expand_dims(mask, -1)
blended_image = ia1 * mask + ia2 * (1.0 - mask)
# for visual debugging
# i1 = Image.fromarray(ia1.astype("uint8"))
# i2 = Image.fromarray(ia2.astype("uint8"))
# b_image = Image.fromarray(blended_image.astype("uint8"))
# print(f"{ia1.shape}, {ia2.shape}, {mask.shape}, {blended_image.shape}")
# print(f"{i1.size}, {i2.size}, {m_image.size}, {b_image.size}")
return blended_image

View File

@ -11,7 +11,4 @@ from .devices import ( # noqa: F401
normalize_device,
torch_dtype,
)
from .logging import InvokeAILogger
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import platform
from contextlib import nullcontext
from typing import Union
import torch
from packaging import version
from torch import autocast
from invokeai.app.services.config import InvokeAIAppConfig
@ -35,7 +37,7 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16"
elif device.type == "mps":
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"):
return "float16"
return "float32"

View File

@ -342,13 +342,14 @@ class InvokeAILogger(object): # noqa D102
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger: # noqa D102
if name in cls.loggers:
return cls.loggers[name]
logger = logging.getLogger(name)
logger = cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.get_loggers(config):
logger.addHandler(ch)
cls.loggers[name] = logger
cls.loggers[name] = logger
return cls.loggers[name]
@classmethod
@ -357,7 +358,7 @@ class InvokeAILogger(object): # noqa D102
handlers = []
for handler in handler_strs:
handler_name, *args = handler.split("=", 2)
arg = args[0] if len(args) > 0 else None
args = args[0] if len(args) > 0 else None
# console and file get the fancy formatter.
# syslog gets a simple one
@ -369,16 +370,16 @@ class InvokeAILogger(object): # noqa D102
handlers.append(ch)
elif handler_name == "syslog":
ch = cls._parse_syslog_args(arg)
ch = cls._parse_syslog_args(args)
handlers.append(ch)
elif handler_name == "file":
ch = cls._parse_file_args(arg)
ch = cls._parse_file_args(args)
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name == "http":
ch = cls._parse_http_args(arg)
ch = cls._parse_http_args(args)
handlers.append(ch)
return handlers

View File

@ -32,9 +32,9 @@ sd-1/main/Analog-Diffusion:
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
repo_id: wavymulder/Analog-Diffusion
recommended: False
sd-1/main/Deliberate_v5:
sd-1/main/Deliberate:
description: Versatile model that produces detailed images up to 768px (4.27 GB)
path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors
repo_id: XpucT/Deliberate
recommended: False
sd-1/main/Dungeons-and-Diffusion:
description: Dungeons & Dragons characters (2.13 GB)

View File

@ -4,7 +4,6 @@ pip install <path_to_git_source>.
"""
import os
import platform
from distutils.version import LooseVersion
import pkg_resources
import psutil
@ -32,6 +31,10 @@ else:
console = Console(style=Style(color="grey74", bgcolor="grey19"))
def get_versions() -> dict:
return requests.get(url=INVOKE_AI_REL).json()
def invokeai_is_running() -> bool:
for p in psutil.process_iter():
try:
@ -47,20 +50,6 @@ def invokeai_is_running() -> bool:
return False
def get_pypi_versions():
url = "https://pypi.org/pypi/invokeai/json"
try:
data = requests.get(url).json()
except Exception:
raise Exception("Unable to fetch version information from PyPi")
versions = list(data["releases"].keys())
versions.sort(key=LooseVersion, reverse=True)
latest_version = [v for v in versions if "rc" not in v][0]
latest_release_candidate = [v for v in versions if "rc" in v][0]
return latest_version, latest_release_candidate, versions
def welcome(latest_release: str, latest_prerelease: str):
@group()
def text():
@ -74,7 +63,8 @@ def welcome(latest_release: str, latest_prerelease: str):
yield "[bold yellow]Options:"
yield f"""[1] Update to the latest [bold]official release[/bold] ([italic]{latest_release}[/italic])
[2] Update to the latest [bold]pre-release[/bold] (may be buggy; caveat emptor!) ([italic]{latest_prerelease}[/italic])
[3] Manually enter the [bold]version[/bold] you wish to update to"""
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to"""
console.rule()
print(
@ -102,35 +92,44 @@ def get_extras():
def main():
versions = get_versions()
released_versions = [x for x in versions if not (x["draft"] or x["prerelease"])]
prerelease_versions = [x for x in versions if not x["draft"] and x["prerelease"]]
latest_release = released_versions[0]["tag_name"] if len(released_versions) else None
latest_prerelease = prerelease_versions[0]["tag_name"] if len(prerelease_versions) else None
if invokeai_is_running():
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
input("Press any key to continue...")
return
latest_release, latest_prerelease, versions = get_pypi_versions()
welcome(latest_release, latest_prerelease)
release = latest_release
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
tag = None
branch = None
release = None
choice = Prompt.ask("Choice:", choices=["1", "2", "3", "4"], default="1")
if choice == "1":
release = latest_release
elif choice == "2":
release = latest_prerelease
elif choice == "3":
while True:
release = Prompt.ask("Enter an InvokeAI version")
release.strip()
if release in versions:
break
print(f":exclamation: [bold red]'{release}' is not a recognized InvokeAI release.[/red bold]")
while not tag:
tag = Prompt.ask("Enter an InvokeAI tag name")
elif choice == "4":
while not branch:
branch = Prompt.ask("Enter an InvokeAI branch name")
extras = get_extras()
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade'
print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
if release:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
elif tag:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
else:
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
print("")
print("")
if os.system(cmd) == 0:

View File

@ -11,7 +11,6 @@ module.exports = {
'plugin:react-hooks/recommended',
'plugin:react/jsx-runtime',
'prettier',
'plugin:storybook/recommended',
],
parser: '@typescript-eslint/parser',
parserOptions: {
@ -27,7 +26,6 @@ module.exports = {
'eslint-plugin-react-hooks',
'i18next',
'path',
'unused-imports',
],
root: true,
rules: {
@ -46,16 +44,9 @@ module.exports = {
radix: 'error',
'space-before-blocks': 'error',
'import/prefer-default-export': 'off',
'@typescript-eslint/no-unused-vars': 'off',
'unused-imports/no-unused-imports': 'error',
'unused-imports/no-unused-vars': [
'@typescript-eslint/no-unused-vars': [
'warn',
{
vars: 'all',
varsIgnorePattern: '^_',
args: 'after-used',
argsIgnorePattern: '^_',
},
{ varsIgnorePattern: '^_', argsIgnorePattern: '^_' },
],
'@typescript-eslint/ban-ts-comment': 'warn',
'@typescript-eslint/no-explicit-any': 'warn',

View File

@ -9,8 +9,7 @@ lerna-debug.log*
node_modules
# We want to distribute the repo
dist
dist/**
# dist
dist-ssr
*.local
@ -39,4 +38,4 @@ stats.html
# Yalc
.yalc
yalc.lock
yalc.lock

View File

@ -0,0 +1,4 @@
#!/usr/bin/env sh
. "$(dirname -- "$0")/_/husky.sh"
cd invokeai/frontend/web/ && npm run lint-staged

View File

@ -12,4 +12,3 @@ index.html
src/services/api/schema.d.ts
static/
src/theme/css/overlayscrollbars.css
pnpm-lock.yaml

View File

@ -1,21 +0,0 @@
import type { StorybookConfig } from '@storybook/react-vite';
const config: StorybookConfig = {
stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
addons: [
'@storybook/addon-links',
'@storybook/addon-essentials',
'@storybook/addon-interactions',
],
framework: {
name: '@storybook/react-vite',
options: {},
},
docs: {
autodocs: 'tag',
},
core: {
disableTelemetry: true,
},
};
export default config;

View File

@ -1,6 +0,0 @@
import { addons } from '@storybook/manager-api';
import { themes } from '@storybook/theming';
addons.setConfig({
theme: themes.dark,
});

View File

@ -1,47 +0,0 @@
import { Preview } from '@storybook/react';
import { themes } from '@storybook/theming';
import i18n from 'i18next';
import React from 'react';
import { initReactI18next } from 'react-i18next';
import { Provider } from 'react-redux';
import GlobalHotkeys from '../src/app/components/GlobalHotkeys';
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
import { createStore } from '../src/app/store/store';
// TODO: Disabled for IDE performance issues with our translation JSON
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
import translationEN from '../public/locales/en.json';
i18n.use(initReactI18next).init({
lng: 'en',
resources: {
en: { translation: translationEN },
},
debug: true,
interpolation: {
escapeValue: false,
},
returnNull: false,
});
const store = createStore(undefined, false);
const preview: Preview = {
decorators: [
(Story) => (
<Provider store={store}>
<ThemeLocaleProvider>
<GlobalHotkeys />
<Story />
</ThemeLocaleProvider>
</Provider>
),
],
parameters: {
docs: {
theme: themes.dark,
},
},
};
export default preview;

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,5 @@
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
# yarn lockfile v1
yarn-path ".yarn/releases/yarn-1.22.19.cjs"

View File

@ -0,0 +1 @@
yarnPath: .yarn/releases/yarn-1.22.19.cjs

Some files were not shown because too many files have changed in this diff Show More