mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
287 Commits
feat/workf
...
v3.5.0rc3
Author | SHA1 | Date | |
---|---|---|---|
47682b9910 | |||
aa36554321 | |||
d0fa131010 | |||
2f438431bd | |||
bbeb5cb477 | |||
cd3111c324 | |||
16b7246412 | |||
42be78d328 | |||
e469e24a58 | |||
cb698ff1fb | |||
0e738c4290 | |||
09d1bc513d | |||
aefa828237 | |||
74ea592d02 | |||
457b0dfac0 | |||
96a717c4ba | |||
77b74264a8 | |||
351078e8aa | |||
b8354bd1a4 | |||
3b944b8af6 | |||
b811c037bd | |||
5bf61382a4 | |||
0f1c5f382a | |||
4af1695c60 | |||
df9a903a50 | |||
311be8f97d | |||
3f970c8326 | |||
fc150acde5 | |||
1615df3aa1 | |||
b2a8c45553 | |||
212dbaf9a2 | |||
ac3cf48d7f | |||
454f01e0c1 | |||
72dca55e44 | |||
264ea6d94d | |||
60e3e653fa | |||
082894c377 | |||
4b00f8fc82 | |||
6ea09ba0b6 | |||
296060db63 | |||
d1d8ee71fc | |||
42c04db167 | |||
b935768eeb | |||
ea4ef042f3 | |||
18b2bcbbee | |||
5ad88c7f86 | |||
3b04fef31d | |||
bec888923a | |||
c6235049c7 | |||
e10f6e8962 | |||
77f04ff8d6 | |||
461e474394 | |||
f0c70fe3f1 | |||
442ac2b828 | |||
bb986b97f3 | |||
98655db57b | |||
8845894e83 | |||
937c7e957d | |||
569ae7c482 | |||
340957f920 | |||
076d9b05ea | |||
2b54e240d4 | |||
5127e9df2d | |||
42329a1849 | |||
42bc6ef154 | |||
6c6c45c3da | |||
f76b04a3b8 | |||
821e0326c9 | |||
cc18d86f29 | |||
ed1583383e | |||
c50a49719b | |||
ebf5f5d418 | |||
386b656530 | |||
d7cede6c28 | |||
15de7c21d9 | |||
9620f9336c | |||
a64ced7b29 | |||
dd7deff1a3 | |||
612912a6c9 | |||
bca2372280 | |||
0b860582f0 | |||
87ff380fe4 | |||
2cdda1fda2 | |||
6caa70123d | |||
7e831c8a96 | |||
3d64bc886d | |||
1a136d6167 | |||
43f2837117 | |||
5f77ef7e99 | |||
22ccaa4e9a | |||
d277bd3c38 | |||
fd4e041e7c | |||
15a3e8076f | |||
2fbe3a3104 | |||
b0cfa58526 | |||
285ed26edd | |||
02565b9a00 | |||
78a6024d6c | |||
95198da645 | |||
ee1f1f3363 | |||
18ba7feca1 | |||
55b0c7cdc9 | |||
713a83e7da | |||
f3a97e06ec | |||
50815d36c6 | |||
a69f518c76 | |||
18093c4f1d | |||
0cf7fe43af | |||
6063760ce2 | |||
c5ba4f2ea5 | |||
3414437eea | |||
417db71471 | |||
afe4e55bf9 | |||
55acc16b2d | |||
535ce10e99 | |||
11f4a48144 | |||
67ed4a0245 | |||
fbbc1037cd | |||
0852fd4e88 | |||
c84526fae5 | |||
f762940335 | |||
fefb78795f | |||
ef8284f009 | |||
290851016e | |||
fa7d002175 | |||
f1b6f78319 | |||
26ab917021 | |||
4f3c32a2ee | |||
77065b1ce1 | |||
41db92b9e8 | |||
c823f5667b | |||
3227b30430 | |||
567f107a81 | |||
b3d5955bc7 | |||
8726b203d4 | |||
b3f92e0547 | |||
72c9a7663f | |||
fcb9e89bd7 | |||
56966d6d05 | |||
e46dc9b34e | |||
e461f9925e | |||
abeb1bd3b3 | |||
83e820d721 | |||
f8e4b93a74 | |||
0710ec30cf | |||
c382329e8c | |||
a2dc780188 | |||
abc9dc4d17 | |||
3c692018cd | |||
3ba3c1918c | |||
f2c6819d68 | |||
ef807cf63a | |||
bbcd58e681 | |||
36043bf38b | |||
fd68c47920 | |||
c5c975c7a9 | |||
41ad13c282 | |||
e9d7e6bdd5 | |||
49b74d189e | |||
179bc64490 | |||
1feab3da37 | |||
0a15f3fc35 | |||
daf00efa4d | |||
55cfb879d0 | |||
de2879f602 | |||
3b1ff4a7f4 | |||
d7f7fbc8c2 | |||
e2567a7e31 | |||
2f3457c02a | |||
aab6369ffe | |||
4c97b619fb | |||
abdd840fb9 | |||
e656768eb2 | |||
494c2a9b05 | |||
40d4c7c8e1 | |||
076284c26f | |||
1af4260ab6 | |||
08ef71a74e | |||
8f6e2c0c85 | |||
0ac33f36ef | |||
9661fa5f76 | |||
ca07449fb4 | |||
fb39f621c6 | |||
977d309692 | |||
72cb8b83fe | |||
99f14b1dfe | |||
95a3c89a56 | |||
b271474812 | |||
2272925607 | |||
5902a52e40 | |||
5140056b59 | |||
f17b3d0068 | |||
5b9d25f57e | |||
73dbb8792e | |||
fc6cebb975 | |||
06104f3851 | |||
6e028d691a | |||
6d176601cc | |||
4627a7c75f | |||
d9a0efb20b | |||
7436aa8e3a | |||
d75d3885c3 | |||
db4763a742 | |||
13c9f8ffb7 | |||
e4f67628c0 | |||
283bb73418 | |||
5b5a71d40c | |||
61060f032a | |||
3423b5848f | |||
fd8d1e13a0 | |||
c42d692ea6 | |||
5f37176938 | |||
375a91db32 | |||
b7ba426249 | |||
d3ad356c6a | |||
fdb97c1d02 | |||
8cda42ab0a | |||
fed2bdafeb | |||
9ba5752770 | |||
8648c2c42e | |||
b519b6e1e0 | |||
913c68982a | |||
6e1e67aa72 | |||
ee6fbabbfb | |||
db58efbe65 | |||
cd15d8b7a9 | |||
3b4b4ba40a | |||
eecee472b1 | |||
7b314116be | |||
bc6d4111a2 | |||
674d9796d0 | |||
5816320645 | |||
14254e8be8 | |||
e990235d32 | |||
5f122186bd | |||
3bfaee9c57 | |||
1ca0901cbe | |||
2d7555b7b8 | |||
3c7d1fcd32 | |||
c7fa2db556 | |||
3b06cc6782 | |||
7c9f48b84d | |||
fed2bf6dab | |||
2b583ffcdf | |||
6f46d15c05 | |||
018ccebd6f | |||
620b2d477a | |||
f73b678aae | |||
bdb0d13a2d | |||
2d2ef5d72c | |||
e46ac45741 | |||
75089b7a9d | |||
778fd55f0d | |||
bb87c988cb | |||
049b0239da | |||
932de08fc0 | |||
303791d5c6 | |||
7e4a689370 | |||
04e0fefdee | |||
9b4e6da226 | |||
e1c53a2465 | |||
121b930abf | |||
436560da39 | |||
3980f79ed5 | |||
1d0dc7eeab | |||
1f63fa8236 | |||
caf47dee09 | |||
d742479810 | |||
ecd3dcd5df | |||
a79e814c8d | |||
3fe1bef5cd | |||
dbd0151c0e | |||
6da508f147 | |||
8ef596eac7 | |||
8f4f4d48d5 | |||
60eae7443a | |||
8695ad6f59 | |||
dc5c452ef9 | |||
8aefe2cefe | |||
ec510d34b5 | |||
19baea1883 | |||
80bc9be3ab | |||
8c7a7bc897 | |||
4aab728590 | |||
9cf060115d | |||
9ea3126118 | |||
6c56233edc |
15
.github/pull_request_template.md
vendored
15
.github/pull_request_template.md
vendored
@ -42,6 +42,21 @@ Please provide steps on how to test changes, any hardware or
|
||||
software specifications as well as any other pertinent information.
|
||||
-->
|
||||
|
||||
## Merge Plan
|
||||
|
||||
<!--
|
||||
A merge plan describes how this PR should be handled after it is approved.
|
||||
|
||||
Example merge plans:
|
||||
- "This PR can be merged when approved"
|
||||
- "This must be squash-merged when approved"
|
||||
- "DO NOT MERGE - I will rebase and tidy commits before merging"
|
||||
- "#dev-chat on discord needs to be advised of this change when it is merged"
|
||||
|
||||
A merge plan is particularly important for large PRs or PRs that touch the
|
||||
database in any way.
|
||||
-->
|
||||
|
||||
## Added/updated tests?
|
||||
|
||||
- [ ] Yes
|
||||
|
24
.github/workflows/lint-frontend.yml
vendored
24
.github/workflows/lint-frontend.yml
vendored
@ -22,12 +22,22 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v3
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
- uses: actions/checkout@v3
|
||||
- run: 'yarn install --frozen-lockfile'
|
||||
- run: 'yarn run lint:tsc'
|
||||
- run: 'yarn run lint:madge'
|
||||
- run: 'yarn run lint:eslint'
|
||||
- run: 'yarn run lint:prettier'
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
- name: Install dependencies
|
||||
run: 'pnpm install --prefer-frozen-lockfile'
|
||||
- name: Typescript
|
||||
run: 'pnpm run lint:tsc'
|
||||
- name: Madge
|
||||
run: 'pnpm run lint:madge'
|
||||
- name: ESLint
|
||||
run: 'pnpm run lint:eslint'
|
||||
- name: Prettier
|
||||
run: 'pnpm run lint:prettier'
|
||||
|
50
.github/workflows/pypi-release.yml
vendored
50
.github/workflows/pypi-release.yml
vendored
@ -1,13 +1,15 @@
|
||||
name: PyPI Release
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'invokeai/version/invokeai_version.py'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
publish_package:
|
||||
description: 'Publish build on PyPi? [true/false]'
|
||||
required: true
|
||||
default: 'false'
|
||||
|
||||
jobs:
|
||||
release:
|
||||
build-and-release:
|
||||
if: github.repository == 'invoke-ai/InvokeAI'
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
@ -15,19 +17,43 @@ jobs:
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
TWINE_NON_INTERACTIVE: 1
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: install deps
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm run build
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Install python dependencies
|
||||
run: pip install --upgrade build twine
|
||||
|
||||
- name: build package
|
||||
- name: Build python package
|
||||
run: python3 -m build
|
||||
|
||||
- name: check distribution
|
||||
- name: Upload build as workflow artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist
|
||||
|
||||
- name: Check distribution
|
||||
run: twine check dist/*
|
||||
|
||||
- name: check PyPI versions
|
||||
- name: Check PyPI versions
|
||||
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
|
||||
run: |
|
||||
pip install --upgrade requests
|
||||
@ -36,6 +62,6 @@ jobs:
|
||||
EXISTS=scripts.pypi_helper.local_on_pypi(); \
|
||||
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
|
||||
|
||||
- name: upload package
|
||||
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != ''
|
||||
- name: Publish build on PyPi
|
||||
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
|
||||
run: twine upload dist/*
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -16,7 +16,7 @@ __pycache__/
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
# dist/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
@ -187,3 +187,4 @@ installer/install.bat
|
||||
installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
installer/InvokeAI-Installer/
|
||||
|
33
Makefile
33
Makefile
@ -1,6 +1,20 @@
|
||||
# simple Makefile with scripts that are otherwise hard to remember
|
||||
# to use, run from the repo root `make <command>`
|
||||
|
||||
default: help
|
||||
|
||||
help:
|
||||
@echo Developer commands:
|
||||
@echo
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
ruff check . --fix
|
||||
@ -18,4 +32,21 @@ mypy:
|
||||
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
|
||||
# (many files are ignored by the config, so this is useful for checking all files)
|
||||
mypy-all:
|
||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||
|
||||
# Build the frontend
|
||||
frontend-build:
|
||||
cd invokeai/frontend/web && pnpm build
|
||||
|
||||
# Run the frontend in dev mode
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
|
||||
# Installer zip file
|
||||
installer-zip:
|
||||
cd installer && ./create_installer.sh
|
||||
|
||||
# Tag the release
|
||||
tag-release:
|
||||
cd installer && ./tag_release.sh
|
||||
|
||||
|
@ -125,8 +125,8 @@ and go to http://localhost:9090.
|
||||
|
||||
You must have Python 3.10 through 3.11 installed on your machine. Earlier or
|
||||
later versions are not supported.
|
||||
Node.js also needs to be installed along with yarn (can be installed with
|
||||
the command `npm install -g yarn` if needed)
|
||||
Node.js also needs to be installed along with `pnpm` (can be installed with
|
||||
the command `npm install -g pnpm` if needed)
|
||||
|
||||
1. Open a command-line window on your machine. The PowerShell is recommended for Windows.
|
||||
2. Create a directory to install InvokeAI into. You'll need at least 15 GB of free space:
|
||||
|
@ -100,6 +100,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
ENV INVOKEAI_ROOT=/invokeai
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$INVOKEAI_SRC:$PATH"
|
||||
ENV CONTAINER_UID=${CONTAINER_UID:-1000}
|
||||
ENV CONTAINER_GID=${CONTAINER_GID:-1000}
|
||||
|
||||
# --link requires buldkit w/ dockerfile syntax 1.4
|
||||
COPY --link --from=builder ${INVOKEAI_SRC} ${INVOKEAI_SRC}
|
||||
@ -117,7 +119,7 @@ WORKDIR ${INVOKEAI_SRC}
|
||||
RUN cd /usr/lib/$(uname -p)-linux-gnu/pkgconfig/ && ln -sf opencv4.pc opencv.pc
|
||||
RUN python3 -c "from patchmatch import patch_match"
|
||||
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R 1000:1000 ${INVOKEAI_ROOT}
|
||||
RUN mkdir -p ${INVOKEAI_ROOT} && chown -R ${CONTAINER_UID}:${CONTAINER_GID} ${INVOKEAI_ROOT}
|
||||
|
||||
COPY docker/docker-entrypoint.sh ./
|
||||
ENTRYPOINT ["/opt/invokeai/docker-entrypoint.sh"]
|
||||
|
@ -23,7 +23,7 @@ This is done via Docker Desktop preferences
|
||||
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
|
||||
a. the desired location of the InvokeAI runtime directory, or
|
||||
b. an existing, v3.0.0 compatible runtime directory.
|
||||
1. `docker compose up`
|
||||
1. Execute `run.sh`
|
||||
|
||||
The image will be built automatically if needed.
|
||||
|
||||
@ -39,7 +39,7 @@ The Docker daemon on the system must be already set up to use the GPU. In case o
|
||||
|
||||
## Customize
|
||||
|
||||
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `docker compose up`, your custom values will be used.
|
||||
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `run.sh`, your custom values will be used.
|
||||
|
||||
You can also set these values in `docker-compose.yml` directly, but `.env` will help avoid conflicts when code is updated.
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
build_args=""
|
||||
|
||||
[[ -f ".env" ]] && build_args=$(awk '$1 ~ /\=[^$]/ {print "--build-arg " $0 " "}' .env)
|
||||
|
||||
echo "docker compose build args:"
|
||||
echo $build_args
|
||||
|
||||
docker compose build $build_args
|
@ -2,23 +2,8 @@
|
||||
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
invokeai:
|
||||
x-invokeai: &invokeai
|
||||
image: "local/invokeai:latest"
|
||||
# edit below to run on a container runtime other than nvidia-container-runtime.
|
||||
# not yet tested with rocm/AMD GPUs
|
||||
# Comment out the "deploy" section to run on CPU only
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
# For AMD support, comment out the deploy section above and uncomment the devices section below:
|
||||
#devices:
|
||||
# - /dev/kfd:/dev/kfd
|
||||
# - /dev/dri:/dev/dri
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
@ -50,3 +35,27 @@ services:
|
||||
# - |
|
||||
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
|
||||
# invokeai-nodes-web --host 0.0.0.0
|
||||
|
||||
services:
|
||||
invokeai-nvidia:
|
||||
<<: *invokeai
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
invokeai-cpu:
|
||||
<<: *invokeai
|
||||
profiles:
|
||||
- cpu
|
||||
|
||||
invokeai-rocm:
|
||||
<<: *invokeai
|
||||
devices:
|
||||
- /dev/kfd:/dev/kfd
|
||||
- /dev/dri:/dev/dri
|
||||
profiles:
|
||||
- rocm
|
||||
|
@ -1,11 +1,28 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
# This script is provided for backwards compatibility with the old docker setup.
|
||||
# it doesn't do much aside from wrapping the usual docker compose CLI.
|
||||
run() {
|
||||
local scriptdir=$(dirname "${BASH_SOURCE[0]}")
|
||||
cd "$scriptdir" || exit 1
|
||||
|
||||
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
|
||||
cd "$SCRIPTDIR" || exit 1
|
||||
local build_args=""
|
||||
local profile=""
|
||||
|
||||
docker compose up -d
|
||||
docker compose logs -f
|
||||
[[ -f ".env" ]] &&
|
||||
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
|
||||
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
|
||||
|
||||
local service_name="invokeai-$profile"
|
||||
|
||||
printf "%s\n" "docker compose build args:"
|
||||
printf "%s\n" "$build_args"
|
||||
|
||||
docker compose build $build_args
|
||||
unset build_args
|
||||
|
||||
printf "%s\n" "starting service $service_name"
|
||||
docker compose --profile "$profile" up -d "$service_name"
|
||||
docker compose logs -f
|
||||
}
|
||||
|
||||
run
|
||||
|
@ -10,40 +10,36 @@ model. These are the:
|
||||
tracks the type of the model, its provenance, and where it can be
|
||||
found on disk.
|
||||
|
||||
* _ModelLoadServiceBase_ Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
* _DownloadQueueServiceBase_ A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
queue has special methods for downloading repo_id folders from
|
||||
Hugging Face, as well as discriminating among model versions in
|
||||
Civitai, but can be used for arbitrary content.
|
||||
|
||||
* _ModelInstallServiceBase_ A service for installing models to
|
||||
disk. It uses `DownloadQueueServiceBase` to download models and
|
||||
their metadata, and `ModelRecordServiceBase` to store that
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
queue has special methods for downloading repo_id folders from
|
||||
Hugging Face, as well as discriminating among model versions in
|
||||
Civitai, but can be used for arbitrary content.
|
||||
|
||||
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
|
||||
## Location of the Code
|
||||
|
||||
All four of these services can be found in
|
||||
`invokeai/app/services` in the following directories:
|
||||
|
||||
* `invokeai/app/services/model_records/`
|
||||
* `invokeai/app/services/downloads/`
|
||||
* `invokeai/app/services/model_loader/`
|
||||
* `invokeai/app/services/model_install/`
|
||||
|
||||
With the exception of the install service, each of these is a thin
|
||||
shell around a corresponding implementation located in
|
||||
`invokeai/backend/model_manager`. The main difference between the
|
||||
modules found in app services and those in the backend folder is that
|
||||
the former add support for event reporting and are more tied to the
|
||||
needs of the InvokeAI API.
|
||||
* `invokeai/app/services/model_loader/` (**under development**)
|
||||
* `invokeai/app/services/downloads/`(**under development**)
|
||||
|
||||
Code related to the FastAPI web API can be found in
|
||||
`invokeai/app/api/routers/models.py`.
|
||||
`invokeai/app/api/routers/model_records.py`.
|
||||
|
||||
***
|
||||
|
||||
@ -165,10 +161,6 @@ of the fields, including `name`, `model_type` and `base_model`, are
|
||||
shared between `ModelConfigBase` and `ModelBase`, and this is a
|
||||
potential source of confusion.
|
||||
|
||||
** TO DO: ** The `ModelBase` code needs to be revised to reduce the
|
||||
duplication of similar classes and to support using the `key` as the
|
||||
primary model identifier.
|
||||
|
||||
## Reading and Writing Model Configuration Records
|
||||
|
||||
The `ModelRecordService` provides the ability to retrieve model
|
||||
@ -362,7 +354,7 @@ model and pass its key to `get_model()`.
|
||||
Several methods allow you to create and update stored model config
|
||||
records.
|
||||
|
||||
#### add_model(key, config) -> ModelConfigBase:
|
||||
#### add_model(key, config) -> AnyModelConfig:
|
||||
|
||||
Given a key and a configuration, this will add the model's
|
||||
configuration record to the database. `config` can either be a subclass of
|
||||
@ -386,27 +378,356 @@ fields to be updated. This will return an `AnyModelConfig` on success,
|
||||
or raise `InvalidModelConfigException` or `UnknownModelException`
|
||||
exceptions on failure.
|
||||
|
||||
***TO DO:*** Investigate why `update_model()` returns an
|
||||
`AnyModelConfig` while `add_model()` returns a `ModelConfigBase`.
|
||||
|
||||
### rename_model(key, new_name) -> ModelConfigBase:
|
||||
|
||||
This is a special case of `update_model()` for the use case of
|
||||
changing the model's name. It is broken out because there are cases in
|
||||
which the InvokeAI application wants to synchronize the model's name
|
||||
with its path in the `models` directory after changing the name, type
|
||||
or base. However, when using the ModelRecordService directly, the call
|
||||
is equivalent to:
|
||||
|
||||
```
|
||||
store.rename_model(key, {'name': 'new_name'})
|
||||
```
|
||||
|
||||
***TO DO:*** Investigate why `rename_model()` is returning a
|
||||
`ModelConfigBase` while `update_model()` returns a `AnyModelConfig`.
|
||||
|
||||
***
|
||||
|
||||
## Model installation
|
||||
|
||||
The `ModelInstallService` class implements the
|
||||
`ModelInstallServiceBase` abstract base class, and provides a one-stop
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
- Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir` (_implementation pending_).
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link (_implementation pending_).
|
||||
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16. (_implementation pending_)
|
||||
|
||||
### Initializing the installer
|
||||
|
||||
A default installer is created at InvokeAI api startup time and stored
|
||||
in `ApiDependencies.invoker.services.model_install` and can
|
||||
also be retrieved from an invocation's `context` argument with
|
||||
`context.services.model_install`.
|
||||
|
||||
In the event you wish to create a new installer, you may use the
|
||||
following initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config, logger)
|
||||
|
||||
store = ModelRecordServiceSQL(db)
|
||||
installer = ModelInstallService(config, store)
|
||||
```
|
||||
|
||||
The full form of `ModelInstallService()` takes the following
|
||||
required parameters:
|
||||
|
||||
| **Argument** | **Type** | **Description** |
|
||||
|------------------|------------------------------|------------------------------|
|
||||
| `config` | InvokeAIAppConfig | InvokeAI app configuration object |
|
||||
| `record_store` | ModelRecordServiceBase | Config record storage database |
|
||||
| `event_bus` | EventServiceBase | Optional event bus to send download/install progress events to |
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.import_model()
|
||||
|
||||
The `import_model()` method is the core of the installer. The
|
||||
following illustrates basic usage:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install import (
|
||||
LocalModelSource,
|
||||
HFModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
source1 = LocalModelSource(path='/opt/models/sushi.safetensors') # a local safetensors file
|
||||
source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local diffusers folder
|
||||
|
||||
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
|
||||
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
|
||||
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
|
||||
|
||||
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||
|
||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
install_job = installer.install_model(source)
|
||||
|
||||
source2job = installer.wait_for_installs()
|
||||
for source in sources:
|
||||
job = source2job[source]
|
||||
if job.status == "completed":
|
||||
model_config = job.config_out
|
||||
model_key = model_config.key
|
||||
print(f"{source} installed as {model_key}")
|
||||
elif job.status == "error":
|
||||
print(f"{source}: {job.error_type}.\nStack trace:\n{job.error}")
|
||||
|
||||
```
|
||||
|
||||
As shown here, the `import_model()` method accepts a variety of
|
||||
sources, including local safetensors files, local diffusers folders,
|
||||
HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `import_model()` return a `ModelInstallJob` job,
|
||||
an object which tracks the progress of the install.
|
||||
|
||||
If a remote model is requested, the model's files are downloaded in
|
||||
parallel across a multiple set of threads using the download
|
||||
queue. During the download process, the `ModelInstallJob` is updated
|
||||
to provide status and progress information. After the files (if any)
|
||||
are downloaded, the remainder of the installation runs in a single
|
||||
serialized background thread. These are the model probing, file
|
||||
copying, and config record database update steps.
|
||||
|
||||
Multiple install jobs can be queued up. You may block until all
|
||||
install jobs are completed (or errored) by calling the
|
||||
`wait_for_installs()` method as shown in the code
|
||||
example. `wait_for_installs()` will return a `dict` that maps the
|
||||
requested source to its job. This object can be interrogated
|
||||
to determine its status. If the job errored out, then the error type
|
||||
and details can be recovered from `job.error_type` and `job.error`.
|
||||
|
||||
The full list of arguments to `import_model()` is as follows:
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
|
||||
| `inplace` | bool | True | Leave a local model in its current location |
|
||||
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
|
||||
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
|
||||
| `config` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
| `access_token` | str | None | Provide authorization information needed to download |
|
||||
|
||||
|
||||
The `inplace` field controls how local model Paths are handled. If
|
||||
True (the default), then the model is simply registered in its current
|
||||
location by the installer's `ModelConfigRecordService`. Otherwise, a
|
||||
copy of the model put into the location specified by the `models_dir`
|
||||
application configuration parameter.
|
||||
|
||||
The `variant` field is used for HuggingFace repo_ids only. If
|
||||
provided, the repo_id download handler will look for and download
|
||||
tensors files that follow the convention for the selected variant:
|
||||
|
||||
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
|
||||
- "onnx" will select files ending with the suffix ".onnx"
|
||||
- "openvino" will select files beginning with "openvino_model"
|
||||
|
||||
In the special case of the "fp16" variant, the installer will select
|
||||
the 32-bit version of the files if the 16-bit version is unavailable.
|
||||
|
||||
`subfolder` is used for HuggingFace repo_ids only. If provided, the
|
||||
model will be downloaded from the designated subfolder rather than the
|
||||
top-level repository folder. If a subfolder is attached to the repo_id
|
||||
using the format `repo_owner/repo_name:subfolder`, then the subfolder
|
||||
specified by the repo_id will override the subfolder argument.
|
||||
|
||||
`config` can be used to override all or a portion of the configuration
|
||||
attributes returned by the model prober. See the section below for
|
||||
details.
|
||||
|
||||
`access_token` is passed to the download queue and used to access
|
||||
repositories that require it.
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `import_model()`, it launches the
|
||||
download and installation process in the background and returns a
|
||||
`ModelInstallJob` object for monitoring the process.
|
||||
|
||||
The `ModelInstallJob` class has the following structure:
|
||||
|
||||
| **Attribute** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `status` | `InstallStatus` | An enum of ["waiting", "running", "completed" and "error" |
|
||||
| `config_in` | `dict` | Overriding configuration values provided by the caller |
|
||||
| `config_out` | `AnyModelConfig`| After successful completion, contains the configuration record written to the database |
|
||||
| `inplace` | `boolean` | True if the caller asked to install the model in place using its local path |
|
||||
| `source` | `ModelSource` | The local path, remote URL or repo_id of the model to be installed |
|
||||
| `local_path` | `Path` | If a remote model, holds the path of the model after it is downloaded; if a local model, same as `source` |
|
||||
| `error_type` | `str` | Name of the exception that led to an error status |
|
||||
| `error` | `str` | Traceback of the error |
|
||||
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as an event of type `EventServiceBase.model_event`, a timestamp and
|
||||
the following event names:
|
||||
|
||||
- `model_install_started`
|
||||
|
||||
The payload will contain the keys `timestamp` and `source`. The latter
|
||||
indicates the requested model source for installation.
|
||||
|
||||
- `model_install_progress`
|
||||
|
||||
Emitted at regular intervals when downloading a remote model, the
|
||||
payload will contain the keys `timestamp`, `source`, `current_bytes`
|
||||
and `total_bytes`. These events are _not_ emitted when a local model
|
||||
already on the filesystem is imported.
|
||||
|
||||
- `model_install_completed`
|
||||
|
||||
Issued once at the end of a successful installation. The payload will
|
||||
contain the keys `timestamp`, `source` and `key`, where `key` is the
|
||||
ID under which the model has been registered.
|
||||
|
||||
- `model_install_error`
|
||||
|
||||
Emitted if the installation process fails for some reason. The payload
|
||||
will contain the keys `timestamp`, `source`, `error_type` and
|
||||
`error`. `error_type` is a short message indicating the nature of the
|
||||
error, and `error` is the long traceback to help debug the problem.
|
||||
|
||||
#### Model confguration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
other configuration parameters. Among other things, it assigns a
|
||||
default name and description for the model based on probed
|
||||
fields.
|
||||
|
||||
When downloading remote models is implemented, additional
|
||||
configuration information, such as list of trigger terms, will be
|
||||
retrieved from the HuggingFace and Civitai model repositories.
|
||||
|
||||
The probed values can be overriden by providing a dictionary in the
|
||||
optional `config` argument passed to `import_model()`. You may provide
|
||||
overriding values for any of the model's configuration
|
||||
attributes. Here is an example of setting the
|
||||
`SchedulerPredictionType` and `name` for an sd-2 model:
|
||||
|
||||
This is typically used to set
|
||||
the model's name and description, but can also be used to overcome
|
||||
cases in which automatic probing is unable to (correctly) determine
|
||||
the model's attribute. The most common situation is the
|
||||
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
|
||||
example of how it works:
|
||||
|
||||
```
|
||||
install_job = installer.import_model(
|
||||
source='stabilityai/stable-diffusion-2-1',
|
||||
variant='fp16',
|
||||
config=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
name='stable diffusion 2 base model',
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Other installer methods
|
||||
|
||||
This section describes additional methods provided by the installer class.
|
||||
|
||||
#### jobs = installer.wait_for_installs()
|
||||
|
||||
Block until all pending installs are completed or errored and then
|
||||
returns a list of completed jobs.
|
||||
|
||||
#### jobs = installer.list_jobs([source])
|
||||
|
||||
Return a list of all active and complete `ModelInstallJobs`. An
|
||||
optional `source` argument allows you to filter the returned list by a
|
||||
model source string pattern using a partial string match.
|
||||
|
||||
#### jobs = installer.get_job(source)
|
||||
|
||||
Return a list of `ModelInstallJob` corresponding to the indicated
|
||||
model source.
|
||||
|
||||
#### installer.prune_jobs
|
||||
|
||||
Remove non-pending jobs (completed or errored) from the job list
|
||||
returned by `list_jobs()` and `get_job()`.
|
||||
|
||||
#### installer.app_config, installer.record_store,
|
||||
installer.event_bus
|
||||
|
||||
Properties that provide access to the installer's `InvokeAIAppConfig`,
|
||||
`ModelRecordServiceBase` and `EventServiceBase` objects.
|
||||
|
||||
#### key = installer.register_path(model_path, config), key = installer.install_path(model_path, config)
|
||||
|
||||
These methods bypass the download queue and directly register or
|
||||
install the model at the indicated path, returning the unique ID for
|
||||
the installed model.
|
||||
|
||||
Both methods accept a Path object corresponding to a checkpoint or
|
||||
diffusers folder, and an optional dict of config attributes to use to
|
||||
override the values derived from model probing.
|
||||
|
||||
The difference between `register_path()` and `install_path()` is that
|
||||
the former creates a model configuration record without changing the
|
||||
location of the model in the filesystem. The latter makes a copy of
|
||||
the model inside the InvokeAI models directory before registering
|
||||
it.
|
||||
|
||||
#### installer.unregister(key)
|
||||
|
||||
This will remove the model config record for the model at key, and is
|
||||
equivalent to `installer.record_store.del_model(key)`
|
||||
|
||||
#### installer.delete(key)
|
||||
|
||||
This is similar to `unregister()` but has the additional effect of
|
||||
conditionally deleting the underlying model file(s) if they reside
|
||||
within the InvokeAI models directory
|
||||
|
||||
#### installer.unconditionally_delete(key)
|
||||
|
||||
This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
||||
|
||||
This method will recursively scan the directory indicated in
|
||||
`scan_dir` for new models and either install them in the models
|
||||
directory or register them in place, depending on the setting of
|
||||
`install` (default False).
|
||||
|
||||
The return value is the list of keys of the new installed/registered
|
||||
models.
|
||||
|
||||
#### installer.sync_to_config()
|
||||
|
||||
This method synchronizes models in the models directory and autoimport
|
||||
directory to those in the `ModelConfigRecordService` database. New
|
||||
models are registered and orphan models are unregistered.
|
||||
|
||||
#### installer.start(invoker)
|
||||
|
||||
The `start` method is called by the API intialization routines when
|
||||
the API starts up. Its effect is to call `sync_to_config()` to
|
||||
synchronize the model record store database with what's currently on
|
||||
disk.
|
||||
|
||||
# The remainder of this documentation is provisional, pending implementation of the Download and Load services
|
||||
|
||||
## Let's get loaded, the lowdown on ModelLoadService
|
||||
|
||||
The `ModelLoadService` is responsible for loading a named model into
|
||||
@ -863,351 +1184,3 @@ other resources that it might have been using.
|
||||
This will start/pause/cancel all jobs that have been submitted to the
|
||||
queue and have not yet reached a terminal state.
|
||||
|
||||
## Model installation
|
||||
|
||||
The `ModelInstallService` class implements the
|
||||
`ModelInstallServiceBase` abstract base class, and provides a one-stop
|
||||
shop for all your model install needs. It provides the following
|
||||
functionality:
|
||||
|
||||
- Registering a model config record for a model already located on the
|
||||
local filesystem, without moving it or changing its path.
|
||||
|
||||
- Installing a model alreadiy located on the local filesystem, by
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
- Downloading a model from an arbitrary URL and installing it in
|
||||
`models_dir`.
|
||||
|
||||
- Special handling for Civitai model URLs which allow the user to
|
||||
paste in a model page's URL or download link. Any metadata provided
|
||||
by Civitai, such as trigger terms, are captured and placed in the
|
||||
model config record.
|
||||
|
||||
- Special handling for HuggingFace repo_ids to recursively download
|
||||
the contents of the repository, paying attention to alternative
|
||||
variants such as fp16.
|
||||
|
||||
- Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
- Interface with the InvokeAI event bus to provide status updates on
|
||||
the download, installation and registration process.
|
||||
|
||||
### Initializing the installer
|
||||
|
||||
A default installer is created at InvokeAI api startup time and stored
|
||||
in `ApiDependencies.invoker.services.model_install_service` and can
|
||||
also be retrieved from an invocation's `context` argument with
|
||||
`context.services.model_install_service`.
|
||||
|
||||
In the event you wish to create a new installer, you may use the
|
||||
following initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download_manager import DownloadQueueServive
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
|
||||
config = InvokeAI.get_config()
|
||||
queue = DownloadQueueService()
|
||||
store = ModelRecordServiceBase.open(config)
|
||||
installer = ModelInstallService(config=config, queue=queue, store=store)
|
||||
```
|
||||
|
||||
The full form of `ModelInstallService()` takes the following
|
||||
parameters. Each parameter will default to a reasonable value, but it
|
||||
is recommended that you set them explicitly as shown in the above example.
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `config` | InvokeAIAppConfig | Use system-wide config | InvokeAI app configuration object |
|
||||
| `queue` | DownloadQueueServiceBase | Create a new download queue for internal use | Download queue |
|
||||
| `store` | ModelRecordServiceBase | Use config to select the database to open | Config storage database |
|
||||
| `event_bus` | EventServiceBase | None | An event bus to send download/install progress events to |
|
||||
| `event_handlers` | List[DownloadEventHandler] | None | Event handlers for the download queue |
|
||||
|
||||
Note that if `store` is not provided, then the class will use
|
||||
`ModelRecordServiceBase.open(config)` to select the database to use.
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.install_model()
|
||||
|
||||
The `install_model()` method is the core of the installer. The
|
||||
following illustrates basic usage:
|
||||
|
||||
```
|
||||
sources = [
|
||||
Path('/opt/models/sushi.safetensors'), # a local safetensors file
|
||||
Path('/opt/models/sushi_diffusers/'), # a local diffusers folder
|
||||
'runwayml/stable-diffusion-v1-5', # a repo_id
|
||||
'runwayml/stable-diffusion-v1-5:vae', # a subfolder within a repo_id
|
||||
'https://civitai.com/api/download/models/63006', # a civitai direct download link
|
||||
'https://civitai.com/models/8765?modelVersionId=10638', # civitai model page
|
||||
'https://s3.amazon.com/fjacks/sd-3.safetensors', # arbitrary URL
|
||||
]
|
||||
|
||||
for source in sources:
|
||||
install_job = installer.install_model(source)
|
||||
|
||||
source2key = installer.wait_for_installs()
|
||||
for source in sources:
|
||||
model_key = source2key[source]
|
||||
print(f"{source} installed as {model_key}")
|
||||
```
|
||||
|
||||
As shown here, the `install_model()` method accepts a variety of
|
||||
sources, including local safetensors files, local diffusers folders,
|
||||
HuggingFace repo_ids with and without a subfolder designation,
|
||||
Civitai model URLs and arbitrary URLs that point to checkpoint files
|
||||
(but not to folders).
|
||||
|
||||
Each call to `install_model()` will return a `ModelInstallJob` job, a
|
||||
subclass of `DownloadJobBase`. The install job has additional
|
||||
install-specific fields described in the next section.
|
||||
|
||||
Each install job will run in a series of background threads using
|
||||
the object's download queue. You may block until all install jobs are
|
||||
completed (or errored) by calling the `wait_for_installs()` method as
|
||||
shown in the code example. `wait_for_installs()` will return a `dict`
|
||||
that maps the requested source to the key of the installed model. In
|
||||
the case that a model fails to download or install, its value in the
|
||||
dict will be None. The actual cause of the error will be reported in
|
||||
the corresponding job's `error` field.
|
||||
|
||||
Alternatively you may install event handlers and/or listen for events
|
||||
on the InvokeAI event bus in order to monitor the progress of the
|
||||
requested installs.
|
||||
|
||||
The full list of arguments to `model_install()` is as follows:
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|------------------|------------------------------|-------------|-------------------------------------------|
|
||||
| `source` | Union[str, Path, AnyHttpUrl] | | The source of the model, Path, URL or repo_id |
|
||||
| `inplace` | bool | True | Leave a local model in its current location |
|
||||
| `variant` | str | None | Desired variant, such as 'fp16' or 'onnx' (HuggingFace only) |
|
||||
| `subfolder` | str | None | Repository subfolder (HuggingFace only) |
|
||||
| `probe_override` | Dict[str, Any] | None | Override all or a portion of model's probed attributes |
|
||||
| `metadata` | ModelSourceMetadata | None | Provide metadata that will be added to model's config |
|
||||
| `access_token` | str | None | Provide authorization information needed to download |
|
||||
| `priority` | int | 10 | Download queue priority for the job |
|
||||
|
||||
|
||||
The `inplace` field controls how local model Paths are handled. If
|
||||
True (the default), then the model is simply registered in its current
|
||||
location by the installer's `ModelConfigRecordService`. Otherwise, the
|
||||
model will be moved into the location specified by the `models_dir`
|
||||
application configuration parameter.
|
||||
|
||||
The `variant` field is used for HuggingFace repo_ids only. If
|
||||
provided, the repo_id download handler will look for and download
|
||||
tensors files that follow the convention for the selected variant:
|
||||
|
||||
- "fp16" will select files named "*model.fp16.{safetensors,bin}"
|
||||
- "onnx" will select files ending with the suffix ".onnx"
|
||||
- "openvino" will select files beginning with "openvino_model"
|
||||
|
||||
In the special case of the "fp16" variant, the installer will select
|
||||
the 32-bit version of the files if the 16-bit version is unavailable.
|
||||
|
||||
`subfolder` is used for HuggingFace repo_ids only. If provided, the
|
||||
model will be downloaded from the designated subfolder rather than the
|
||||
top-level repository folder. If a subfolder is attached to the repo_id
|
||||
using the format `repo_owner/repo_name:subfolder`, then the subfolder
|
||||
specified by the repo_id will override the subfolder argument.
|
||||
|
||||
`probe_override` can be used to override all or a portion of the
|
||||
attributes returned by the model prober. This can be used to overcome
|
||||
cases in which automatic probing is unable to (correctly) determine
|
||||
the model's attribute. The most common situation is the
|
||||
`prediction_type` field for sd-2 (and rare sd-1) models. Here is an
|
||||
example of how it works:
|
||||
|
||||
```
|
||||
install_job = installer.install_model(
|
||||
source='stabilityai/stable-diffusion-2-1',
|
||||
variant='fp16',
|
||||
probe_override=dict(
|
||||
prediction_type=SchedulerPredictionType('v_prediction')
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
`metadata` allows you to attach custom metadata to the installed
|
||||
model. See the next section for details.
|
||||
|
||||
`priority` and `access_token` are passed to the download queue and
|
||||
have the same effect as they do for the DownloadQueueServiceBase.
|
||||
|
||||
#### Monitoring the install job process
|
||||
|
||||
When you create an install job with `model_install()`, events will be
|
||||
passed to the list of `DownloadEventHandlers` provided at installer
|
||||
initialization time. Event handlers can also be added to individual
|
||||
model install jobs by calling their `add_handler()` method as
|
||||
described earlier for the `DownloadQueueService`.
|
||||
|
||||
If the `event_bus` argument was provided, events will also be
|
||||
broadcast to the InvokeAI event bus. The events will appear on the bus
|
||||
as a singular event type named `model_event` with a payload of
|
||||
`job`. You can then retrieve the job and check its status.
|
||||
|
||||
** TO DO: ** consider breaking `model_event` into
|
||||
`model_install_started`, `model_install_completed`, etc. The event bus
|
||||
features have not yet been tested with FastAPI/websockets, and it may
|
||||
turn out that the job object is not serializable.
|
||||
|
||||
#### Model metadata and probing
|
||||
|
||||
The install service has special handling for HuggingFace and Civitai
|
||||
URLs that capture metadata from the source and include it in the model
|
||||
configuration record. For example, fetching the Civitai model 8765
|
||||
will produce a config record similar to this (using YAML
|
||||
representation):
|
||||
|
||||
```
|
||||
5abc3ef8600b6c1cc058480eaae3091e:
|
||||
path: sd-1/lora/to8contrast-1-5.safetensors
|
||||
name: to8contrast-1-5
|
||||
base_model: sd-1
|
||||
model_type: lora
|
||||
model_format: lycoris
|
||||
key: 5abc3ef8600b6c1cc058480eaae3091e
|
||||
hash: 5abc3ef8600b6c1cc058480eaae3091e
|
||||
description: 'Trigger terms: to8contrast style'
|
||||
author: theovercomer8
|
||||
license: allowCommercialUse=Sell; allowDerivatives=True; allowNoCredit=True
|
||||
source: https://civitai.com/models/8765?modelVersionId=10638
|
||||
thumbnail_url: null
|
||||
tags:
|
||||
- model
|
||||
- style
|
||||
- portraits
|
||||
```
|
||||
|
||||
For sources that do not provide model metadata, you can attach custom
|
||||
fields by providing a `metadata` argument to `model_install()` using
|
||||
an initialized `ModelSourceMetadata` object (available for import from
|
||||
`model_install_service.py`):
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_install_service import ModelSourceMetadata
|
||||
meta = ModelSourceMetadata(
|
||||
name="my model",
|
||||
author="Sushi Chef",
|
||||
description="Highly customized model; trigger with 'sushi',"
|
||||
license="mit",
|
||||
thumbnail_url="http://s3.amazon.com/ljack/pics/sushi.png",
|
||||
tags=list('sfw', 'food')
|
||||
)
|
||||
install_job = installer.install_model(
|
||||
source='sushi_chef/model3',
|
||||
variant='fp16',
|
||||
metadata=meta,
|
||||
)
|
||||
```
|
||||
|
||||
It is not currently recommended to provide custom metadata when
|
||||
installing from Civitai or HuggingFace source, as the metadata
|
||||
provided by the source will overwrite the fields you provide. Instead,
|
||||
after the model is installed you can use
|
||||
`ModelRecordService.update_model()` to change the desired fields.
|
||||
|
||||
** TO DO: ** Change the logic so that the caller's metadata fields take
|
||||
precedence over those provided by the source.
|
||||
|
||||
|
||||
#### Other installer methods
|
||||
|
||||
This section describes additional, less-frequently-used attributes and
|
||||
methods provided by the installer class.
|
||||
|
||||
##### installer.wait_for_installs()
|
||||
|
||||
This is equivalent to the `DownloadQueue` `join()` method. It will
|
||||
block until all the active jobs in the install queue have reached a
|
||||
terminal state (completed, errored or cancelled).
|
||||
|
||||
##### installer.queue, installer.store, installer.config
|
||||
|
||||
These attributes provide access to the `DownloadQueueServiceBase`,
|
||||
`ModelConfigRecordServiceBase`, and `InvokeAIAppConfig` objects that
|
||||
the installer uses.
|
||||
|
||||
For example, to temporarily pause all pending installations, you can
|
||||
do this:
|
||||
|
||||
```
|
||||
installer.queue.pause_all_jobs()
|
||||
```
|
||||
##### key = installer.register_path(model_path, overrides), key = installer.install_path(model_path, overrides)
|
||||
|
||||
These methods bypass the download queue and directly register or
|
||||
install the model at the indicated path, returning the unique ID for
|
||||
the installed model.
|
||||
|
||||
Both methods accept a Path object corresponding to a checkpoint or
|
||||
diffusers folder, and an optional dict of attributes to use to
|
||||
override the values derived from model probing.
|
||||
|
||||
The difference between `register_path()` and `install_path()` is that
|
||||
the former will not move the model from its current position, while
|
||||
the latter will move it into the `models_dir` hierarchy.
|
||||
|
||||
##### installer.unregister(key)
|
||||
|
||||
This will remove the model config record for the model at key, and is
|
||||
equivalent to `installer.store.unregister(key)`
|
||||
|
||||
##### installer.delete(key)
|
||||
|
||||
This is similar to `unregister()` but has the additional effect of
|
||||
deleting the underlying model file(s) -- even if they were outside the
|
||||
`models_dir` directory!
|
||||
|
||||
##### installer.conditionally_delete(key)
|
||||
|
||||
This method will call `unregister()` if the model identified by `key`
|
||||
is outside the `models_dir` hierarchy, and call `delete()` if the
|
||||
model is inside.
|
||||
|
||||
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
||||
|
||||
This method will recursively scan the directory indicated in
|
||||
`scan_dir` for new models and either install them in the models
|
||||
directory or register them in place, depending on the setting of
|
||||
`install` (default False).
|
||||
|
||||
The return value is the list of keys of the new installed/registered
|
||||
models.
|
||||
|
||||
#### installer.scan_models_directory()
|
||||
|
||||
This method scans the models directory for new models and registers
|
||||
them in place. Models that are present in the
|
||||
`ModelConfigRecordService` database whose paths are not found will be
|
||||
unregistered.
|
||||
|
||||
#### installer.sync_to_config()
|
||||
|
||||
This method synchronizes models in the models directory and autoimport
|
||||
directory to those in the `ModelConfigRecordService` database. New
|
||||
models are registered and orphan models are unregistered.
|
||||
|
||||
#### hash=installer.hash(model_path)
|
||||
|
||||
This method is calls the fasthash algorithm on a model's Path
|
||||
(either a file or a folder) to generate a unique ID based on the
|
||||
contents of the model.
|
||||
|
||||
##### installer.start(invoker)
|
||||
|
||||
The `start` method is called by the API intialization routines when
|
||||
the API starts up. Its effect is to call `sync_to_config()` to
|
||||
synchronize the model record store database with what's currently on
|
||||
disk.
|
||||
|
||||
This method should not ordinarily be called manually.
|
||||
|
@ -154,14 +154,16 @@ groups in `invokeia.yaml`:
|
||||
|
||||
### Web Server
|
||||
|
||||
| Setting | Default Value | Description |
|
||||
|----------|----------------|--------------|
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
| Setting | Default Value | Description |
|
||||
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||
| `port` | `9090` | Network port number that the web server will listen on |
|
||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
|
||||
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
|
||||
|
||||
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
||||
|
||||
|
@ -293,6 +293,19 @@ manager, please follow these steps:
|
||||
|
||||
## Developer Install
|
||||
|
||||
!!! warning
|
||||
|
||||
InvokeAI uses a SQLite database. By running on `main`, you accept responsibility for your database. This
|
||||
means making regular backups (especially before pulling) and/or fixing it yourself in the event that a
|
||||
PR introduces a schema change.
|
||||
|
||||
If you don't need persistent backend storage, you can use an ephemeral in-memory database by setting
|
||||
`use_memory_db: true` under `Path:` in your `invokeai.yaml` file.
|
||||
|
||||
If this is untenable, you should run the application via the official installer or a manual install of the
|
||||
python package from pypi. These releases will not break your database.
|
||||
|
||||
|
||||
If you have an interest in how InvokeAI works, or you would like to
|
||||
add features or bugfixes, you are encouraged to install the source
|
||||
code for InvokeAI. For this to work, you will need to install the
|
||||
@ -388,3 +401,5 @@ environment variable INVOKEAI_ROOT to point to the installation directory.
|
||||
|
||||
Note that if you run into problems with the Conda installation, the InvokeAI
|
||||
staff will **not** be able to help you out. Caveat Emptor!
|
||||
|
||||
[dev-chat]: https://discord.com/channels/1020123559063990373/1049495067846524939
|
10
docs/javascripts/init_kapa_widget.js
Normal file
10
docs/javascripts/init_kapa_widget.js
Normal file
@ -0,0 +1,10 @@
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
var script = document.createElement("script");
|
||||
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
|
||||
script.setAttribute("data-website-id", "b5973bb1-476b-451e-8cf4-98de86745a10");
|
||||
script.setAttribute("data-project-name", "Invoke.AI");
|
||||
script.setAttribute("data-project-color", "#11213C");
|
||||
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/113954515?s=280&v=4");
|
||||
script.async = true;
|
||||
document.head.appendChild(script);
|
||||
});
|
@ -14,6 +14,10 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
|
||||
- Community Nodes
|
||||
+ [Average Images](#average-images)
|
||||
+ [Clean Image Artifacts After Cut](#clean-image-artifacts-after-cut)
|
||||
+ [Close Color Mask](#close-color-mask)
|
||||
+ [Clothing Mask](#clothing-mask)
|
||||
+ [Contrast Limited Adaptive Histogram Equalization](#contrast-limited-adaptive-histogram-equalization)
|
||||
+ [Depth Map from Wavefront OBJ](#depth-map-from-wavefront-obj)
|
||||
+ [Film Grain](#film-grain)
|
||||
+ [Generative Grammar-Based Prompt Nodes](#generative-grammar-based-prompt-nodes)
|
||||
@ -22,16 +26,22 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Halftone](#halftone)
|
||||
+ [Ideal Size](#ideal-size)
|
||||
+ [Image and Mask Composition Pack](#image-and-mask-composition-pack)
|
||||
+ [Image Dominant Color](#image-dominant-color)
|
||||
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
|
||||
+ [Image Picker](#image-picker)
|
||||
+ [Image Resize Plus](#image-resize-plus)
|
||||
+ [Load Video Frame](#load-video-frame)
|
||||
+ [Make 3D](#make-3d)
|
||||
+ [Mask Operations](#mask-operations)
|
||||
+ [Match Histogram](#match-histogram)
|
||||
+ [Negative Image](#negative-image)
|
||||
+ [Oobabooga](#oobabooga)
|
||||
+ [Prompt Tools](#prompt-tools)
|
||||
+ [Remote Image](#remote-image)
|
||||
+ [Remove Background](#remove-background)
|
||||
+ [Retroize](#retroize)
|
||||
+ [Size Stepper Nodes](#size-stepper-nodes)
|
||||
+ [Simple Skin Detection](#simple-skin-detection)
|
||||
+ [Text font to Image](#text-font-to-image)
|
||||
+ [Thresholding](#thresholding)
|
||||
+ [Unsharp Mask](#unsharp-mask)
|
||||
@ -48,6 +58,46 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/average-images-node
|
||||
|
||||
--------------------------------
|
||||
### Clean Image Artifacts After Cut
|
||||
|
||||
Description: Removes residual artifacts after an image is separated from its background.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/clean-artifact-after-cut-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/clean-artifact-after-cut-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Close Color Mask
|
||||
|
||||
Description: Generates a mask for images based on a closely matching color, useful for color-based selections.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/close-color-mask-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/close-color-mask-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Clothing Mask
|
||||
|
||||
Description: Employs a U2NET neural network trained for the segmentation of clothing items in images.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/clothing-mask-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/clothing-mask-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Contrast Limited Adaptive Histogram Equalization
|
||||
|
||||
Description: Enhances local image contrast using adaptive histogram equalization with contrast limiting.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/clahe-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/clahe-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Depth Map from Wavefront OBJ
|
||||
|
||||
@ -164,6 +214,16 @@ This includes 15 Nodes:
|
||||
|
||||
</br><img src="https://raw.githubusercontent.com/dwringer/composition-nodes/main/composition_pack_overview.jpg" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Image Dominant Color
|
||||
|
||||
Description: Identifies and extracts the dominant color from an image using k-means clustering.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/image-dominant-color-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-dominant-color-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Image to Character Art Image Nodes
|
||||
|
||||
@ -185,6 +245,17 @@ This includes 15 Nodes:
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/image-picker-node
|
||||
|
||||
--------------------------------
|
||||
### Image Resize Plus
|
||||
|
||||
Description: Provides various image resizing options such as fill, stretch, fit, center, and crop.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/image-resize-plus-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Load Video Frame
|
||||
|
||||
@ -209,6 +280,16 @@ This includes 15 Nodes:
|
||||
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png" width="300" />
|
||||
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png" width="300" />
|
||||
|
||||
--------------------------------
|
||||
### Mask Operations
|
||||
|
||||
Description: Offers logical operations (OR, SUB, AND) for combining and manipulating image masks.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/mask-operations-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/mask-operations-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Match Histogram
|
||||
|
||||
@ -226,6 +307,16 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
|
||||
|
||||
--------------------------------
|
||||
### Negative Image
|
||||
|
||||
Description: Creates a negative version of an image, effective for visual effects and mask inversion.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/negative-image-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/negative-image-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Oobabooga
|
||||
|
||||
@ -289,6 +380,15 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
**Node Link:** https://github.com/fieldOfView/InvokeAI-remote_image
|
||||
|
||||
--------------------------------
|
||||
### Remove Background
|
||||
|
||||
Description: An integration of the rembg package to remove backgrounds from images using multiple U2NET models.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/remove-background-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/remove-background-node/master/.readme/node.png" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Retroize
|
||||
@ -301,6 +401,17 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
<img src="https://github.com/Ar7ific1al/InvokeAI_nodes_retroize/assets/2306586/de8b4fa6-324c-4c2d-b36c-297600c73974" width="500" />
|
||||
|
||||
--------------------------------
|
||||
### Simple Skin Detection
|
||||
|
||||
Description: Detects skin in images based on predefined color thresholds.
|
||||
|
||||
Node Link: https://github.com/VeyDlin/simple-skin-detection-node
|
||||
|
||||
View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/simple-skin-detection-node/master/.readme/node.png" width="500" />
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Size Stepper Nodes
|
||||
|
||||
@ -386,6 +497,7 @@ See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/READ
|
||||
|
||||
<img src="https://github.com/skunkworxdark/XYGrid_nodes/blob/main/images/collage.png" width="300" />
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Example Node Template
|
||||
|
||||
|
@ -2,43 +2,72 @@
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
function is_bin_in_path {
|
||||
builtin type -P "$1" &>/dev/null
|
||||
}
|
||||
|
||||
function git_show {
|
||||
git show -s --format='%h %s' $1
|
||||
}
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
echo
|
||||
|
||||
# Some machines only have `python3` in PATH, others have `python` - make an alias.
|
||||
# We can use a function to approximate an alias within a non-interactive shell.
|
||||
if ! is_bin_in_path python && is_bin_in_path python3; then
|
||||
function python {
|
||||
python3 "$@"
|
||||
}
|
||||
fi
|
||||
|
||||
if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
# we can't just call 'deactivate' because this function is not exported
|
||||
# to the environment of this script from the bash process that runs the script
|
||||
echo "A virtual environment is activated. Please deactivate it before proceeding".
|
||||
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
VERSION=$(cd ..; python -c "from invokeai.version import __version__ as version; print(version)")
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v3-latest"
|
||||
|
||||
echo Building installer for version $VERSION
|
||||
echo "Be certain that you're in the 'installer' directory before continuing."
|
||||
read -p "Press any key to continue, or CTRL-C to exit..."
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show
|
||||
echo
|
||||
|
||||
read -e -p "Tag this repo with '${VERSION}' and '${LATEST_TAG}'? [n]: " input
|
||||
RESPONSE=${input:='n'}
|
||||
if [ "$RESPONSE" == 'y' ]; then
|
||||
# ---------------------- FRONTEND ----------------------
|
||||
|
||||
git push origin :refs/tags/$VERSION
|
||||
if ! git tag -fa $VERSION ; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
fi
|
||||
pushd ../invokeai/frontend/web >/dev/null
|
||||
echo
|
||||
echo "Installing frontend dependencies..."
|
||||
echo
|
||||
pnpm i --frozen-lockfile
|
||||
echo
|
||||
echo "Building frontend..."
|
||||
echo
|
||||
pnpm build
|
||||
popd
|
||||
|
||||
git push origin :refs/tags/$LATEST_TAG
|
||||
git tag -fa $LATEST_TAG
|
||||
# ---------------------- BACKEND ----------------------
|
||||
|
||||
echo "remember to push --tags!"
|
||||
fi
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo Building the wheel
|
||||
echo
|
||||
echo "Building wheel..."
|
||||
echo
|
||||
|
||||
# install the 'build' package in the user site packages, if needed
|
||||
# could be improved by using a temporary venv, but it's tiny and harmless
|
||||
@ -46,12 +75,15 @@ if [[ $(python -c 'from importlib.util import find_spec; print(find_spec("build"
|
||||
pip install --user build
|
||||
fi
|
||||
|
||||
rm -r ../build
|
||||
rm -rf ../build
|
||||
|
||||
python -m build --wheel --outdir dist/ ../.
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo Building installer zip fles for InvokeAI $VERSION
|
||||
echo
|
||||
echo "Building installer zip files for InvokeAI ${VERSION}..."
|
||||
echo
|
||||
|
||||
# get rid of any old ones
|
||||
rm -f *.zip
|
||||
@ -59,9 +91,11 @@ rm -rf InvokeAI-Installer
|
||||
|
||||
# copy content
|
||||
mkdir InvokeAI-Installer
|
||||
for f in templates lib *.txt *.reg; do
|
||||
for f in templates *.txt *.reg; do
|
||||
cp -r ${f} InvokeAI-Installer/
|
||||
done
|
||||
mkdir InvokeAI-Installer/lib
|
||||
cp lib/*.py InvokeAI-Installer/lib
|
||||
|
||||
# Move the wheel
|
||||
mv dist/*.whl InvokeAI-Installer/lib/
|
||||
@ -72,13 +106,13 @@ cp install.sh.in InvokeAI-Installer/install.sh
|
||||
chmod a+x InvokeAI-Installer/install.sh
|
||||
|
||||
# Windows
|
||||
perl -p -e "s/^set INVOKEAI_VERSION=.*/set INVOKEAI_VERSION=$VERSION/" install.bat.in > InvokeAI-Installer/install.bat
|
||||
perl -p -e "s/^set INVOKEAI_VERSION=.*/set INVOKEAI_VERSION=$VERSION/" install.bat.in >InvokeAI-Installer/install.bat
|
||||
cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
|
||||
# Zip everything up
|
||||
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
|
||||
|
||||
# clean up
|
||||
rm -rf InvokeAI-Installer tmp dist
|
||||
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
|
||||
|
||||
exit 0
|
||||
|
@ -244,9 +244,9 @@ class InvokeAiInstance:
|
||||
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
|
||||
"urllib3~=1.26.0",
|
||||
"requests~=2.28.0",
|
||||
"torch==2.1.0",
|
||||
"torch==2.1.1",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchvision>=0.14.1",
|
||||
"torchvision>=0.16.1",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
|
71
installer/tag_release.sh
Executable file
71
installer/tag_release.sh
Executable file
@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
BCYAN="\e[1;36m"
|
||||
BYELLOW="\e[1;33m"
|
||||
BGREEN="\e[1;32m"
|
||||
BRED="\e[1;31m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
function does_tag_exist {
|
||||
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
|
||||
}
|
||||
|
||||
function git_show_ref {
|
||||
git show-ref --dereference $1 --abbrev 7
|
||||
}
|
||||
|
||||
function git_show {
|
||||
git show -s --format='%h %s' $1
|
||||
}
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
)
|
||||
PATCH=""
|
||||
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
LATEST_TAG="v${MAJOR_VERSION}-latest"
|
||||
|
||||
if does_tag_exist $VERSION; then
|
||||
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
|
||||
git_show_ref tags/$VERSION
|
||||
echo
|
||||
fi
|
||||
if does_tag_exist $LATEST_TAG; then
|
||||
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
|
||||
git_show_ref tags/$LATEST_TAG
|
||||
echo
|
||||
fi
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show
|
||||
echo
|
||||
|
||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
|
||||
read -e -p 'y/n [n]: ' input
|
||||
RESPONSE=${input:='n'}
|
||||
if [ "$RESPONSE" == 'y' ]; then
|
||||
echo
|
||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
|
||||
git push --delete origin $VERSION
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
|
||||
if ! git tag -fa $VERSION; then
|
||||
echo "Existing/invalid tag"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
|
||||
git push --delete origin $LATEST_TAG
|
||||
|
||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
|
||||
git tag -fa $LATEST_TAG
|
||||
|
||||
echo -e "Pushing updated tags to remote..."
|
||||
git push origin --tags
|
||||
fi
|
||||
exit 0
|
@ -2,7 +2,7 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@ -23,6 +23,7 @@ from ..services.invoker import Invoker
|
||||
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
@ -30,7 +31,6 @@ from ..services.session_processor.session_processor_default import DefaultSessio
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.sqlite import SqliteDatabase
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
@ -67,8 +67,9 @@ class ApiDependencies:
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
output_folder = config.output_path
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = SqliteDatabase(config, logger)
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
logger = logger
|
||||
@ -80,13 +81,15 @@ class ApiDependencies:
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
model_install_service = ModelInstallService(
|
||||
app_config=config, record_store=model_record_service, event_bus=events
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
@ -94,7 +97,6 @@ class ApiDependencies:
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
@ -114,6 +116,7 @@ class ApiDependencies:
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
model_install=model_install_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
@ -121,14 +124,12 @@ class ApiDependencies:
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
workflow_image_records=workflow_image_records,
|
||||
workflow_records=workflow_records,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
db.clean()
|
||||
|
||||
@staticmethod
|
||||
|
@ -8,10 +8,11 @@ from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator, WorkflowFieldValidator
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@ -73,7 +74,7 @@ async def upload_image(
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if workflow_raw is not None:
|
||||
try:
|
||||
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
|
||||
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
@ -184,6 +185,18 @@ async def get_image_metadata(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
|
||||
)
|
||||
async def get_image_workflow(
|
||||
image_name: str = Path(description="The name of image whose workflow to get"),
|
||||
) -> Optional[WorkflowWithoutID]:
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@ -25,7 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
@ -43,15 +44,25 @@ class ModelsList(BaseModel):
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
model_format: Optional[str] = Query(
|
||||
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
|
||||
)
|
||||
)
|
||||
else:
|
||||
found_models.extend(record_store.search_by_attr(model_type=model_type))
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@ -117,12 +128,17 @@ async def update_model_record(
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
"""
|
||||
Delete model record from database.
|
||||
|
||||
The configuration record will be removed. The corresponding weights files will be
|
||||
deleted as well if they reside within the InvokeAI "models" directory.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
record_store.del_model(key)
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
@ -141,7 +157,7 @@ async def del_model_record(
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
@ -162,3 +178,145 @@ async def add_model_record(
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/import",
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}`
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
`{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
"model_install_started"
|
||||
"model_install_completed"
|
||||
"model_install_error"
|
||||
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""
|
||||
Return list of model install jobs.
|
||||
|
||||
If the optional 'source' argument is provided, then the list will be filtered
|
||||
for partial string matches against the install source.
|
||||
"""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""
|
||||
Prune all completed and errored jobs from the install job list.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
204: {"description": "Model config record database resynced with files on disk"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def sync_models_to_config() -> Response:
|
||||
"""
|
||||
Traverse the models and autoimport directories. Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
@ -1,7 +1,19 @@
|
||||
from fastapi import APIRouter, Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowCategory,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
@ -10,11 +22,76 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
"/i/{workflow_id}",
|
||||
operation_id="get_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowField},
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowField:
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
|
||||
@workflows_router.patch(
|
||||
"/i/{workflow_id}",
|
||||
operation_id="update_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def update_workflow(
|
||||
workflow: Workflow = Body(description="The updated workflow", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Updates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
|
||||
|
||||
|
||||
@workflows_router.delete(
|
||||
"/i/{workflow_id}",
|
||||
operation_id="delete_workflow",
|
||||
)
|
||||
async def delete_workflow(
|
||||
workflow_id: str = Path(description="The workflow to delete"),
|
||||
) -> None:
|
||||
"""Deletes a workflow"""
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
|
||||
|
||||
|
||||
@workflows_router.post(
|
||||
"/",
|
||||
operation_id="create_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def create_workflow(
|
||||
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/",
|
||||
operation_id="list_workflows",
|
||||
responses={
|
||||
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
|
||||
},
|
||||
)
|
||||
async def list_workflows(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of workflows per page"),
|
||||
order_by: WorkflowRecordOrderBy = Query(
|
||||
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
|
||||
),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
|
||||
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
page=page, per_page=per_page, order_by=order_by, direction=direction, query=query, category=category
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ class SocketIO:
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
@ -28,10 +29,13 @@ class SocketIO:
|
||||
room=event[1]["data"]["queue_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs):
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs):
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.leave_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||
|
@ -219,18 +219,19 @@ def overridden_redoc() -> HTMLResponse:
|
||||
|
||||
web_root_path = Path(list(web_dir.__path__)[0])
|
||||
|
||||
# Only serve the UI if we it has a build
|
||||
if (web_root_path / "dist").exists():
|
||||
# Cannot add headers to StaticFiles, so we must serve index.html with a custom route
|
||||
# Add cache-control: no-store header to prevent caching of index.html, which leads to broken UIs at release
|
||||
@app.get("/", include_in_schema=False, name="ui_root")
|
||||
def get_index() -> FileResponse:
|
||||
return FileResponse(Path(web_root_path, "dist/index.html"), headers={"Cache-Control": "no-store"})
|
||||
|
||||
# Cannot add headers to StaticFiles, so we must serve index.html with a custom route
|
||||
# Add cache-control: no-store header to prevent caching of index.html, which leads to broken UIs at release
|
||||
@app.get("/", include_in_schema=False, name="ui_root")
|
||||
def get_index() -> FileResponse:
|
||||
return FileResponse(Path(web_root_path, "dist/index.html"), headers={"Cache-Control": "no-store"})
|
||||
# # Must mount *after* the other routes else it borks em
|
||||
app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), name="assets")
|
||||
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
|
||||
|
||||
|
||||
# # Must mount *after* the other routes else it borks em
|
||||
app.mount("/static", StaticFiles(directory=Path(web_root_path, "static/")), name="static") # docs favicon is in here
|
||||
app.mount("/assets", StaticFiles(directory=Path(web_root_path, "dist/assets/")), name="assets")
|
||||
app.mount("/locales", StaticFiles(directory=Path(web_root_path, "dist/locales/")), name="locales")
|
||||
|
||||
|
||||
def invoke_api() -> None:
|
||||
@ -271,6 +272,8 @@ def invoke_api() -> None:
|
||||
port=port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
@ -16,6 +17,7 @@ from pydantic.fields import FieldInfo, _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
@ -37,6 +39,19 @@ class InvalidFieldError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class Classification(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The classification of an Invocation.
|
||||
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
|
||||
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
|
||||
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
|
||||
"""
|
||||
|
||||
Stable = "stable"
|
||||
Beta = "beta"
|
||||
Prototype = "prototype"
|
||||
|
||||
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
@ -437,6 +452,7 @@ class UIConfigBase(BaseModel):
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
@ -452,6 +468,7 @@ class InvocationContext:
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
workflow: Optional[WorkflowWithoutID]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -460,12 +477,14 @@ class InvocationContext:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
workflow: Optional[WorkflowWithoutID],
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
self.workflow = workflow
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
@ -602,6 +621,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig.node_pack is not None:
|
||||
schema["node_pack"] = uiconfig.node_pack
|
||||
schema["classification"] = uiconfig.classification
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
@ -705,8 +725,10 @@ class _Model(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
# Get all pydantic model attrs, methods, etc
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=DeprecationWarning)
|
||||
# Get all pydantic model attrs, methods, etc
|
||||
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||
|
||||
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
@ -775,6 +797,7 @@ def invocation(
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
classification: Classification = Classification.Stable,
|
||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||
"""
|
||||
Registers an invocation.
|
||||
@ -785,6 +808,7 @@ def invocation(
|
||||
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
|
||||
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
|
||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
|
||||
"""
|
||||
|
||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||
@ -805,11 +829,12 @@ def invocation(
|
||||
cls.UIConfig.title = title
|
||||
cls.UIConfig.tags = tags
|
||||
cls.UIConfig.category = category
|
||||
cls.UIConfig.classification = classification
|
||||
|
||||
# Grab the node pack's name from the module name, if it's a custom node
|
||||
module_name = cls.__module__.split(".")[0]
|
||||
if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX):
|
||||
cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0]
|
||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
||||
if is_custom_node:
|
||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
||||
else:
|
||||
cls.UIConfig.node_pack = None
|
||||
|
||||
@ -903,24 +928,6 @@ def invocation_output(
|
||||
return wrapper
|
||||
|
||||
|
||||
class WorkflowField(RootModel):
|
||||
"""
|
||||
Pydantic model for workflows with custom root of type dict[str, Any].
|
||||
Workflows are stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The workflow")
|
||||
|
||||
|
||||
WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[WorkflowField] = Field(
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
@ -943,3 +950,13 @@ class WithMetadata(BaseModel):
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
@ -39,7 +39,6 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -129,7 +128,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
@ -153,7 +152,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
@ -173,7 +172,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
@ -196,7 +195,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@ -225,7 +224,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@ -247,7 +246,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@ -270,7 +269,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Openpose Processor",
|
||||
tags=["controlnet", "openpose", "pose"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
@ -295,7 +294,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
@ -322,7 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@ -339,7 +338,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.1.0"
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
@ -362,7 +361,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.1.0"
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
@ -389,7 +388,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@ -419,7 +418,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@ -435,7 +434,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
@ -458,7 +457,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@ -487,7 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@ -527,7 +526,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
@ -569,7 +568,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
@ -6,7 +6,6 @@ import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
@ -34,7 +33,7 @@ for d in Path(__file__).parent.iterdir():
|
||||
continue
|
||||
|
||||
# load the module, appending adding a suffix to identify it as a custom node pack
|
||||
spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute())
|
||||
spec = spec_from_file_location(module_name, init.absolute())
|
||||
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warn(f"Could not load {init}")
|
||||
|
@ -8,11 +8,11 @@ from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
|
||||
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
@ -41,7 +41,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -17,7 +17,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -438,8 +437,8 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.1.0")
|
||||
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
|
||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image for face detection")
|
||||
@ -508,7 +507,7 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
@ -532,8 +531,8 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.1.0")
|
||||
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
|
||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
@ -627,7 +626,7 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
@ -650,9 +649,9 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.1.0"
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
@ -716,7 +715,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -13,7 +13,15 @@ from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
)
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||
@ -36,8 +44,14 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.1.0")
|
||||
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"blank_image",
|
||||
title="Blank Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class BlankImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
width: int = InputField(default=512, description="The width of the image")
|
||||
@ -56,7 +70,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -66,8 +80,14 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.1.0")
|
||||
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
tags=["image", "crop"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageCropInvocation(BaseInvocation, WithMetadata):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
@ -90,7 +110,7 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -155,8 +175,14 @@ class CenterPadCropInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0")
|
||||
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_paste",
|
||||
title="Paste Image",
|
||||
tags=["image", "paste"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImagePasteInvocation(BaseInvocation, WithMetadata):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
base_image: ImageField = InputField(description="The base image")
|
||||
@ -199,7 +225,7 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -209,8 +235,14 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.1.0")
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"tomask",
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
@ -231,7 +263,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -241,8 +273,14 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.1.0")
|
||||
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_mul",
|
||||
title="Multiply Images",
|
||||
tags=["image", "multiply"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
@ -262,7 +300,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -275,8 +313,14 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.1.0")
|
||||
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_chan",
|
||||
title="Extract Image Channel",
|
||||
tags=["image", "channel"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageChannelInvocation(BaseInvocation, WithMetadata):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
@ -295,7 +339,7 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -308,8 +352,14 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.1.0")
|
||||
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_conv",
|
||||
title="Convert Image Mode",
|
||||
tags=["image", "convert"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageConvertInvocation(BaseInvocation, WithMetadata):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
@ -328,7 +378,7 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -338,8 +388,14 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.1.0")
|
||||
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_blur",
|
||||
title="Blur Image",
|
||||
tags=["image", "blur"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageBlurInvocation(BaseInvocation, WithMetadata):
|
||||
"""Blurs an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to blur")
|
||||
@ -363,7 +419,7 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -373,6 +429,64 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"unsharp_mask",
|
||||
title="Unsharp Mask",
|
||||
tags=["image", "unsharp_mask"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Applies an unsharp mask filter to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to use")
|
||||
radius: float = InputField(gt=0, description="Unsharp mask radius", default=2)
|
||||
strength: float = InputField(ge=0, description="Unsharp mask strength", default=50)
|
||||
|
||||
def pil_from_array(self, arr):
|
||||
return Image.fromarray((arr * 255).astype("uint8"))
|
||||
|
||||
def array_from_pil(self, img):
|
||||
return numpy.array(img) / 255
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mode = image.mode
|
||||
|
||||
alpha_channel = image.getchannel("A") if mode == "RGBA" else None
|
||||
image = image.convert("RGB")
|
||||
image_blurred = self.array_from_pil(image.filter(ImageFilter.GaussianBlur(radius=self.radius)))
|
||||
|
||||
image = self.array_from_pil(image)
|
||||
image += (image - image_blurred) * (self.strength / 100.0)
|
||||
image = numpy.clip(image, 0, 1)
|
||||
image = self.pil_from_array(image)
|
||||
|
||||
image = image.convert(mode)
|
||||
|
||||
# Make the image RGBA if we had a source alpha channel
|
||||
if alpha_channel is not None:
|
||||
image.putalpha(alpha_channel)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
PIL_RESAMPLING_MODES = Literal[
|
||||
"nearest",
|
||||
"box",
|
||||
@ -393,8 +507,14 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.1.0")
|
||||
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"img_resize",
|
||||
title="Resize Image",
|
||||
tags=["image", "resize"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageResizeInvocation(BaseInvocation, WithMetadata):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
@ -420,7 +540,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -430,8 +550,14 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.1.0")
|
||||
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"img_scale",
|
||||
title="Scale Image",
|
||||
tags=["image", "scale"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageScaleInvocation(BaseInvocation, WithMetadata):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
image: ImageField = InputField(description="The image to scale")
|
||||
@ -462,7 +588,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -472,8 +598,14 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.1.0")
|
||||
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_lerp",
|
||||
title="Lerp Image",
|
||||
tags=["image", "lerp"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageLerpInvocation(BaseInvocation, WithMetadata):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@ -496,7 +628,7 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -506,8 +638,14 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.1.0")
|
||||
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_ilerp",
|
||||
title="Inverse Lerp Image",
|
||||
tags=["image", "ilerp"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@ -530,7 +668,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -540,8 +678,14 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.1.0")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"img_nsfw",
|
||||
title="Blur NSFW Image",
|
||||
tags=["image", "nsfw"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
@ -566,7 +710,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -587,9 +731,9 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
title="Add Invisible Watermark",
|
||||
tags=["image", "watermark"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
@ -606,7 +750,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -616,8 +760,14 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.1.0")
|
||||
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"mask_edge",
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation, WithMetadata):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
@ -652,7 +802,7 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -667,9 +817,9 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
title="Combine Masks",
|
||||
tags=["image", "mask", "multiply"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class MaskCombineInvocation(BaseInvocation, WithMetadata):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
@ -689,7 +839,7 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -699,8 +849,14 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.1.0")
|
||||
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"color_correct",
|
||||
title="Color Correct",
|
||||
tags=["image", "color"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ColorCorrectInvocation(BaseInvocation, WithMetadata):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
@ -800,7 +956,7 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -810,8 +966,14 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.1.0")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_hue_adjust",
|
||||
title="Adjust Image Hue",
|
||||
tags=["image", "hue"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@ -840,7 +1002,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -913,9 +1075,9 @@ CHANNEL_FORMATS = {
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@ -950,7 +1112,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -984,9 +1146,9 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
|
||||
"""Scale a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@ -1025,7 +1187,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
@ -1043,10 +1205,10 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
|
||||
title="Save Image",
|
||||
tags=["primitives", "image"],
|
||||
category="primitives",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class SaveImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
@ -1064,7 +1226,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -1082,7 +1244,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class LinearUIOutputInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
|
||||
"""Handles Linear UI Image Outputting tasks."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
|
@ -13,7 +13,7 @@ from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
@ -118,8 +118,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -144,7 +144,7 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -154,8 +154,8 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1")
|
||||
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -181,7 +181,7 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -192,9 +192,9 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0"
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -235,7 +235,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -245,8 +245,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -264,7 +264,7 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@ -274,8 +274,8 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@ -293,7 +293,7 @@ class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -64,7 +64,6 @@ from .baseinvocation import (
|
||||
OutputField,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -802,9 +801,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@ -886,7 +885,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -31,7 +31,6 @@ from .baseinvocation import (
|
||||
UIComponent,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -326,9 +325,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
title="ONNX Latents to Image",
|
||||
tags=["latents", "image", "vae", "onnx"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@ -378,7 +377,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
@ -5,17 +7,24 @@ from pydantic import BaseModel
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
||||
from invokeai.backend.tiles.tiles import (
|
||||
calc_tiles_even_split,
|
||||
calc_tiles_min_overlap,
|
||||
calc_tiles_with_overlap,
|
||||
merge_tiles_with_linear_blending,
|
||||
merge_tiles_with_seam_blending,
|
||||
)
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
|
||||
|
||||
@ -29,7 +38,14 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
|
||||
tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.")
|
||||
|
||||
|
||||
@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0")
|
||||
@invocation(
|
||||
"calculate_image_tiles",
|
||||
title="Calculate Image Tiles",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
|
||||
@ -56,6 +72,79 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
||||
return CalculateImageTilesOutput(tiles=tiles)
|
||||
|
||||
|
||||
@invocation(
|
||||
"calculate_image_tiles_even_split",
|
||||
title="Calculate Image Tiles Even Split",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
|
||||
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
|
||||
image_height: int = InputField(
|
||||
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
|
||||
)
|
||||
num_tiles_x: int = InputField(
|
||||
default=2,
|
||||
ge=1,
|
||||
description="Number of tiles to divide image into on the x axis",
|
||||
)
|
||||
num_tiles_y: int = InputField(
|
||||
default=2,
|
||||
ge=1,
|
||||
description="Number of tiles to divide image into on the y axis",
|
||||
)
|
||||
overlap: int = InputField(
|
||||
default=128,
|
||||
ge=0,
|
||||
multiple_of=8,
|
||||
description="The overlap, in pixels, between adjacent tiles.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||
tiles = calc_tiles_even_split(
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
num_tiles_x=self.num_tiles_x,
|
||||
num_tiles_y=self.num_tiles_y,
|
||||
overlap=self.overlap,
|
||||
)
|
||||
return CalculateImageTilesOutput(tiles=tiles)
|
||||
|
||||
|
||||
@invocation(
|
||||
"calculate_image_tiles_min_overlap",
|
||||
title="Calculate Image Tiles Minimum Overlap",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
|
||||
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
|
||||
image_height: int = InputField(
|
||||
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
|
||||
)
|
||||
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
|
||||
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
|
||||
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
tile_height=self.tile_height,
|
||||
tile_width=self.tile_width,
|
||||
min_overlap=self.min_overlap,
|
||||
)
|
||||
return CalculateImageTilesOutput(tiles=tiles)
|
||||
|
||||
|
||||
@invocation_output("tile_to_properties_output")
|
||||
class TileToPropertiesOutput(BaseInvocationOutput):
|
||||
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
|
||||
@ -77,7 +166,14 @@ class TileToPropertiesOutput(BaseInvocationOutput):
|
||||
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
|
||||
|
||||
|
||||
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
|
||||
@invocation(
|
||||
"tile_to_properties",
|
||||
title="Tile to Properties",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class TileToPropertiesInvocation(BaseInvocation):
|
||||
"""Split a Tile into its individual properties."""
|
||||
|
||||
@ -103,7 +199,14 @@ class PairTileImageOutput(BaseInvocationOutput):
|
||||
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
|
||||
|
||||
|
||||
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
|
||||
@invocation(
|
||||
"pair_tile_image",
|
||||
title="Pair Tile with Image",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class PairTileImageInvocation(BaseInvocation):
|
||||
"""Pair an image with its tile properties."""
|
||||
|
||||
@ -122,13 +225,29 @@ class PairTileImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.0.0")
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
BLEND_MODES = Literal["Linear", "Seam"]
|
||||
|
||||
|
||||
@invocation(
|
||||
"merge_tiles_to_image",
|
||||
title="Merge Tiles to Image",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Merge multiple tile images into a single image."""
|
||||
|
||||
# Inputs
|
||||
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
|
||||
blend_mode: BLEND_MODES = InputField(
|
||||
default="Seam",
|
||||
description="blending type Linear or Seam",
|
||||
input=Input.Direct,
|
||||
)
|
||||
blend_amount: int = InputField(
|
||||
default=32,
|
||||
ge=0,
|
||||
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
||||
)
|
||||
@ -158,10 +277,18 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
channels = tile_np_images[0].shape[-1]
|
||||
dtype = tile_np_images[0].dtype
|
||||
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
|
||||
if self.blend_mode == "Linear":
|
||||
merge_tiles_with_linear_blending(
|
||||
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
||||
)
|
||||
elif self.blend_mode == "Seam":
|
||||
merge_tiles_with_seam_blending(
|
||||
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported blend mode: '{self.blend_mode}'.")
|
||||
|
||||
merge_tiles_with_linear_blending(
|
||||
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
||||
)
|
||||
# Convert into a PIL image and save
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
@ -172,7 +299,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
|
@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@ -29,8 +29,8 @@ if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.2.0")
|
||||
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
@ -118,7 +118,7 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
@ -4,7 +4,7 @@ from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .board_image_records_base import BoardImageRecordStorageBase
|
||||
|
||||
@ -20,63 +20,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board_images` junction table."""
|
||||
|
||||
# Create the `board_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
|
@ -3,7 +3,7 @@ import threading
|
||||
from typing import Union, cast
|
||||
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from .board_records_base import BoardRecordStorageBase
|
||||
@ -28,52 +28,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `boards` table and `board_images` junction table."""
|
||||
|
||||
# Create the `boards` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""
|
||||
Init file for InvokeAI configure package
|
||||
"""
|
||||
"""Init file for InvokeAI configure package."""
|
||||
|
||||
from .config_base import PagingArgumentParser # noqa F401
|
||||
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
||||
from .config_default import InvokeAIAppConfig, get_invokeai_config
|
||||
|
||||
__all__ = ["InvokeAIAppConfig", "get_invokeai_config"]
|
||||
|
@ -173,7 +173,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, TypeAdapter
|
||||
@ -221,6 +221,9 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||
# SSL options correspond to https://www.uvicorn.org/settings/#https
|
||||
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
|
||||
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
|
||||
|
||||
# FEATURES
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||
@ -334,7 +337,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
def get_config(cls, **kwargs: Dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
@ -383,17 +386,17 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return db_dir / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self) -> Optional[Path]:
|
||||
def model_conf_path(self) -> Path:
|
||||
"""Path to models configuration file."""
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self) -> Optional[Path]:
|
||||
def legacy_conf_path(self) -> Path:
|
||||
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self) -> Optional[Path]:
|
||||
def models_path(self) -> Path:
|
||||
"""Path to the models directory."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
from .events_base import EventServiceBase # noqa F401
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||
@ -16,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
@ -30,6 +32,13 @@ class EventServiceBase:
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.model_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
@ -313,3 +322,73 @@ class EventServiceBase:
|
||||
event_name="queue_cleared",
|
||||
payload={"queue_id": queue_id},
|
||||
)
|
||||
|
||||
def emit_model_install_started(self, source: str) -> None:
|
||||
"""
|
||||
Emitted when an install job is started.
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_started",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_completed(self, source: str, key: str) -> None:
|
||||
"""
|
||||
Emitted when an install job is completed successfully.
|
||||
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={
|
||||
"source": source,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_progress(
|
||||
self,
|
||||
source: str,
|
||||
current_bytes: int,
|
||||
total_bytes: int,
|
||||
) -> None:
|
||||
"""
|
||||
Emitted while the install job is in progress.
|
||||
(Downloaded models only)
|
||||
|
||||
:param source: Source of the model
|
||||
:param current_bytes: Number of bytes downloaded so far
|
||||
:param total_bytes: Total bytes to download
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_progress",
|
||||
payload={
|
||||
"source": source,
|
||||
"current_bytes": int,
|
||||
"total_bytes": int,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_error(
|
||||
self,
|
||||
source: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""
|
||||
Emitted when an install job encounters an exception.
|
||||
|
||||
:param source: Source of the model
|
||||
:param exception: The exception that raised the error
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
@ -4,7 +4,8 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
@ -33,7 +34,7 @@ class ImageFileStorageBase(ABC):
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
@ -43,3 +44,8 @@ class ImageFileStorageBase(ABC):
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
"""Gets the workflow of an image."""
|
||||
pass
|
||||
|
@ -7,8 +7,9 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
from .image_files_base import ImageFileStorageBase
|
||||
@ -56,7 +57,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
@ -64,12 +65,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
info_dict = {}
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", metadata.model_dump_json())
|
||||
metadata_json = metadata.model_dump_json()
|
||||
info_dict["invokeai_metadata"] = metadata_json
|
||||
pnginfo.add_text("invokeai_metadata", metadata_json)
|
||||
if workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", workflow.model_dump_json())
|
||||
workflow_json = workflow.model_dump_json()
|
||||
info_dict["invokeai_workflow"] = workflow_json
|
||||
pnginfo.add_text("invokeai_workflow", workflow_json)
|
||||
|
||||
# When saving the image, the image object's info field is not populated. We need to set it
|
||||
image.info = info_dict
|
||||
image.save(
|
||||
image_path,
|
||||
"PNG",
|
||||
@ -121,6 +129,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
||||
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
|
||||
image = self.get(image_name)
|
||||
workflow = image.info.get("invokeai_workflow", None)
|
||||
if workflow is not None:
|
||||
return WorkflowWithoutID.model_validate_json(workflow)
|
||||
return None
|
||||
|
||||
def __validate_storage_folders(self) -> None:
|
||||
"""Checks if the required output folders exist and create them if they don't"""
|
||||
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
|
||||
|
@ -75,6 +75,7 @@ class ImageRecordStorageBase(ABC):
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
has_workflow: bool,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
|
@ -100,6 +100,7 @@ IMAGE_DTO_COLS = ", ".join(
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"has_workflow",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
@ -145,6 +146,7 @@ class ImageRecord(BaseModelExcludeNull):
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
starred: bool = Field(description="Whether this image is starred.")
|
||||
"""Whether this image is starred."""
|
||||
has_workflow: bool = Field(description="Whether this image has a workflow.")
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
|
||||
@ -188,6 +190,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
starred = image_dict.get("starred", False)
|
||||
has_workflow = image_dict.get("has_workflow", False)
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
@ -202,4 +205,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
starred=starred,
|
||||
has_workflow=has_workflow,
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .image_records_base import ImageRecordStorageBase
|
||||
from .image_records_common import (
|
||||
@ -32,91 +32,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `images` table indices.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -408,6 +323,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
has_workflow: bool,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
@ -429,9 +345,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@ -444,6 +361,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
@ -3,7 +3,7 @@ from typing import Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
@ -51,7 +52,7 @@ class ImageServiceABC(ABC):
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -85,6 +86,11 @@ class ImageServiceABC(ABC):
|
||||
"""Gets an image's metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
"""Gets an image's workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
|
@ -24,11 +24,6 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
default=None, description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
workflow_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow that generated this image.",
|
||||
)
|
||||
"""The workflow that generated this image."""
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
@ -36,7 +31,6 @@ def image_record_to_dto(
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
workflow_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
@ -44,5 +38,4 @@ def image_record_to_dto(
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
@ -2,9 +2,10 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from ..image_files.image_files_common import (
|
||||
ImageFileDeleteException,
|
||||
@ -42,7 +43,7 @@ class ImageService(ImageServiceABC):
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
@ -55,12 +56,6 @@ class ImageService(ImageServiceABC):
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
if workflow is not None:
|
||||
created_workflow = self.__invoker.services.workflow_records.create(workflow)
|
||||
workflow_id = created_workflow.model_dump()["id"]
|
||||
else:
|
||||
workflow_id = None
|
||||
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self.__invoker.services.image_records.save(
|
||||
# Non-nullable fields
|
||||
@ -69,6 +64,7 @@ class ImageService(ImageServiceABC):
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
has_workflow=workflow is not None,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
@ -78,8 +74,6 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
if board_id is not None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
if workflow_id is not None:
|
||||
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||
)
|
||||
@ -143,7 +137,6 @@ class ImageService(ImageServiceABC):
|
||||
image_url=self.__invoker.services.urls.get_image_url(image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -164,18 +157,15 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
try:
|
||||
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
|
||||
if workflow_id is None:
|
||||
return None
|
||||
return self.__invoker.services.workflow_records.get(workflow_id)
|
||||
except ImageRecordNotFoundException:
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
return self.__invoker.services.image_files.get_workflow(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self.__invoker.services.logger.error("Image file not found")
|
||||
raise
|
||||
except Exception:
|
||||
self.__invoker.services.logger.error("Problem getting image workflow")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
@ -223,7 +213,6 @@ class ImageService(ImageServiceABC):
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||
)
|
||||
for r in results.items
|
||||
]
|
||||
|
@ -108,6 +108,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
workflow=queue_item.workflow,
|
||||
)
|
||||
)
|
||||
|
||||
@ -178,6 +179,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
session_queue_item_id=queue_item.session_queue_item_id,
|
||||
session_queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state=graph_execution_state,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -1,9 +1,12 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
@ -15,5 +18,6 @@ class InvocationQueueItem(BaseModel):
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
@ -28,7 +29,6 @@ if TYPE_CHECKING:
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
|
||||
|
||||
@ -51,6 +51,7 @@ class InvocationServices:
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
model_install: "ModelInstallServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
@ -59,7 +60,6 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
names: "NameServiceBase"
|
||||
urls: "UrlServiceBase"
|
||||
workflow_image_records: "WorkflowImageRecordsStorageBase"
|
||||
workflow_records: "WorkflowRecordsStorageBase"
|
||||
|
||||
def __init__(
|
||||
@ -79,6 +79,7 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@ -87,7 +88,6 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_image_records: "WorkflowImageRecordsStorageBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
@ -105,6 +105,7 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
@ -113,5 +114,4 @@ class InvocationServices:
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
self.workflow_image_records = workflow_image_records
|
||||
self.workflow_records = workflow_records
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .shared.graph import Graph, GraphExecutionState
|
||||
@ -22,6 +24,7 @@ class Invoker:
|
||||
session_queue_item_id: int,
|
||||
session_queue_batch_id: str,
|
||||
graph_execution_state: GraphExecutionState,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
invoke_all: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
@ -43,6 +46,7 @@ class Invoker:
|
||||
session_queue_batch_id=session_queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
workflow=workflow,
|
||||
invoke_all=invoke_all,
|
||||
)
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ from typing import Generic, Optional, TypeVar, get_args
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .item_storage_base import ItemStorageABC
|
||||
|
||||
|
25
invokeai/app/services/model_install/__init__.py
Normal file
25
invokeai/app/services/model_install/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
UnknownInstallJobException,
|
||||
URLModelSource,
|
||||
)
|
||||
from .model_install_default import ModelInstallService
|
||||
|
||||
__all__ = [
|
||||
"ModelInstallServiceBase",
|
||||
"ModelInstallService",
|
||||
"InstallStatus",
|
||||
"ModelInstallJob",
|
||||
"UnknownInstallJobException",
|
||||
"ModelSource",
|
||||
"LocalModelSource",
|
||||
"HFModelSource",
|
||||
"URLModelSource",
|
||||
]
|
306
invokeai/app/services/model_install/model_install_base.py
Normal file
306
invokeai/app/services/model_install/model_install_base.py
Normal file
@ -0,0 +1,306 @@
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""A HuggingFace repo_id, with optional variant and sub-folder."""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[str] = None
|
||||
subfolder: Optional[str | Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
base += f":{self.subfolder}" if self.subfolder else ""
|
||||
base += f" ({self.variant})" if self.variant else ""
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["generic_url"] = "generic_url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
error_type: Optional[str] = Field(default=None, description="Class name of the exception that led to status==ERROR")
|
||||
error: Optional[str] = Field(default=None, description="Error traceback") # noqa #501
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self.error_type = e.__class__.__name__
|
||||
self.error = "".join(traceback.format_exception(e))
|
||||
self.status = InstallStatus.ERROR
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
Create ModelInstallService object.
|
||||
|
||||
:param config: Systemwide InvokeAIAppConfig.
|
||||
:param store: Systemwide ModelConfigStore
|
||||
:param event_bus: InvokeAI event bus for reporting events to.
|
||||
"""
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Call at InvokeAI startup time."""
|
||||
self.sync_to_config()
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Stop the model install service. After this the objection can be safely deleted."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def app_config(self) -> InvokeAIAppConfig:
|
||||
"""Return the appConfig object associated with the installer."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def record_store(self) -> ModelRecordServiceBase:
|
||||
"""Return the ModelRecoreService object associated with the installer."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def event_bus(self) -> Optional[EventServiceBase]:
|
||||
"""Return the event service base object associated with the installer."""
|
||||
|
||||
@abstractmethod
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
This keeps the model in its current location.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def unregister(self, key: str) -> None:
|
||||
"""Remove model with indicated key from the database."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Remove model with indicated key from the database. Delete its files only if they are within our models directory."""
|
||||
|
||||
@abstractmethod
|
||||
def unconditionally_delete(self, key: str) -> None:
|
||||
"""Remove model with indicated key from the database and unconditionally delete weight files from disk."""
|
||||
|
||||
@abstractmethod
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
This moves the model from its current location into
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param config: Dict of attributes that will override autoassigned values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install the indicated model.
|
||||
|
||||
:param source: ModelSource object
|
||||
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record. Use it to override
|
||||
`name`, `description`, `base_type`, `model_type`, `format`,
|
||||
`prediction_type`, `image_size`, and/or `ztsnr_training`.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
This call is executed asynchronously in a separate
|
||||
thread and will issue the following events on the event bus:
|
||||
|
||||
- model_install_started
|
||||
- model_install_error
|
||||
- model_install_completed
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
|
||||
The call returns a ModelInstallJob object which can be
|
||||
polled to learn the current status and/or error message.
|
||||
|
||||
Variants recognized by HuggingFace currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
"""
|
||||
List active and complete install jobs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> List[ModelInstallJob]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending installs have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return the current list of jobs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
"""
|
||||
Recursively scan directory for new models and register or install them.
|
||||
|
||||
:param scan_dir: Path to the directory to scan.
|
||||
:param install: Install if True, otherwise register in place.
|
||||
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the model record database."""
|
395
invokeai/app/services/model_install/model_install_default.py
Normal file
395
invokeai/app/services/model_install/model_install_default.py
Normal file
@ -0,0 +1,395 @@
|
||||
"""Model installation class."""
|
||||
|
||||
import threading
|
||||
from hashlib import sha256
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from random import randbytes
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
|
||||
from .model_install_base import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
)
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = ModelInstallJob(
|
||||
source=LocalModelSource(path="stop"),
|
||||
local_path=Path("/dev/null"),
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_record_store: ModelRecordServiceBase
|
||||
_event_bus: Optional[EventServiceBase] = None
|
||||
_install_queue: Queue[ModelInstallJob]
|
||||
_install_jobs: List[ModelInstallJob]
|
||||
_logger: Logger
|
||||
_cached_model_paths: Set[Path]
|
||||
_models_installed: Set[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the installer object.
|
||||
|
||||
:param app_config: InvokeAIAppConfig object
|
||||
:param record_store: Previously-opened ModelRecordService database
|
||||
:param event_bus: Optional EventService object
|
||||
"""
|
||||
self._app_config = app_config
|
||||
self._record_store = record_store
|
||||
self._event_bus = event_bus
|
||||
self._logger = InvokeAILogger.get_logger(name=self.__class__.__name__)
|
||||
self._install_jobs = []
|
||||
self._install_queue = Queue()
|
||||
self._cached_model_paths = set()
|
||||
self._models_installed = set()
|
||||
self._start_installer_thread()
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
return self._app_config
|
||||
|
||||
@property
|
||||
def record_store(self) -> ModelRecordServiceBase: # noqa D102
|
||||
return self._record_store
|
||||
|
||||
@property
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
||||
self._install_queue.put(STOP_JOB)
|
||||
|
||||
def _start_installer_thread(self) -> None:
|
||||
threading.Thread(target=self._install_next_item, daemon=True).start()
|
||||
|
||||
def _install_next_item(self) -> None:
|
||||
done = False
|
||||
while not done:
|
||||
job = self._install_queue.get()
|
||||
if job == STOP_JOB:
|
||||
done = True
|
||||
continue
|
||||
|
||||
assert job.local_path is not None
|
||||
try:
|
||||
self._signal_job_running(job)
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
key = self.install_path(job.local_path, job.config_in)
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
except (OSError, DuplicateModelException, InvalidModelConfigException) as excp:
|
||||
self._signal_job_errored(job, excp)
|
||||
finally:
|
||||
self._install_queue.task_done()
|
||||
self._logger.info("Install thread exiting")
|
||||
|
||||
def _signal_job_running(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"{job.source}: model installation started")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
assert job.config_out
|
||||
self._logger.info(
|
||||
f"{job.source}: model installation completed. {job.local_path} registered key {job.config_out.key}"
|
||||
)
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
job.set_error(excp)
|
||||
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}")
|
||||
if self._event_bus:
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
||||
|
||||
def register_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if config.get("source") is None:
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
return self._register(model_path, config)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if config.get("source") is None:
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.original_hash
|
||||
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
|
||||
return self._register(
|
||||
new_path,
|
||||
config,
|
||||
info,
|
||||
)
|
||||
|
||||
def import_model(
|
||||
self,
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob: # noqa D102
|
||||
if not config:
|
||||
config = {}
|
||||
|
||||
# Installing a local path
|
||||
if isinstance(source, LocalModelSource) and Path(source.path).exists(): # a path that is already on disk
|
||||
job = ModelInstallJob(
|
||||
source=source,
|
||||
config_in=config,
|
||||
local_path=Path(source.path),
|
||||
)
|
||||
self._install_jobs.append(job)
|
||||
self._install_queue.put(job)
|
||||
return job
|
||||
|
||||
else: # here is where we'd download a URL or repo_id. Implementation pending download queue.
|
||||
raise UnknownModelException("File or directory not found")
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
def get_job(self, source: ModelSource) -> List[ModelInstallJob]: # noqa D102
|
||||
return [x for x in self._install_jobs if x.source == source]
|
||||
|
||||
def wait_for_installs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
self._install_queue.join()
|
||||
return self._install_jobs
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
unfinished_jobs = [
|
||||
x for x in self._install_jobs if x.status not in [InstallStatus.COMPLETED, InstallStatus.ERROR]
|
||||
]
|
||||
self._install_jobs = unfinished_jobs
|
||||
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the config record store database."""
|
||||
self._scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||
self._logger.info(f"{len(installed)} new models registered")
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._models_installed: Set[str] = set()
|
||||
search.search(scan_dir)
|
||||
return list(self._models_installed)
|
||||
|
||||
def _scan_models_directory(self) -> None:
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk")
|
||||
for model_config in self.record_store.all_models():
|
||||
path = Path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self.record_store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self.app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base.value / model.type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
new_hash = FastModelHash.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.current_hash != new_hash:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
model.current_hash = new_hash
|
||||
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
|
||||
self.record_store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.register_path(model)
|
||||
self._sync_model_path(id) # possibly move it to right place in `models`
|
||||
self._logger.info(f"Registered {model.name} with id {id}")
|
||||
self._models_installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _scan_install(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.install_path(model)
|
||||
self._logger.info(f"Installed {model} with id {id}")
|
||||
self._models_installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
def unregister(self, key: str) -> None: # noqa D102
|
||||
self.record_store.del_model(key)
|
||||
|
||||
def delete(self, key: str) -> None: # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self.record_store.get_model(key)
|
||||
models_dir = self.app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.unconditionally_delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def unconditionally_delete(self, key: str) -> None: # noqa D102
|
||||
model = self.record_store.get_model(key)
|
||||
path = self.app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def _copy_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if old_path.is_dir():
|
||||
copytree(old_path, new_path)
|
||||
else:
|
||||
copyfile(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while new_path.exists():
|
||||
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
move(old_path, new_path)
|
||||
return new_path
|
||||
|
||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
||||
if config: # used to override probe fields
|
||||
for key, value in config.items():
|
||||
setattr(info, key, value)
|
||||
return info
|
||||
|
||||
def _create_key(self) -> str:
|
||||
return sha256(randbytes(100)).hexdigest()[0:32]
|
||||
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
key = self._create_key()
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
model_path = model_path.relative_to(self.app_config.models_path)
|
||||
|
||||
info.path = model_path.as_posix()
|
||||
|
||||
# add 'main' specific fields
|
||||
if hasattr(info, "config"):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(key, info)
|
||||
return key
|
@ -6,3 +6,11 @@ from .model_records_base import ( # noqa F401
|
||||
UnknownModelException,
|
||||
)
|
||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||
|
||||
__all__ = [
|
||||
"ModelRecordServiceBase",
|
||||
"ModelRecordServiceSQL",
|
||||
"DuplicateModelException",
|
||||
"InvalidModelException",
|
||||
"UnknownModelException",
|
||||
]
|
||||
|
@ -7,10 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.2.0"
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -32,12 +29,6 @@ class ConfigFileVersionMismatchException(Exception):
|
||||
class ModelRecordServiceBase(ABC):
|
||||
"""Abstract base class for storage and retrieval of model configs."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""Return the config file/database schema version."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
@ -115,6 +106,7 @@ class ModelRecordServiceBase(ABC):
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
@ -122,6 +114,7 @@ class ModelRecordServiceBase(ABC):
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
|
@ -49,12 +49,12 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
from ..shared.sqlite import SqliteDatabase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
@ -78,85 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
with self._db.lock:
|
||||
# Enable foreign keys
|
||||
self._db.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._db.conn.commit()
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Create sqlite3 tables."""
|
||||
# model_config table breaks out the fields that are common to all config objects
|
||||
# and puts class-specific ones in a serialized json object
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add indexes for searchable fields
|
||||
for stmt in [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]:
|
||||
self._cursor.execute(stmt)
|
||||
|
||||
# Add our version to the metadata table
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE into model_manager_metadata (
|
||||
metadata_key,
|
||||
metadata_value
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@ -175,21 +96,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
base,
|
||||
type,
|
||||
name,
|
||||
path,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?,?,?,?,?);
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.base,
|
||||
record.type,
|
||||
record.name,
|
||||
record.path,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
@ -214,22 +127,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
WHERE metadata_key=?;
|
||||
""",
|
||||
("version",),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise KeyError("Models database does not have metadata key 'version'")
|
||||
return rows[0]
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
"""
|
||||
Delete a model.
|
||||
@ -269,14 +166,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET base=?,
|
||||
type=?,
|
||||
name=?,
|
||||
path=?,
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
@ -332,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
@ -339,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
@ -355,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
@ -374,7 +273,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE model_path=?;
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
|
@ -114,6 +114,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
@ -8,6 +8,10 @@ from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
WorkflowWithoutIDValidator,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# region Errors
|
||||
@ -66,6 +70,9 @@ class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
default=None, description="The workflow to initialize the session with"
|
||||
)
|
||||
runs: int = Field(
|
||||
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
||||
)
|
||||
@ -164,6 +171,14 @@ def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||
return session
|
||||
|
||||
|
||||
def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
||||
workflow_raw = queue_item_dict.get("workflow", None)
|
||||
if workflow_raw is not None:
|
||||
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
|
||||
return workflow
|
||||
return None
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
@ -213,12 +228,16 @@ class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
|
||||
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
default=None, description="The workflow associated with this queue item"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
queue_item_dict["session"] = get_session(queue_item_dict)
|
||||
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
|
||||
return SessionQueueItem(**queue_item_dict)
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -334,7 +353,7 @@ def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) ->
|
||||
|
||||
def create_session_nfv_tuples(
|
||||
batch: Batch, maximum: int
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
|
||||
"""
|
||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||
@ -365,7 +384,7 @@ def create_session_nfv_tuples(
|
||||
return
|
||||
flat_node_field_values = list(chain.from_iterable(d))
|
||||
graph = populate_graph(batch.graph, flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
|
||||
count += 1
|
||||
|
||||
|
||||
@ -391,12 +410,14 @@ def calc_session_count(batch: Batch) -> int:
|
||||
class SessionQueueValueToInsert(NamedTuple):
|
||||
"""A tuple of values to insert into the session_queue table"""
|
||||
|
||||
# Careful with the ordering of this - it must match the insert statement
|
||||
queue_id: str # queue_id
|
||||
session: str # session json
|
||||
session_id: str # session_id
|
||||
batch_id: str # batch_id
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@ -404,7 +425,7 @@ ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
|
||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
||||
values_to_insert: ValuesToInsert = []
|
||||
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
# sessions must have unique id
|
||||
session.id = uuid_string()
|
||||
values_to_insert.append(
|
||||
@ -416,6 +437,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
# must use pydantic_encoder bc field_values is a list of models
|
||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
@ -28,7 +28,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqliteSessionQueue(SessionQueueBase):
|
||||
@ -50,7 +50,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__lock = db.lock
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
@ -98,114 +97,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the session queue tables, indicies, and triggers"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
@ -281,8 +172,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
|
@ -1,50 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteDatabase:
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||
self._logger = logger
|
||||
self._config = config
|
||||
|
||||
if self._config.use_memory_db:
|
||||
self.db_path = sqlite_memory
|
||||
logger.info("Using in-memory database")
|
||||
else:
|
||||
db_path = self._config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = str(db_path)
|
||||
self._logger.info(f"Using database at {self.db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self._config.log_sql:
|
||||
self.conn.set_trace_callback(self._logger.debug)
|
||||
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def clean(self) -> None:
|
||||
try:
|
||||
if self.db_path == sqlite_memory:
|
||||
return
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.lock.acquire()
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
raise e
|
||||
finally:
|
||||
self.lock.release()
|
10
invokeai/app/services/shared/sqlite/sqlite_common.py
Normal file
10
invokeai/app/services/shared/sqlite/sqlite_common.py
Normal file
@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SQLiteDirection(str, Enum, metaclass=MetaEnum):
|
||||
Ascending = "ASC"
|
||||
Descending = "DESC"
|
67
invokeai/app/services/shared/sqlite/sqlite_database.py
Normal file
67
invokeai/app/services/shared/sqlite/sqlite_database.py
Normal file
@ -0,0 +1,67 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
|
||||
class SqliteDatabase:
|
||||
"""
|
||||
Manages a connection to an SQLite database.
|
||||
|
||||
:param db_path: Path to the database file. If None, an in-memory database is used.
|
||||
:param logger: Logger to use for logging.
|
||||
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
|
||||
|
||||
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
|
||||
- The database file is written to disk if it does not exist.
|
||||
- Foreign key constraints are enabled by default.
|
||||
- The connection is configured to use the `sqlite3.Row` row factory.
|
||||
|
||||
In addition to the constructor args, the instance provides the following attributes and methods:
|
||||
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
|
||||
- `lock`: A shared re-entrant lock, used to approximate thread safety.
|
||||
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||
"""Initializes the database. This is used internally by the class constructor."""
|
||||
self.logger = logger
|
||||
self.db_path = db_path
|
||||
self.verbose = verbose
|
||||
|
||||
if not self.db_path:
|
||||
logger.info("Initializing in-memory database")
|
||||
else:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
"""
|
||||
# No need to clean in-memory database
|
||||
if not self.db_path:
|
||||
return
|
||||
with self.lock:
|
||||
try:
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
32
invokeai/app/services/shared/sqlite/sqlite_util.py
Normal file
32
invokeai/app/services/shared/sqlite/sqlite_util.py
Normal file
@ -0,0 +1,32 @@
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileStorageBase) -> SqliteDatabase:
|
||||
"""
|
||||
Initializes the SQLite database.
|
||||
|
||||
:param config: The app config
|
||||
:param logger: The logger
|
||||
:param image_files: The image files service (used by migration 2)
|
||||
|
||||
This function:
|
||||
- Instantiates a :class:`SqliteDatabase`
|
||||
- Instantiates a :class:`SqliteMigrator` and registers all migrations
|
||||
- Runs all migrations
|
||||
"""
|
||||
db_path = None if config.use_memory_db else config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=logger, verbose=config.log_sql)
|
||||
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
@ -0,0 +1,372 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration1Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Migration callback for database version 1."""
|
||||
|
||||
self._create_board_images(cursor)
|
||||
self._create_boards(cursor)
|
||||
self._create_images(cursor)
|
||||
self._create_model_config(cursor)
|
||||
self._create_session_queue(cursor)
|
||||
self._create_workflow_images(cursor)
|
||||
self._create_workflows(cursor)
|
||||
|
||||
def _create_board_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `board_images` table, indices and triggers."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_boards(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `boards` table, indices and triggers."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = ["CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);"]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `images` table, indices and triggers. Adds the `starred` column."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
height INTEGER NOT NULL,
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add the 'starred' column to `images` if it doesn't exist
|
||||
cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "starred" not in columns:
|
||||
tables.append("ALTER TABLE images ADD COLUMN starred BOOLEAN DEFAULT FALSE;")
|
||||
indices.append("CREATE INDEX IF NOT EXISTS idx_images_starred ON images(starred);")
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the `model_config` table, `model_manager_metadata` table, indices and triggers."""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_manager_metadata (
|
||||
metadata_key TEXT NOT NULL PRIMARY KEY,
|
||||
metadata_value TEXT NOT NULL
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_session_queue(self, cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id INTEGER PRIMARY KEY AUTOINCREMENT, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'completed', 'failed', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
""",
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
""",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_workflow_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||
workflow_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between workflows and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _create_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflows FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflows
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_1() -> Migration:
|
||||
"""
|
||||
Builds the migration from database version 0 (init) to 1.
|
||||
|
||||
This migration represents the state of the database circa InvokeAI v3.4.0, which was the last
|
||||
version to not use migrations to manage the database.
|
||||
|
||||
As such, this migration does include some ALTER statements, and the SQL statements are written
|
||||
to be idempotent.
|
||||
|
||||
- Create `board_images` junction table
|
||||
- Create `boards` table
|
||||
- Create `images` table, add `starred` column
|
||||
- Create `model_config` table
|
||||
- Create `session_queue` table
|
||||
- Create `workflow_images` junction table
|
||||
- Create `workflows` table
|
||||
"""
|
||||
|
||||
migration_1 = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
callback=Migration1Callback(),
|
||||
)
|
||||
|
||||
return migration_1
|
@ -0,0 +1,198 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from pydantic import ValidationError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.image_files.image_files_common import ImageFileNotFoundException
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
UnsafeWorkflowWithVersionValidator,
|
||||
)
|
||||
|
||||
|
||||
class Migration2Callback:
|
||||
def __init__(self, image_files: ImageFileStorageBase, logger: Logger):
|
||||
self._image_files = image_files
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor):
|
||||
self._add_images_has_workflow(cursor)
|
||||
self._add_session_queue_workflow(cursor)
|
||||
self._drop_old_workflow_tables(cursor)
|
||||
self._add_workflow_library(cursor)
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_embedded_workflows(cursor)
|
||||
|
||||
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `has_workflow` column to `images` table."""
|
||||
cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "has_workflow" not in columns:
|
||||
cursor.execute("ALTER TABLE images ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;")
|
||||
|
||||
def _add_session_queue_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Add the `workflow` column to `session_queue` table."""
|
||||
|
||||
cursor.execute("PRAGMA table_info(session_queue)")
|
||||
columns = [column[1] for column in cursor.fetchall()]
|
||||
|
||||
if "workflow" not in columns:
|
||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN workflow TEXT;")
|
||||
|
||||
def _drop_old_workflow_tables(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `workflows` and `workflow_images` tables."""
|
||||
cursor.execute("DROP TABLE IF EXISTS workflow_images;")
|
||||
cursor.execute("DROP TABLE IF EXISTS workflows;")
|
||||
|
||||
def _add_workflow_library(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Adds the `workflow_library` table and drops the `workflows` and `workflow_images` tables."""
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||
workflow TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated manually when retrieving workflow
|
||||
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Generated columns, needed for indexing and searching
|
||||
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
||||
);
|
||||
""",
|
||||
]
|
||||
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);",
|
||||
]
|
||||
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_library FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_library
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Drops the `model_config` table, recreating it.
|
||||
|
||||
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
||||
|
||||
Because this table is not used in production, we are able to simply drop it and recreate it.
|
||||
"""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
the database now has a `has_workflow` column, indicating if an image has a workflow embedded.
|
||||
|
||||
This migrate callback checks each image for the presence of an embedded workflow, then updates its entry
|
||||
in the database accordingly.
|
||||
"""
|
||||
# Get all image names
|
||||
cursor.execute("SELECT image_name FROM images")
|
||||
image_names: list[str] = [image[0] for image in cursor.fetchall()]
|
||||
total_image_names = len(image_names)
|
||||
|
||||
if not total_image_names:
|
||||
return
|
||||
|
||||
self._logger.info(f"Migrating workflows for {total_image_names} images")
|
||||
|
||||
# Migrate the images
|
||||
to_migrate: list[tuple[bool, str]] = []
|
||||
pbar = tqdm(image_names)
|
||||
for idx, image_name in enumerate(pbar):
|
||||
pbar.set_description(f"Checking image {idx + 1}/{total_image_names} for workflow")
|
||||
try:
|
||||
pil_image = self._image_files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._logger.warning(f"Image {image_name} not found, skipping")
|
||||
continue
|
||||
if "invokeai_workflow" in pil_image.info:
|
||||
try:
|
||||
UnsafeWorkflowWithVersionValidator.validate_json(pil_image.info.get("invokeai_workflow", ""))
|
||||
except ValidationError:
|
||||
self._logger.warning(f"Image {image_name} has invalid embedded workflow, skipping")
|
||||
continue
|
||||
to_migrate.append((True, image_name))
|
||||
|
||||
self._logger.info(f"Adding {len(to_migrate)} embedded workflows to database")
|
||||
cursor.executemany("UPDATE images SET has_workflow = ? WHERE image_name = ?", to_migrate)
|
||||
|
||||
|
||||
def build_migration_2(image_files: ImageFileStorageBase, logger: Logger) -> Migration:
|
||||
"""
|
||||
Builds the migration from database version 1 to 2.
|
||||
|
||||
Introduced in v3.5.0 for the new workflow library.
|
||||
|
||||
:param image_files: The image files service, used to check for embedded workflows
|
||||
:param logger: The logger, used to log progress during embedded workflows handling
|
||||
|
||||
This migration does the following:
|
||||
- Add `has_workflow` column to `images` table
|
||||
- Add `workflow` column to `session_queue` table
|
||||
- Drop `workflows` and `workflow_images` tables
|
||||
- Add `workflow_library` table
|
||||
- Drops the `model_manager_metadata` table
|
||||
- Drops the `model_config` table, recreating it (at this point, there is no user data in this table)
|
||||
- Populates the `has_workflow` column in the `images` table (requires `image_files` & `logger` dependencies)
|
||||
"""
|
||||
migration_2 = Migration(
|
||||
from_version=1,
|
||||
to_version=2,
|
||||
callback=Migration2Callback(image_files=image_files, logger=logger),
|
||||
)
|
||||
|
||||
return migration_2
|
@ -0,0 +1,164 @@
|
||||
import sqlite3
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MigrateCallback(Protocol):
|
||||
"""
|
||||
A callback that performs a migration.
|
||||
|
||||
Migrate callbacks are provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
If the callback needs to access additional dependencies, will be provided to the callback at runtime.
|
||||
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
"""Raised when a migration fails."""
|
||||
|
||||
|
||||
class MigrationVersionError(ValueError):
|
||||
"""Raised when a migration version is invalid."""
|
||||
|
||||
|
||||
class Migration(BaseModel):
|
||||
"""
|
||||
Represents a migration for a SQLite database.
|
||||
|
||||
:param from_version: The database version on which this migration may be run
|
||||
:param to_version: The database version that results from this migration
|
||||
:param migrate_callback: The callback to run to perform the migration
|
||||
|
||||
Migration callbacks will be provided an open cursor to the database. They should not commit their
|
||||
transaction; this is handled by the migrator.
|
||||
|
||||
It is suggested to use a class to define the migration callback and a builder function to create
|
||||
the :class:`Migration`. This allows the callback to be provided with additional dependencies and
|
||||
keeps things tidy, as all migration logic is self-contained.
|
||||
|
||||
Example:
|
||||
```py
|
||||
# Define the migration callback class
|
||||
class Migration1Callback:
|
||||
# This migration needs a logger, so we define a class that accepts a logger in its constructor.
|
||||
def __init__(self, image_files: ImageFileStorageBase) -> None:
|
||||
self._image_files = ImageFileStorageBase
|
||||
|
||||
# This dunder method allows the instance of the class to be called like a function.
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._add_with_banana_column(cursor)
|
||||
self._do_something_with_images(cursor)
|
||||
|
||||
def _add_with_banana_column(self, cursor: sqlite3.Cursor) -> None:
|
||||
\"""Adds the with_banana column to the sushi table.\"""
|
||||
# Execute SQL using the cursor, taking care to *not commit* a transaction
|
||||
cursor.execute('ALTER TABLE sushi ADD COLUMN with_banana BOOLEAN DEFAULT TRUE;')
|
||||
|
||||
def _do_something_with_images(self, cursor: sqlite3.Cursor) -> None:
|
||||
\"""Does something with the image files service.\"""
|
||||
self._image_files.get(...)
|
||||
|
||||
# Define the migration builder function. This function creates an instance of the migration callback
|
||||
# class and returns a Migration.
|
||||
def build_migration_1(image_files: ImageFileStorageBase) -> Migration:
|
||||
\"""Builds the migration from database version 0 to 1.
|
||||
Requires the image files service to...
|
||||
\"""
|
||||
|
||||
migration_1 = Migration(
|
||||
from_version=0,
|
||||
to_version=1,
|
||||
migrate_callback=Migration1Callback(image_files=image_files),
|
||||
)
|
||||
|
||||
return migration_1
|
||||
|
||||
# Register the migration after all dependencies have been initialized
|
||||
db = SqliteDatabase(db_path, logger)
|
||||
migrator = SqliteMigrator(db)
|
||||
migrator.register_migration(build_migration_1(image_files))
|
||||
migrator.run_migrations()
|
||||
```
|
||||
"""
|
||||
|
||||
from_version: int = Field(ge=0, strict=True, description="The database version on which this migration may be run")
|
||||
to_version: int = Field(ge=1, strict=True, description="The database version that results from this migration")
|
||||
callback: MigrateCallback = Field(description="The callback to run to perform the migration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_to_version(self) -> "Migration":
|
||||
"""Validates that to_version is one greater than from_version."""
|
||||
if self.to_version != self.from_version + 1:
|
||||
raise MigrationVersionError("to_version must be one greater than from_version")
|
||||
return self
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class MigrationSet:
|
||||
"""
|
||||
A set of Migrations. Performs validation during migration registration and provides utility methods.
|
||||
|
||||
Migrations should be registered with `register()`. Once all are registered, `validate_migration_chain()`
|
||||
should be called to ensure that the migrations form a single chain of migrations from version 0 to the latest version.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[Migration] = set()
|
||||
|
||||
def register(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
||||
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
||||
if migration_from_already_registered or migration_to_already_registered:
|
||||
raise MigrationVersionError("Migration with from_version or to_version already registered")
|
||||
self._migrations.add(migration)
|
||||
|
||||
def get(self, from_version: int) -> Optional[Migration]:
|
||||
"""Gets the migration that may be run on the given database version."""
|
||||
# register() ensures that there is only one migration with a given from_version, so this is safe.
|
||||
return next((m for m in self._migrations if m.from_version == from_version), None)
|
||||
|
||||
def validate_migration_chain(self) -> None:
|
||||
"""
|
||||
Validates that the migrations form a single chain of migrations from version 0 to the latest version,
|
||||
Raises a MigrationError if there is a problem.
|
||||
"""
|
||||
if self.count == 0:
|
||||
return
|
||||
if self.latest_version == 0:
|
||||
return
|
||||
next_migration = self.get(from_version=0)
|
||||
if next_migration is None:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
touched_count = 1
|
||||
while next_migration is not None:
|
||||
next_migration = self.get(next_migration.to_version)
|
||||
if next_migration is not None:
|
||||
touched_count += 1
|
||||
if touched_count != self.count:
|
||||
raise MigrationError("Migration chain is fragmented")
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""The count of registered migrations."""
|
||||
return len(self._migrations)
|
||||
|
||||
@property
|
||||
def latest_version(self) -> int:
|
||||
"""Gets latest to_version among registered migrations. Returns 0 if there are no migrations registered."""
|
||||
if self.count == 0:
|
||||
return 0
|
||||
return sorted(self._migrations, key=lambda m: m.to_version)[-1].to_version
|
@ -0,0 +1,130 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration, MigrationError, MigrationSet
|
||||
|
||||
|
||||
class SqliteMigrator:
|
||||
"""
|
||||
Manages migrations for a SQLite database.
|
||||
|
||||
:param db: The instance of :class:`SqliteDatabase` to migrate.
|
||||
|
||||
Migrations should be registered with :meth:`register_migration`.
|
||||
|
||||
Each migration is run in a transaction. If a migration fails, the transaction is rolled back.
|
||||
|
||||
Example Usage:
|
||||
```py
|
||||
db = SqliteDatabase(db_path="my_db.db", logger=logger)
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2())
|
||||
migrator.run_migrations()
|
||||
```
|
||||
"""
|
||||
|
||||
backup_path: Optional[Path] = None
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
|
||||
def register_migration(self, migration: Migration) -> None:
|
||||
"""Registers a migration."""
|
||||
self._migration_set.register(migration)
|
||||
self._logger.debug(f"Registered migration {migration.from_version} -> {migration.to_version}")
|
||||
|
||||
def run_migrations(self) -> bool:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._db.lock:
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return False
|
||||
|
||||
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migration_set.get(self._get_current_version(cursor))
|
||||
self._logger.info("Database updated successfully")
|
||||
return True
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
"""Runs a single migration."""
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db.lock, self._db.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
f"Database is at version {self._get_current_version(cursor)}, expected {migration.from_version}"
|
||||
)
|
||||
self._logger.debug(f"Running migration from {migration.from_version} to {migration.to_version}")
|
||||
|
||||
# Run the actual migration
|
||||
migration.callback(cursor)
|
||||
|
||||
# Update the version
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (?);", (migration.to_version,))
|
||||
|
||||
self._logger.debug(
|
||||
f"Successfully migrated database from {migration.from_version} to {migration.to_version}"
|
||||
)
|
||||
# We want to catch *any* error, mirroring the behaviour of the sqlite3 module.
|
||||
except Exception as e:
|
||||
# The connection context manager has already rolled back the migration, so we don't need to do anything.
|
||||
msg = f"Error migrating database from {migration.from_version} to {migration.to_version}: {e}"
|
||||
self._logger.error(msg)
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the migrations table for the database, if one does not already exist."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
return
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
cursor.connection.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
cursor.connection.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
@classmethod
|
||||
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
|
||||
"""Gets the current version of the database, or 0 if the migrations table does not exist."""
|
||||
try:
|
||||
cursor.execute("SELECT MAX(version) FROM migrations;")
|
||||
version: int = cursor.fetchone()[0]
|
||||
if version is None:
|
||||
return 0
|
||||
return version
|
||||
except sqlite3.OperationalError as e:
|
||||
if "no such table" in str(e):
|
||||
return 0
|
||||
raise
|
@ -1,23 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WorkflowImageRecordsStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Creates a workflow-image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's workflow id, if it has one."""
|
||||
pass
|
@ -1,122 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||
|
||||
|
||||
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
|
||||
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
|
||||
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
# Create the `workflow_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||
workflow_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between workflows and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for workflow id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for workflow id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Creates a workflow-image record."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO workflow_images (workflow_id, image_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(workflow_id, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_workflow_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's workflow id, if it has one."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id
|
||||
FROM workflow_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
@ -0,0 +1,17 @@
|
||||
# Default Workflows
|
||||
|
||||
Workflows placed in this directory will be synced to the `workflow_library` as
|
||||
_default workflows_ on app startup.
|
||||
|
||||
- Default workflows are not editable by users. If they are loaded and saved,
|
||||
they will save as a copy of the default workflow.
|
||||
- Default workflows must have the `meta.category` property set to `"default"`.
|
||||
An exception will be raised during sync if this is not set correctly.
|
||||
- Default workflows appear on the "Default Workflows" tab of the Workflow
|
||||
Library.
|
||||
|
||||
After adding or updating default workflows, you **must** start the app up and
|
||||
load them to ensure:
|
||||
|
||||
- The workflow loads without warning or errors
|
||||
- The workflow runs successfully
|
@ -0,0 +1,798 @@
|
||||
{
|
||||
"name": "Text to Image - SD1.5",
|
||||
"author": "InvokeAI",
|
||||
"description": "Sample text to image workflow for Stable Diffusion 1.5/2",
|
||||
"version": "1.1.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text2image, SD1.5, SD2, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"fieldName": "width"
|
||||
},
|
||||
{
|
||||
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"fieldName": "height"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"category": "default",
|
||||
"version": "2.0.0"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"type": "compel",
|
||||
"label": "Negative Compel Prompt",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "Negative Prompt",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"clip": {
|
||||
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 259,
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 350
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"type": "noise",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.1",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"seed": {
|
||||
"id": "6431737c-918a-425d-a3b4-5d57e2f35d4d",
|
||||
"name": "seed",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"width": {
|
||||
"id": "38fc5b66-fe6e-47c8-bba9-daf58e454ed7",
|
||||
"name": "width",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"height": {
|
||||
"id": "16298330-e2bf-4872-a514-d6923df53cbb",
|
||||
"name": "height",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"use_cpu": {
|
||||
"id": "c7c436d3-7a7a-4e76-91e4-c6deb271623c",
|
||||
"name": "use_cpu",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": true
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"noise": {
|
||||
"id": "50f650dc-0184-4e23-a927-0497a96fe954",
|
||||
"name": "noise",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "bb8a452b-133d-42d1-ae4a-3843d7e4109a",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "35cfaa12-3b8b-4b7a-a884-327ff3abddd9",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 388,
|
||||
"position": {
|
||||
"x": 600,
|
||||
"y": 325
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"type": "main_model_loader",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"id": "993eabd2-40fd-44fe-bce7-5d0c7075ddab",
|
||||
"name": "model",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MainModelField"
|
||||
},
|
||||
"value": {
|
||||
"model_name": "stable-diffusion-v1-5",
|
||||
"base_model": "sd-1",
|
||||
"model_type": "main"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"unet": {
|
||||
"id": "5c18c9db-328d-46d0-8cb9-143391c410be",
|
||||
"name": "unet",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"id": "6effcac0-ec2f-4bf5-a49e-a2c29cf921f4",
|
||||
"name": "clip",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"id": "57683ba3-f5f5-4f58-b9a2-4b83dacad4a1",
|
||||
"name": "vae",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 226,
|
||||
"position": {
|
||||
"x": 600,
|
||||
"y": 25
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"type": "compel",
|
||||
"label": "Positive Compel Prompt",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "Positive Prompt",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": "Super cute tiger cub, national geographic award-winning photograph"
|
||||
},
|
||||
"clip": {
|
||||
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 259,
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 25
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
|
||||
"type": "rand_int",
|
||||
"label": "Random Seed",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"id": "3ec65a37-60ba-4b6c-a0b2-553dd7a84b84",
|
||||
"name": "low",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"id": "085f853a-1a5f-494d-8bec-e4ba29a3f2d1",
|
||||
"name": "high",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 2147483647
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"value": {
|
||||
"id": "812ade4d-7699-4261-b9fc-a6c9d2ab55ee",
|
||||
"name": "value",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 600,
|
||||
"y": 275
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "denoise_latents",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.5.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"positive_conditioning": {
|
||||
"id": "90b7f4f8-ada7-4028-8100-d2e54f192052",
|
||||
"name": "positive_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"id": "9393779e-796c-4f64-b740-902a1177bf53",
|
||||
"name": "negative_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"noise": {
|
||||
"id": "8e17f1e5-4f98-40b1-b7f4-86aeeb4554c1",
|
||||
"name": "noise",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"steps": {
|
||||
"id": "9b63302d-6bd2-42c9-ac13-9b1afb51af88",
|
||||
"name": "steps",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 50
|
||||
},
|
||||
"cfg_scale": {
|
||||
"id": "87dd04d3-870e-49e1-98bf-af003a810109",
|
||||
"name": "cfg_scale",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 7.5
|
||||
},
|
||||
"denoising_start": {
|
||||
"id": "f369d80f-4931-4740-9bcd-9f0620719fab",
|
||||
"name": "denoising_start",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"id": "747d10e5-6f02-445c-994c-0604d814de8c",
|
||||
"name": "denoising_end",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 1
|
||||
},
|
||||
"scheduler": {
|
||||
"id": "1de84a4e-3a24-4ec8-862b-16ce49633b9b",
|
||||
"name": "scheduler",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "SchedulerField"
|
||||
},
|
||||
"value": "unipc"
|
||||
},
|
||||
"unet": {
|
||||
"id": "ffa6fef4-3ce2-4bdb-9296-9a834849489b",
|
||||
"name": "unet",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"control": {
|
||||
"id": "077b64cb-34be-4fcc-83f2-e399807a02bd",
|
||||
"name": "control",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "ControlField"
|
||||
}
|
||||
},
|
||||
"ip_adapter": {
|
||||
"id": "1d6948f7-3a65-4a65-a20c-768b287251aa",
|
||||
"name": "ip_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "IPAdapterField"
|
||||
}
|
||||
},
|
||||
"t2i_adapter": {
|
||||
"id": "75e67b09-952f-4083-aaf4-6b804d690412",
|
||||
"name": "t2i_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "T2IAdapterField"
|
||||
}
|
||||
},
|
||||
"cfg_rescale_multiplier": {
|
||||
"id": "9101f0a6-5fe0-4826-b7b3-47e5d506826c",
|
||||
"name": "cfg_rescale_multiplier",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"latents": {
|
||||
"id": "334d4ba3-5a99-4195-82c5-86fb3f4f7d43",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"denoise_mask": {
|
||||
"id": "0d3dbdbf-b014-4e95-8b18-ff2ff9cb0bfa",
|
||||
"name": "denoise_mask",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "DenoiseMaskField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"latents": {
|
||||
"id": "70fa5bbc-0c38-41bb-861a-74d6d78d2f38",
|
||||
"name": "latents",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "98ee0e6c-82aa-4e8f-8be5-dc5f00ee47f0",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "e8cb184a-5e1a-47c8-9695-4b8979564f5d",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 703,
|
||||
"position": {
|
||||
"x": 1400,
|
||||
"y": 25
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "l2i",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"version": "1.2.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"metadata": {
|
||||
"id": "ab375f12-0042-4410-9182-29e30db82c85",
|
||||
"name": "metadata",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MetadataField"
|
||||
}
|
||||
},
|
||||
"latents": {
|
||||
"id": "3a7e7efd-bff5-47d7-9d48-615127afee78",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"id": "a1f5f7a1-0795-4d58-b036-7820c0b0ef2b",
|
||||
"name": "vae",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
},
|
||||
"tiled": {
|
||||
"id": "da52059a-0cee-4668-942f-519aa794d739",
|
||||
"name": "tiled",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": false
|
||||
},
|
||||
"fp32": {
|
||||
"id": "c4841df3-b24e-4140-be3b-ccd454c2522c",
|
||||
"name": "fp32",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": true
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"image": {
|
||||
"id": "72d667d0-cf85-459d-abf2-28bd8b823fe7",
|
||||
"name": "image",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ImageField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "c8c907d8-1066-49d1-b9a6-83bdcd53addc",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "230f359c-b4ea-436c-b372-332d7dcdca85",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 266,
|
||||
"position": {
|
||||
"x": 1800,
|
||||
"y": 25
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
|
||||
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
|
||||
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"type": "default",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-7d8bf987-284f-413a-b2fd-d825445a5d6cclip",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-93dc02a4-d05b-48ed-b99c-c9b616af3402clip",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-eea2702a-19fb-45b5-9d75-56b4211ec03cnoise",
|
||||
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "noise",
|
||||
"targetHandle": "noise"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-7d8bf987-284f-413a-b2fd-d825445a5d6cconditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cpositive_conditioning",
|
||||
"source": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-93dc02a4-d05b-48ed-b99c-c9b616af3402conditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cnegative_conditioning",
|
||||
"source": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8unet-eea2702a-19fb-45b5-9d75-56b4211ec03cunet",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "unet",
|
||||
"targetHandle": "unet"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
|
||||
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "default",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8vae-58c957f5-0d01-41fc-a803-b2bbf0413d4fvae",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "default",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
}
|
||||
]
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,17 +1,50 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowCategory,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRecordsStorageBase(ABC):
|
||||
"""Base class for workflow storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Get workflow by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
"""Updates a workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
"""Deletes a workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
page: int,
|
||||
per_page: int,
|
||||
order_by: WorkflowRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
category: WorkflowCategory,
|
||||
query: Optional[str],
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
|
@ -1,2 +1,118 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_validator
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
__workflow_meta_version__ = semver.Version.parse("1.0.0")
|
||||
|
||||
|
||||
class ExposedField(BaseModel):
|
||||
nodeId: str
|
||||
fieldName: str
|
||||
|
||||
|
||||
class WorkflowNotFoundError(Exception):
|
||||
"""Raised when a workflow is not found"""
|
||||
|
||||
|
||||
class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum):
|
||||
"""The order by options for workflow records"""
|
||||
|
||||
CreatedAt = "created_at"
|
||||
UpdatedAt = "updated_at"
|
||||
OpenedAt = "opened_at"
|
||||
Name = "name"
|
||||
|
||||
|
||||
class WorkflowCategory(str, Enum, metaclass=MetaEnum):
|
||||
User = "user"
|
||||
Default = "default"
|
||||
|
||||
|
||||
class WorkflowMeta(BaseModel):
|
||||
version: str = Field(description="The version of the workflow schema.")
|
||||
category: WorkflowCategory = Field(
|
||||
default=WorkflowCategory.User, description="The category of the workflow (user or default)."
|
||||
)
|
||||
|
||||
@field_validator("version")
|
||||
def validate_version(cls, version: str):
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
return version
|
||||
except Exception:
|
||||
raise ValueError(f"Invalid workflow meta version: {version}")
|
||||
|
||||
def to_semver(self) -> semver.Version:
|
||||
return semver.Version.parse(self.version)
|
||||
|
||||
|
||||
class WorkflowWithoutID(BaseModel):
|
||||
name: str = Field(description="The name of the workflow.")
|
||||
author: str = Field(description="The author of the workflow.")
|
||||
description: str = Field(description="The description of the workflow.")
|
||||
version: str = Field(description="The version of the workflow.")
|
||||
contact: str = Field(description="The contact of the workflow.")
|
||||
tags: str = Field(description="The tags of the workflow.")
|
||||
notes: str = Field(description="The notes of the workflow.")
|
||||
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
|
||||
meta: WorkflowMeta = Field(description="The meta of the workflow.")
|
||||
# TODO: nodes and edges are very loosely typed
|
||||
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
|
||||
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
|
||||
|
||||
|
||||
class UnsafeWorkflowWithVersion(BaseModel):
|
||||
"""
|
||||
This utility model only requires a workflow to have a valid version string.
|
||||
It is used to validate a workflow version without having to validate the entire workflow.
|
||||
"""
|
||||
|
||||
meta: WorkflowMeta = Field(description="The meta of the workflow.")
|
||||
|
||||
|
||||
UnsafeWorkflowWithVersionValidator = TypeAdapter(UnsafeWorkflowWithVersion)
|
||||
|
||||
|
||||
class Workflow(WorkflowWithoutID):
|
||||
id: str = Field(description="The id of the workflow.")
|
||||
|
||||
|
||||
WorkflowValidator = TypeAdapter(Workflow)
|
||||
|
||||
|
||||
class WorkflowRecordDTOBase(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow.")
|
||||
name: str = Field(description="The name of the workflow.")
|
||||
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
|
||||
opened_at: Union[datetime.datetime, str] = Field(description="The opened timestamp of the workflow.")
|
||||
|
||||
|
||||
class WorkflowRecordDTO(WorkflowRecordDTOBase):
|
||||
workflow: Workflow = Field(description="The workflow.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRecordDTO":
|
||||
data["workflow"] = WorkflowValidator.validate_json(data.get("workflow", ""))
|
||||
return WorkflowRecordDTOValidator.validate_python(data)
|
||||
|
||||
|
||||
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
|
||||
|
||||
|
||||
class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):
|
||||
description: str = Field(description="The description of the workflow.")
|
||||
category: WorkflowCategory = Field(description="The description of the workflow.")
|
||||
|
||||
|
||||
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)
|
||||
|
@ -1,37 +1,53 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowCategory,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordListItemDTOValidator,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowWithoutID,
|
||||
WorkflowWithoutIDValidator,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
_invoker: Invoker
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
self._create_tables()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_workflows()
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow
|
||||
FROM workflows
|
||||
UPDATE workflow_library
|
||||
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
@ -39,25 +55,28 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowFieldValidator.validate_json(row[0])
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||
try:
|
||||
# workflows do not have ids until they are saved
|
||||
workflow_id = uuid_string()
|
||||
workflow.root["id"] = workflow_id
|
||||
# Only user workflows may be created by this method
|
||||
assert workflow.meta.category is WorkflowCategory.User
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO workflows(workflow)
|
||||
VALUES (?);
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
workflow_id,
|
||||
workflow
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(workflow.model_dump_json(),),
|
||||
(workflow_with_id.id, workflow_with_id.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
@ -65,35 +84,148 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow_id)
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
);
|
||||
"""
|
||||
UPDATE workflow_library
|
||||
SET workflow = ?
|
||||
WHERE workflow_id = ? AND category = 'user';
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflows FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflows
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
WHERE workflow_id = ? AND category = 'user';
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
page: int,
|
||||
per_page: int,
|
||||
order_by: WorkflowRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
category: WorkflowCategory,
|
||||
query: Optional[str] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
assert category in WorkflowCategory
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at
|
||||
FROM workflow_library
|
||||
WHERE category = ?
|
||||
"""
|
||||
main_params: list[int | str] = [category.value]
|
||||
count_params: list[int | str] = [category.value]
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
main_query += " AND name LIKE ? OR description LIKE ? "
|
||||
count_query += " AND name LIKE ? OR description LIKE ?;"
|
||||
main_params.extend([wildcard_query, wildcard_query])
|
||||
count_params.extend([wildcard_query, wildcard_query])
|
||||
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value} LIMIT ? OFFSET ?;"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
self._cursor.execute(main_query, main_params)
|
||||
rows = self._cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
self._cursor.execute(count_query, count_params)
|
||||
total = self._cursor.fetchone()[0]
|
||||
pages = int(total / per_page) + 1
|
||||
|
||||
return PaginatedResults(
|
||||
items=workflows,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=pages,
|
||||
total=total,
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
|
||||
"""
|
||||
An enhancement might be to only update workflows that have changed. This would require stable
|
||||
default workflow IDs, and properly incrementing the workflow version.
|
||||
|
||||
It's much simpler to just replace them all with whichever workflows are in the directory.
|
||||
|
||||
The downside is that the `updated_at` and `opened_at` timestamps for default workflows are
|
||||
meaningless, as they are overwritten every time the server starts.
|
||||
"""
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
workflows: list[Workflow] = []
|
||||
workflows_dir = Path(__file__).parent / Path("default_workflows")
|
||||
workflow_paths = workflows_dir.glob("*.json")
|
||||
for path in workflow_paths:
|
||||
bytes_ = path.read_bytes()
|
||||
workflow_without_id = WorkflowWithoutIDValidator.validate_json(bytes_)
|
||||
workflow = Workflow(**workflow_without_id.model_dump(), id=uuid_string())
|
||||
workflows.append(workflow)
|
||||
# Only default workflows may be managed by this method
|
||||
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM workflow_library
|
||||
WHERE category = 'default';
|
||||
"""
|
||||
)
|
||||
for w in workflows:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR REPLACE INTO workflow_library (
|
||||
workflow_id,
|
||||
workflow
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(w.id, w.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
|
@ -32,6 +32,8 @@ class ModelProbeInfo(object):
|
||||
upcast_attention: bool
|
||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||
image_size: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
@ -113,12 +115,16 @@ class ModelProbe(object):
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
name = cls.get_model_name(model_path)
|
||||
description = f"{base_type.value} {model_type.value} model {name}"
|
||||
format = probe.get_format()
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
variant_type=variant_type,
|
||||
prediction_type=prediction_type,
|
||||
name=name,
|
||||
description=description,
|
||||
upcast_attention=(
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
@ -142,6 +148,13 @@ class ModelProbe(object):
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_name(cls, model_path: Path) -> str:
|
||||
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
return model_path.stem
|
||||
else:
|
||||
return model_path.name
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
@ -376,7 +389,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
elif "clip_g" in checkpoint:
|
||||
token_dim = checkpoint["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
token_dim = list(checkpoint.values())[0].shape[-1]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
|
@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_unet_") and (
|
||||
"time_emb_proj.lora_down" in key
|
||||
): # recognizes format at https://civitai.com/models/224641
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
|
29
invokeai/backend/model_manager/__init__.py
Normal file
29
invokeai/backend/model_manager/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from .config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
"ModelProbe",
|
||||
"ModelSearch",
|
||||
"InvalidModelConfigException",
|
||||
"ModelConfigFactory",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"ModelVariantType",
|
||||
"ModelFormat",
|
||||
"SchedulerPredictionType",
|
||||
"AnyModelConfig",
|
||||
]
|
@ -23,7 +23,7 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@ -122,7 +122,7 @@ class ModelConfigBase(BaseModel):
|
||||
validate_assignment=True,
|
||||
)
|
||||
|
||||
def update(self, attributes: dict):
|
||||
def update(self, attributes: Dict[str, Any]) -> None:
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
@ -195,8 +195,6 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
# Note that we do not need prediction_type or upcast_attention here
|
||||
# because they are provided in the checkpoint's own config file.
|
||||
|
||||
|
||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import TypeAdapter
|
||||
@ -10,8 +11,9 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@ -38,24 +40,27 @@ class MigrateModelYamlToDb:
|
||||
"""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
logger: InvokeAILogger
|
||||
logger: Logger
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.config.parse_args()
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
|
||||
def get_db(self) -> ModelRecordServiceSQL:
|
||||
"""Fetch the sqlite3 database for this installation."""
|
||||
db = SqliteDatabase(self.config, self.logger)
|
||||
db_path = None if self.config.use_memory_db else self.config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=self.logger, verbose=self.config.log_sql)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
yaml_path = self.config.model_conf_path
|
||||
return OmegaConf.load(yaml_path)
|
||||
omegaconf = OmegaConf.load(yaml_path)
|
||||
assert isinstance(omegaconf, DictConfig)
|
||||
return omegaconf
|
||||
|
||||
def migrate(self):
|
||||
def migrate(self) -> None:
|
||||
"""Do the migration from models.yaml to invokeai.db."""
|
||||
db = self.get_db()
|
||||
yaml = self.get_yaml()
|
||||
@ -69,6 +74,7 @@ class MigrateModelYamlToDb:
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
stanza["base"] = BaseModelType(base_type)
|
||||
@ -77,12 +83,20 @@ class MigrateModelYamlToDb:
|
||||
stanza["original_hash"] = hash
|
||||
stanza["current_hash"] = hash
|
||||
|
||||
new_config = ModelsValidator.validate_python(stanza)
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
try:
|
||||
db.add_model(new_key, new_config)
|
||||
if original_record := db.search_by_path(stanza.path):
|
||||
key = original_record[0].key
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
db.update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
db.add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
except UnknownModelException:
|
||||
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||
|
||||
|
||||
def main():
|
||||
|
686
invokeai/backend/model_manager/probe.py
Normal file
686
invokeai/backend/model_manager/probe.py
Normal file
@ -0,0 +1,686 @@
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models.base import read_checkpoint_meta
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
from invokeai.backend.model_management.util import lora_token_vector_length
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
|
||||
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""Base class for probes."""
|
||||
|
||||
def __init__(self, model_path: Path):
|
||||
self.model_path = model_path
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get model base type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
"""Get model file format."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_variant_type(self) -> Optional[ModelVariantType]:
|
||||
"""Get model variant type."""
|
||||
return None
|
||||
|
||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||
"""Get model scheduler prediction type."""
|
||||
return None
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
|
||||
) -> None:
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
return cls.probe(model_path, fields)
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
fields: Optional[Dict[str, Any]] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Probe the model at model_path and return its configuration record.
|
||||
|
||||
:param model_path: Path to the model file (checkpoint) or directory (diffusers).
|
||||
:param fields: An optional dictionary that can be used to override probed
|
||||
fields. Typically used for fields that don't probe well, such as prediction_type.
|
||||
|
||||
Returns: The appropriate model configuration derived from ModelConfigBase.
|
||||
"""
|
||||
if fields is None:
|
||||
fields = {}
|
||||
|
||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||
model_info = None
|
||||
model_type = None
|
||||
if format_type == "diffusers":
|
||||
model_type = cls.get_model_type_from_folder(model_path)
|
||||
else:
|
||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
||||
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = FastModelHash.hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["path"] = model_path.as_posix()
|
||||
fields["type"] = fields.get("type") or model_type
|
||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
||||
fields["config"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
base_type=fields["base"],
|
||||
variant_type=fields["variant"],
|
||||
prediction_type=fields["prediction_type"],
|
||||
).as_posix()
|
||||
|
||||
# additional fields needed for main non-checkpoint models
|
||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||
ModelFormat.Onnx,
|
||||
ModelFormat.Olive,
|
||||
ModelFormat.Diffusers,
|
||||
]:
|
||||
fields["upcast_attention"] = fields.get("upcast_attention") or (
|
||||
fields["base"] == BaseModelType.StableDiffusion2
|
||||
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
|
||||
model_info = ModelConfigFactory.make_config(fields)
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_name(cls, model_path: Path) -> str:
|
||||
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
return model_path.stem
|
||||
else:
|
||||
return model_path.name
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
raise InvalidModelConfigException(f"{model_path}: unrecognized suffix")
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
return ModelType.TextualInversion
|
||||
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in ckpt.keys():
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path) -> ModelType:
|
||||
"""Get the model type of a hugging-face style folder."""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
for suffix in ["bin", "safetensors"]:
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelConfigException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_checkpoint_config_path(
|
||||
cls,
|
||||
model_path: Path,
|
||||
model_type: ModelType,
|
||||
base_type: BaseModelType,
|
||||
variant_type: ModelVariantType,
|
||||
prediction_type: SchedulerPredictionType,
|
||||
) -> Path:
|
||||
# look for a YAML file adjacent to the model file first
|
||||
possible_conf = model_path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type == ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
elif model_type == ModelType.ControlNet:
|
||||
config_file = (
|
||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
||||
)
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}"
|
||||
)
|
||||
assert isinstance(config_file, str)
|
||||
return Path(config_file)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path.name, model_path)
|
||||
model = torch.load(model_path)
|
||||
assert isinstance(model, dict)
|
||||
return model
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(self, model_path: Path):
|
||||
super().__init__(model_path)
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("checkpoint")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}"
|
||||
)
|
||||
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
else:
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
"""Return model prediction type."""
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion2:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||
|
||||
elif type == BaseModelType.StableDiffusion1:
|
||||
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||
else:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for LoRA checkpoints."""
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("lycoris")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for probing embeddings."""
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.EmbeddingFile
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if "string_to_token" in checkpoint:
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in checkpoint:
|
||||
token_dim = checkpoint["emb_params"].shape[-1]
|
||||
elif "clip_g" in checkpoint:
|
||||
token_dim = checkpoint["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_dim == 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
"""Class for probing controlnets."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
raise InvalidModelConfigException("{self.model_path}: Unable to determine base type")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("diffusers")
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
config_file = self.model_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.model_path.name
|
||||
if name == "vae":
|
||||
name = self.model_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.EmbeddingFolder
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
path = self.model_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
raise InvalidModelConfigException(
|
||||
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
|
||||
)
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat("onnx")
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||
return base_model
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
|
||||
if base_file.exists():
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelConfigException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> IPAdapterModelFormat:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.model_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelConfigException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
|
||||
adapter_type = config.get("adapter_type", None)
|
||||
if adapter_type == "full_adapter_xl":
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif adapter_type == "full_adapter" or "light_adapter":
|
||||
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
||||
return BaseModelType.StableDiffusion1
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})."
|
||||
)
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
190
invokeai/backend/model_manager/search.py
Normal file
190
invokeai/backend/model_manager/search.py
Normal file
@ -0,0 +1,190 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class and implementation for recursive directory search for models.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||
|
||||
def find_main_models(model: Path) -> bool:
|
||||
info = ModelProbe.probe(model)
|
||||
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
search = ModelSearch(on_model_found=report_it)
|
||||
found = search.search('/tmp/models')
|
||||
print(found) # list of matching model paths
|
||||
print(search.stats) # search stats
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class SearchStats(BaseModel):
|
||||
items_scanned: int = 0
|
||||
models_found: int = 0
|
||||
models_filtered: int = 0
|
||||
|
||||
|
||||
class ModelSearchBase(ABC, BaseModel):
|
||||
"""
|
||||
Abstract directory traversal model search class
|
||||
|
||||
Usage:
|
||||
search = ModelSearchBase(
|
||||
on_search_started = search_started_callback,
|
||||
on_search_completed = search_completed_callback,
|
||||
on_model_found = model_found_callback,
|
||||
)
|
||||
models_found = search.search('/path/to/directory')
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def search_started(self) -> None:
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the root search directory to the Callable `on_search_started`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_found(self, model: Path) -> None:
|
||||
"""
|
||||
Called when a model is found during search.
|
||||
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
|
||||
Passes the model's Path to the Callable `on_model_found`.
|
||||
This Callable receives the path to the model and returns a boolean
|
||||
to indicate whether the model should be returned in the search
|
||||
results.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_completed(self) -> None:
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
"""
|
||||
Recursively search for models in `directory` and return a set of model paths.
|
||||
|
||||
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
||||
Callables will be invoked during the search.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelSearch(ModelSearchBase):
|
||||
"""
|
||||
Implementation of ModelSearch with callbacks.
|
||||
Usage:
|
||||
search = ModelSearch()
|
||||
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
models_found: Set[Path] = Field(default=None)
|
||||
scanned_dirs: Set[Path] = Field(default=None)
|
||||
pruned_paths: Set[Path] = Field(default=None)
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
self.scanned_dirs = set()
|
||||
self.pruned_paths = set()
|
||||
if self.on_search_started:
|
||||
self.on_search_started(self._directory)
|
||||
|
||||
def model_found(self, model: Path) -> None:
|
||||
self.stats.models_found += 1
|
||||
if not self.on_model_found or self.on_model_found(model):
|
||||
self.stats.models_filtered += 1
|
||||
self.models_found.add(model)
|
||||
|
||||
def search_completed(self) -> None:
|
||||
if self.on_search_completed:
|
||||
self.on_search_completed(self._models_found)
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = Path(directory)
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
self.search_completed()
|
||||
return self.models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str]) -> None:
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self.pruned_paths.add(Path(root))
|
||||
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
||||
continue
|
||||
|
||||
self.stats.items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path.parent in self.scanned_dirs:
|
||||
self.scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
(path / x).exists()
|
||||
for x in [
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
]
|
||||
):
|
||||
self.scanned_dirs.add(path)
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self.scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.model_found(path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
@ -242,17 +242,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
control_model: ControlNetModel = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
requires_safety_checker,
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
@ -260,9 +249,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
# FIXME: can't currently register control module
|
||||
# control_model=control_model,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
self.control_model = control_model
|
||||
self.use_ip_adapter = False
|
||||
|
@ -3,7 +3,42 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile, paste
|
||||
from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend
|
||||
|
||||
|
||||
def calc_overlap(tiles: list[Tile], num_tiles_x: int, num_tiles_y: int) -> list[Tile]:
|
||||
"""Calculate and update the overlap of a list of tiles.
|
||||
|
||||
Args:
|
||||
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
|
||||
num_tiles_x: the number of tiles on the x axis.
|
||||
num_tiles_y: the number of tiles on the y axis.
|
||||
"""
|
||||
|
||||
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
|
||||
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
|
||||
return None
|
||||
return tiles[idx_y * num_tiles_x + idx_x]
|
||||
|
||||
for tile_idx_y in range(num_tiles_y):
|
||||
for tile_idx_x in range(num_tiles_x):
|
||||
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
|
||||
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
|
||||
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
|
||||
|
||||
assert cur_tile is not None
|
||||
|
||||
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
|
||||
if top_neighbor_tile is not None:
|
||||
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
|
||||
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
|
||||
|
||||
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
|
||||
if left_neighbor_tile is not None:
|
||||
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
|
||||
left_neighbor_tile.overlap.right = cur_tile.overlap.left
|
||||
return tiles
|
||||
|
||||
|
||||
def calc_tiles_with_overlap(
|
||||
@ -63,31 +98,133 @@ def calc_tiles_with_overlap(
|
||||
|
||||
tiles.append(tile)
|
||||
|
||||
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
|
||||
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
|
||||
return None
|
||||
return tiles[idx_y * num_tiles_x + idx_x]
|
||||
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
|
||||
|
||||
# Iterate over tiles again and calculate overlaps.
|
||||
|
||||
def calc_tiles_even_split(
|
||||
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap: int = 0
|
||||
) -> list[Tile]:
|
||||
"""Calculate the tile coordinates for a given image shape with the number of tiles requested.
|
||||
|
||||
Args:
|
||||
image_height (int): The image height in px.
|
||||
image_width (int): The image width in px.
|
||||
num_x_tiles (int): The number of tile to split the image into on the X-axis.
|
||||
num_y_tiles (int): The number of tile to split the image into on the Y-axis.
|
||||
overlap (int, optional): The overlap between adjacent tiles in pixels. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
|
||||
"""
|
||||
# Ensure the image is divisible by LATENT_SCALE_FACTOR
|
||||
if image_width % LATENT_SCALE_FACTOR != 0 or image_height % LATENT_SCALE_FACTOR != 0:
|
||||
raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by {LATENT_SCALE_FACTOR}")
|
||||
|
||||
# Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down)
|
||||
if num_tiles_x > 1:
|
||||
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
|
||||
assert overlap <= image_width - (LATENT_SCALE_FACTOR * (num_tiles_x - 1))
|
||||
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
|
||||
((image_width + overlap * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
assert overlap < tile_size_x
|
||||
else:
|
||||
tile_size_x = image_width
|
||||
|
||||
if num_tiles_y > 1:
|
||||
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
|
||||
assert overlap <= image_height - (LATENT_SCALE_FACTOR * (num_tiles_y - 1))
|
||||
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
|
||||
((image_height + overlap * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
assert overlap < tile_size_y
|
||||
else:
|
||||
tile_size_y = image_height
|
||||
|
||||
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||
tiles: list[Tile] = []
|
||||
|
||||
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||
for tile_idx_y in range(num_tiles_y):
|
||||
# Calculate the top and bottom of the row
|
||||
top = tile_idx_y * (tile_size_y - overlap)
|
||||
bottom = min(top + tile_size_y, image_height)
|
||||
# For the last row adjust bottom to be the height of the image
|
||||
if tile_idx_y == num_tiles_y - 1:
|
||||
bottom = image_height
|
||||
|
||||
for tile_idx_x in range(num_tiles_x):
|
||||
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
|
||||
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
|
||||
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
|
||||
# Calculate the left & right coordinate of each tile
|
||||
left = tile_idx_x * (tile_size_x - overlap)
|
||||
right = min(left + tile_size_x, image_width)
|
||||
# For the last tile in the row adjust right to be the width of the image
|
||||
if tile_idx_x == num_tiles_x - 1:
|
||||
right = image_width
|
||||
|
||||
assert cur_tile is not None
|
||||
tile = Tile(
|
||||
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
|
||||
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||
)
|
||||
|
||||
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
|
||||
if top_neighbor_tile is not None:
|
||||
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
|
||||
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
|
||||
tiles.append(tile)
|
||||
|
||||
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
|
||||
if left_neighbor_tile is not None:
|
||||
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
|
||||
left_neighbor_tile.overlap.right = cur_tile.overlap.left
|
||||
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
|
||||
|
||||
return tiles
|
||||
|
||||
def calc_tiles_min_overlap(
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
tile_height: int,
|
||||
tile_width: int,
|
||||
min_overlap: int = 0,
|
||||
) -> list[Tile]:
|
||||
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
|
||||
|
||||
Args:
|
||||
image_height (int): The image height in px.
|
||||
image_width (int): The image width in px.
|
||||
tile_height (int): The tile height in px. All tiles will have this height.
|
||||
tile_width (int): The tile width in px. All tiles will have this width.
|
||||
min_overlap (int): The target minimum overlap between adjacent tiles. If the tiles do not evenly cover the image
|
||||
shape, then the overlap will be spread between the tiles.
|
||||
|
||||
Returns:
|
||||
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
|
||||
"""
|
||||
|
||||
assert min_overlap < tile_height
|
||||
assert min_overlap < tile_width
|
||||
|
||||
# catches the cases when the tile size is larger than the images size and adjusts the tile size
|
||||
if image_width < tile_width:
|
||||
tile_width = image_width
|
||||
|
||||
if image_height < tile_height:
|
||||
tile_height = image_height
|
||||
|
||||
num_tiles_x = math.ceil((image_width - min_overlap) / (tile_width - min_overlap))
|
||||
num_tiles_y = math.ceil((image_height - min_overlap) / (tile_height - min_overlap))
|
||||
|
||||
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||
tiles: list[Tile] = []
|
||||
|
||||
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||
for tile_idx_y in range(num_tiles_y):
|
||||
top = (tile_idx_y * (image_height - tile_height)) // (num_tiles_y - 1) if num_tiles_y > 1 else 0
|
||||
bottom = top + tile_height
|
||||
|
||||
for tile_idx_x in range(num_tiles_x):
|
||||
left = (tile_idx_x * (image_width - tile_width)) // (num_tiles_x - 1) if num_tiles_x > 1 else 0
|
||||
right = left + tile_width
|
||||
|
||||
tile = Tile(
|
||||
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
|
||||
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||
)
|
||||
|
||||
tiles.append(tile)
|
||||
|
||||
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
|
||||
|
||||
|
||||
def merge_tiles_with_linear_blending(
|
||||
@ -199,3 +336,91 @@ def merge_tiles_with_linear_blending(
|
||||
),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def merge_tiles_with_seam_blending(
|
||||
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
|
||||
):
|
||||
"""Merge a set of image tiles into `dst_image` with seam blending between the tiles.
|
||||
|
||||
We expect every tile edge to either:
|
||||
1) have an overlap of 0, because it is aligned with the image edge, or
|
||||
2) have an overlap >= blend_amount.
|
||||
If neither of these conditions are satisfied, we raise an exception.
|
||||
|
||||
The seam blending is centered on a seam of least energy of the overlap between adjacent tiles.
|
||||
|
||||
Args:
|
||||
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
|
||||
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
|
||||
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
|
||||
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
|
||||
"""
|
||||
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
|
||||
# iterate over tiles left-to-right, top-to-bottom.
|
||||
tiles_and_images = list(zip(tiles, tile_images, strict=True))
|
||||
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
|
||||
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
|
||||
|
||||
# Organize tiles into rows.
|
||||
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
|
||||
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
|
||||
first_tile_in_cur_row, _ = tiles_and_images[0]
|
||||
for tile_and_image in tiles_and_images:
|
||||
tile, _ = tile_and_image
|
||||
if not (
|
||||
tile.coords.top == first_tile_in_cur_row.coords.top
|
||||
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
|
||||
):
|
||||
# Store the previous row, and start a new one.
|
||||
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||
cur_tile_and_image_row = []
|
||||
first_tile_in_cur_row, _ = tile_and_image
|
||||
|
||||
cur_tile_and_image_row.append(tile_and_image)
|
||||
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||
|
||||
for tile_and_image_row in tile_and_image_rows:
|
||||
first_tile_in_row, _ = tile_and_image_row[0]
|
||||
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
|
||||
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
|
||||
|
||||
# Blend the tiles in the row horizontally.
|
||||
for tile, tile_image in tile_and_image_row:
|
||||
# We expect the tiles to be ordered left-to-right.
|
||||
# For each tile:
|
||||
# - extract the overlap regions and pass to seam_blend()
|
||||
# - apply blended region to the row_image
|
||||
# - apply the un-blended region to the row_image
|
||||
tile_height, tile_width, _ = tile_image.shape
|
||||
overlap_size = tile.overlap.left
|
||||
# Left blending:
|
||||
if overlap_size > 0:
|
||||
assert overlap_size >= blend_amount
|
||||
|
||||
overlap_coord_right = tile.coords.left + overlap_size
|
||||
src_overlap = row_image[:, tile.coords.left : overlap_coord_right]
|
||||
dst_overlap = tile_image[:, :overlap_size]
|
||||
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=False)
|
||||
row_image[:, tile.coords.left : overlap_coord_right] = blended_overlap
|
||||
row_image[:, overlap_coord_right : tile.coords.right] = tile_image[:, overlap_size:]
|
||||
else:
|
||||
# no overlap just paste the tile
|
||||
row_image[:, tile.coords.left : tile.coords.right] = tile_image
|
||||
|
||||
# Blend the row into the dst_image
|
||||
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
|
||||
# Rows are processed in the same way as tiles (extract overlap, blend, apply)
|
||||
row_overlap_size = first_tile_in_row.overlap.top
|
||||
if row_overlap_size > 0:
|
||||
assert row_overlap_size >= blend_amount
|
||||
|
||||
overlap_coords_bottom = first_tile_in_row.coords.top + row_overlap_size
|
||||
src_overlap = dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :]
|
||||
dst_overlap = row_image[:row_overlap_size, :]
|
||||
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=True)
|
||||
dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] = blended_overlap
|
||||
dst_image[overlap_coords_bottom : first_tile_in_row.coords.bottom, :] = row_image[row_overlap_size:, :]
|
||||
else:
|
||||
# no overlap just paste the row
|
||||
dst_image[first_tile_in_row.coords.top : first_tile_in_row.coords.bottom, :] = row_image
|
||||
|
@ -1,5 +1,7 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -31,10 +33,10 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
|
||||
"""Paste a source image into a destination image.
|
||||
|
||||
Args:
|
||||
dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C).
|
||||
src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
|
||||
dst_image (np.array): The destination image to paste into. Shape: (H, W, C).
|
||||
src_image (np.array): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
|
||||
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
|
||||
mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'.
|
||||
mask (Optional[np.array]): A mask that defines the blending between 'src_image' and 'dst_image'.
|
||||
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
|
||||
`src * mask + dst * (1 - mask)`.
|
||||
"""
|
||||
@ -45,3 +47,106 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
|
||||
mask = np.expand_dims(mask, -1)
|
||||
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
|
||||
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
|
||||
|
||||
|
||||
def seam_blend(ia1: np.ndarray, ia2: np.ndarray, blend_amount: int, x_seam: bool) -> np.ndarray:
|
||||
"""Blend two overlapping tile sections using a seams to find a path.
|
||||
|
||||
It is assumed that input images will be RGB np arrays and are the same size.
|
||||
|
||||
Args:
|
||||
ia1 (np.array): Image array 1 Shape: (H, W, C).
|
||||
ia2 (np.array): Image array 2 Shape: (H, W, C).
|
||||
x_seam (bool): If the images should be blended on the x axis or not.
|
||||
blend_amount (int): The size of the blur to use on the seam. Half of this value will be used to avoid the edges of the image.
|
||||
"""
|
||||
assert ia1.shape == ia2.shape
|
||||
assert ia2.size == ia2.size
|
||||
|
||||
def shift(arr, num, fill_value=255.0):
|
||||
result = np.full_like(arr, fill_value)
|
||||
if num > 0:
|
||||
result[num:] = arr[:-num]
|
||||
elif num < 0:
|
||||
result[:num] = arr[-num:]
|
||||
else:
|
||||
result[:] = arr
|
||||
return result
|
||||
|
||||
# Assume RGB and convert to grey
|
||||
# Could offer other options for the luminance conversion
|
||||
# BT.709 [0.2126, 0.7152, 0.0722], BT.2020 [0.2627, 0.6780, 0.0593])
|
||||
# it might not have a huge impact due to the blur that is applied over the seam
|
||||
iag1 = np.dot(ia1, [0.2989, 0.5870, 0.1140]) # BT.601 perceived brightness
|
||||
iag2 = np.dot(ia2, [0.2989, 0.5870, 0.1140])
|
||||
|
||||
# Calc Difference between the images
|
||||
ia = iag2 - iag1
|
||||
|
||||
# If the seam is on the X-axis rotate the array so we can treat it like a vertical seam
|
||||
if x_seam:
|
||||
ia = np.rot90(ia, 1)
|
||||
|
||||
# Calc max and min X & Y limits
|
||||
# gutter is used to avoid the blur hitting the edge of the image
|
||||
gutter = math.ceil(blend_amount / 2) if blend_amount > 0 else 0
|
||||
max_y, max_x = ia.shape
|
||||
max_x -= gutter
|
||||
min_x = gutter
|
||||
|
||||
# Calc the energy in the difference
|
||||
# Could offer different energy calculations e.g. Sobel or Scharr
|
||||
energy = np.abs(np.gradient(ia, axis=0)) + np.abs(np.gradient(ia, axis=1))
|
||||
|
||||
# Find the starting position of the seam
|
||||
res = np.copy(energy)
|
||||
for y in range(1, max_y):
|
||||
row = res[y, :]
|
||||
rowl = shift(row, -1)
|
||||
rowr = shift(row, 1)
|
||||
res[y, :] = res[y - 1, :] + np.min([row, rowl, rowr], axis=0)
|
||||
|
||||
# create an array max_y long
|
||||
lowest_energy_line = np.empty([max_y], dtype="uint16")
|
||||
lowest_energy_line[max_y - 1] = np.argmin(res[max_y - 1, min_x : max_x - 1])
|
||||
|
||||
# Calc the path of the seam
|
||||
# could offer options for larger search than just 1 pixel by adjusting lpos and rpos
|
||||
for ypos in range(max_y - 2, -1, -1):
|
||||
lowest_pos = lowest_energy_line[ypos + 1]
|
||||
lpos = lowest_pos - 1
|
||||
rpos = lowest_pos + 1
|
||||
lpos = np.clip(lpos, min_x, max_x - 1)
|
||||
rpos = np.clip(rpos, min_x, max_x - 1)
|
||||
lowest_energy_line[ypos] = np.argmin(energy[ypos, lpos : rpos + 1]) + lpos
|
||||
|
||||
# Draw the mask
|
||||
mask = np.zeros_like(ia)
|
||||
for ypos in range(0, max_y):
|
||||
to_fill = lowest_energy_line[ypos]
|
||||
mask[ypos, :to_fill] = 1
|
||||
|
||||
# If the seam is on the X-axis rotate the array back
|
||||
if x_seam:
|
||||
mask = np.rot90(mask, 3)
|
||||
|
||||
# blur the seam mask if required
|
||||
if blend_amount > 0:
|
||||
mask = cv2.blur(mask, (blend_amount, blend_amount))
|
||||
|
||||
# for visual debugging
|
||||
# from PIL import Image
|
||||
# m_image = Image.fromarray((mask * 255.0).astype("uint8"))
|
||||
|
||||
# copy ia2 over ia1 while applying the seam mask
|
||||
mask = np.expand_dims(mask, -1)
|
||||
blended_image = ia1 * mask + ia2 * (1.0 - mask)
|
||||
|
||||
# for visual debugging
|
||||
# i1 = Image.fromarray(ia1.astype("uint8"))
|
||||
# i2 = Image.fromarray(ia2.astype("uint8"))
|
||||
# b_image = Image.fromarray(blended_image.astype("uint8"))
|
||||
# print(f"{ia1.shape}, {ia2.shape}, {mask.shape}, {blended_image.shape}")
|
||||
# print(f"{i1.size}, {i2.size}, {m_image.size}, {b_image.size}")
|
||||
|
||||
return blended_image
|
||||
|
@ -11,4 +11,7 @@ from .devices import ( # noqa: F401
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
)
|
||||
from .logging import InvokeAILogger
|
||||
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
|
||||
|
||||
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]
|
||||
|
@ -342,14 +342,13 @@ class InvokeAILogger(object): # noqa D102
|
||||
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
) -> logging.Logger: # noqa D102
|
||||
if name in cls.loggers:
|
||||
logger = cls.loggers[name]
|
||||
logger.handlers.clear()
|
||||
else:
|
||||
logger = logging.getLogger(name)
|
||||
return cls.loggers[name]
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
for ch in cls.get_loggers(config):
|
||||
logger.addHandler(ch)
|
||||
cls.loggers[name] = logger
|
||||
cls.loggers[name] = logger
|
||||
return cls.loggers[name]
|
||||
|
||||
@classmethod
|
||||
@ -358,7 +357,7 @@ class InvokeAILogger(object): # noqa D102
|
||||
handlers = []
|
||||
for handler in handler_strs:
|
||||
handler_name, *args = handler.split("=", 2)
|
||||
args = args[0] if len(args) > 0 else None
|
||||
arg = args[0] if len(args) > 0 else None
|
||||
|
||||
# console and file get the fancy formatter.
|
||||
# syslog gets a simple one
|
||||
@ -370,16 +369,16 @@ class InvokeAILogger(object): # noqa D102
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name == "syslog":
|
||||
ch = cls._parse_syslog_args(args)
|
||||
ch = cls._parse_syslog_args(arg)
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name == "file":
|
||||
ch = cls._parse_file_args(args)
|
||||
ch = cls._parse_file_args(arg)
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name == "http":
|
||||
ch = cls._parse_http_args(args)
|
||||
ch = cls._parse_http_args(arg)
|
||||
handlers.append(ch)
|
||||
return handlers
|
||||
|
||||
|
@ -32,9 +32,9 @@ sd-1/main/Analog-Diffusion:
|
||||
description: An SD-1.5 model trained on diverse analog photographs (2.13 GB)
|
||||
repo_id: wavymulder/Analog-Diffusion
|
||||
recommended: False
|
||||
sd-1/main/Deliberate:
|
||||
sd-1/main/Deliberate_v5:
|
||||
description: Versatile model that produces detailed images up to 768px (4.27 GB)
|
||||
repo_id: XpucT/Deliberate
|
||||
path: https://huggingface.co/XpucT/Deliberate/resolve/main/Deliberate_v5.safetensors
|
||||
recommended: False
|
||||
sd-1/main/Dungeons-and-Diffusion:
|
||||
description: Dungeons & Dragons characters (2.13 GB)
|
||||
|
@ -4,6 +4,7 @@ pip install <path_to_git_source>.
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import pkg_resources
|
||||
import psutil
|
||||
@ -31,10 +32,6 @@ else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
|
||||
def get_versions() -> dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
|
||||
def invokeai_is_running() -> bool:
|
||||
for p in psutil.process_iter():
|
||||
try:
|
||||
@ -50,6 +47,20 @@ def invokeai_is_running() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_pypi_versions():
|
||||
url = "https://pypi.org/pypi/invokeai/json"
|
||||
try:
|
||||
data = requests.get(url).json()
|
||||
except Exception:
|
||||
raise Exception("Unable to fetch version information from PyPi")
|
||||
|
||||
versions = list(data["releases"].keys())
|
||||
versions.sort(key=LooseVersion, reverse=True)
|
||||
latest_version = [v for v in versions if "rc" not in v][0]
|
||||
latest_release_candidate = [v for v in versions if "rc" in v][0]
|
||||
return latest_version, latest_release_candidate, versions
|
||||
|
||||
|
||||
def welcome(latest_release: str, latest_prerelease: str):
|
||||
@group()
|
||||
def text():
|
||||
@ -63,8 +74,7 @@ def welcome(latest_release: str, latest_prerelease: str):
|
||||
yield "[bold yellow]Options:"
|
||||
yield f"""[1] Update to the latest [bold]official release[/bold] ([italic]{latest_release}[/italic])
|
||||
[2] Update to the latest [bold]pre-release[/bold] (may be buggy; caveat emptor!) ([italic]{latest_prerelease}[/italic])
|
||||
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
|
||||
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to"""
|
||||
[3] Manually enter the [bold]version[/bold] you wish to update to"""
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@ -92,44 +102,35 @@ def get_extras():
|
||||
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
released_versions = [x for x in versions if not (x["draft"] or x["prerelease"])]
|
||||
prerelease_versions = [x for x in versions if not x["draft"] and x["prerelease"]]
|
||||
latest_release = released_versions[0]["tag_name"] if len(released_versions) else None
|
||||
latest_prerelease = prerelease_versions[0]["tag_name"] if len(prerelease_versions) else None
|
||||
|
||||
if invokeai_is_running():
|
||||
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
||||
input("Press any key to continue...")
|
||||
return
|
||||
|
||||
latest_release, latest_prerelease, versions = get_pypi_versions()
|
||||
|
||||
welcome(latest_release, latest_prerelease)
|
||||
|
||||
tag = None
|
||||
branch = None
|
||||
release = None
|
||||
choice = Prompt.ask("Choice:", choices=["1", "2", "3", "4"], default="1")
|
||||
release = latest_release
|
||||
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
|
||||
|
||||
if choice == "1":
|
||||
release = latest_release
|
||||
elif choice == "2":
|
||||
release = latest_prerelease
|
||||
elif choice == "3":
|
||||
while not tag:
|
||||
tag = Prompt.ask("Enter an InvokeAI tag name")
|
||||
elif choice == "4":
|
||||
while not branch:
|
||||
branch = Prompt.ask("Enter an InvokeAI branch name")
|
||||
while True:
|
||||
release = Prompt.ask("Enter an InvokeAI version")
|
||||
release.strip()
|
||||
if release in versions:
|
||||
break
|
||||
print(f":exclamation: [bold red]'{release}' is not a recognized InvokeAI release.[/red bold]")
|
||||
|
||||
extras = get_extras()
|
||||
|
||||
print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
|
||||
if release:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
|
||||
elif tag:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
|
||||
else:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
|
||||
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
|
||||
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade'
|
||||
|
||||
print("")
|
||||
print("")
|
||||
if os.system(cmd) == 0:
|
||||
|
@ -11,6 +11,7 @@ module.exports = {
|
||||
'plugin:react-hooks/recommended',
|
||||
'plugin:react/jsx-runtime',
|
||||
'prettier',
|
||||
'plugin:storybook/recommended',
|
||||
],
|
||||
parser: '@typescript-eslint/parser',
|
||||
parserOptions: {
|
||||
@ -26,6 +27,7 @@ module.exports = {
|
||||
'eslint-plugin-react-hooks',
|
||||
'i18next',
|
||||
'path',
|
||||
'unused-imports',
|
||||
],
|
||||
root: true,
|
||||
rules: {
|
||||
@ -44,9 +46,16 @@ module.exports = {
|
||||
radix: 'error',
|
||||
'space-before-blocks': 'error',
|
||||
'import/prefer-default-export': 'off',
|
||||
'@typescript-eslint/no-unused-vars': [
|
||||
'@typescript-eslint/no-unused-vars': 'off',
|
||||
'unused-imports/no-unused-imports': 'error',
|
||||
'unused-imports/no-unused-vars': [
|
||||
'warn',
|
||||
{ varsIgnorePattern: '^_', argsIgnorePattern: '^_' },
|
||||
{
|
||||
vars: 'all',
|
||||
varsIgnorePattern: '^_',
|
||||
args: 'after-used',
|
||||
argsIgnorePattern: '^_',
|
||||
},
|
||||
],
|
||||
'@typescript-eslint/ban-ts-comment': 'warn',
|
||||
'@typescript-eslint/no-explicit-any': 'warn',
|
||||
|
3
invokeai/frontend/web/.gitignore
vendored
3
invokeai/frontend/web/.gitignore
vendored
@ -10,6 +10,7 @@ lerna-debug.log*
|
||||
node_modules
|
||||
# We want to distribute the repo
|
||||
# dist
|
||||
# dist/**
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
@ -38,4 +39,4 @@ stats.html
|
||||
|
||||
# Yalc
|
||||
.yalc
|
||||
yalc.lock
|
||||
yalc.lock
|
||||
|
@ -1,4 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
. "$(dirname -- "$0")/_/husky.sh"
|
||||
|
||||
cd invokeai/frontend/web/ && npm run lint-staged
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user