mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
68 Commits
release/ad
...
refactor/m
Author | SHA1 | Date | |
---|---|---|---|
26c77e8522 | |||
fbede84405 | |||
756cb9c27e | |||
78b29db458 | |||
1225c3fb47 | |||
4957a360ff | |||
32ad742f3e | |||
2d11d97dad | |||
64858b2523 | |||
d5134325f6 | |||
702d0f68af | |||
a0d0e9f474 | |||
475823835f | |||
b95d547ccc | |||
19be196d50 | |||
f123ee61d0 | |||
9b4758f02f | |||
a626ca3e1c | |||
8d2952695d | |||
8dd55cc45e | |||
562fb1f3a1 | |||
1940169925 | |||
29b049b9d9 | |||
79cf3ec9a5 | |||
37b76caccf | |||
a4f9bfc8f7 | |||
9afdd0f4a8 | |||
bee6ad1547 | |||
87a5b771c4 | |||
fa3f1b6e41 | |||
e86f3fe29e | |||
d0fa131010 | |||
2f438431bd | |||
bbeb5cb477 | |||
cd3111c324 | |||
16b7246412 | |||
42be78d328 | |||
e469e24a58 | |||
cb698ff1fb | |||
0e738c4290 | |||
09d1bc513d | |||
c610283158 | |||
aefa828237 | |||
74ea592d02 | |||
457b0dfac0 | |||
96a717c4ba | |||
77b74264a8 | |||
351078e8aa | |||
b8354bd1a4 | |||
3b944b8af6 | |||
b811c037bd | |||
5bf61382a4 | |||
0f1c5f382a | |||
4af1695c60 | |||
df9a903a50 | |||
311be8f97d | |||
3f970c8326 | |||
fc150acde5 | |||
1615df3aa1 | |||
b2a8c45553 | |||
212dbaf9a2 | |||
ac3cf48d7f | |||
296060db63 | |||
d1d8ee71fc | |||
612912a6c9 | |||
bca2372280 | |||
0b860582f0 | |||
87ff380fe4 |
6
.github/workflows/lint-frontend.yml
vendored
6
.github/workflows/lint-frontend.yml
vendored
@ -21,16 +21,16 @@ jobs:
|
||||
if: github.event.pull_request.draft == false
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup Node 20
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
node-version: '18'
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: 8
|
||||
version: '8.12.1'
|
||||
- name: Install dependencies
|
||||
run: 'pnpm install --prefer-frozen-lockfile'
|
||||
- name: Typescript
|
||||
|
50
.github/workflows/pypi-release.yml
vendored
50
.github/workflows/pypi-release.yml
vendored
@ -1,13 +1,15 @@
|
||||
name: PyPI Release
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'invokeai/version/invokeai_version.py'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
publish_package:
|
||||
description: 'Publish build on PyPi? [true/false]'
|
||||
required: true
|
||||
default: 'false'
|
||||
|
||||
jobs:
|
||||
release:
|
||||
build-and-release:
|
||||
if: github.repository == 'invoke-ai/InvokeAI'
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
@ -15,19 +17,43 @@ jobs:
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||
TWINE_NON_INTERACTIVE: 1
|
||||
steps:
|
||||
- name: checkout sources
|
||||
uses: actions/checkout@v3
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: install deps
|
||||
- name: Setup Node 18
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v2
|
||||
with:
|
||||
version: '8.12.1'
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: pnpm install --prefer-frozen-lockfile
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Build frontend
|
||||
run: pnpm run build
|
||||
working-directory: invokeai/frontend/web
|
||||
|
||||
- name: Install python dependencies
|
||||
run: pip install --upgrade build twine
|
||||
|
||||
- name: build package
|
||||
- name: Build python package
|
||||
run: python3 -m build
|
||||
|
||||
- name: check distribution
|
||||
- name: Upload build as workflow artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist
|
||||
path: dist
|
||||
|
||||
- name: Check distribution
|
||||
run: twine check dist/*
|
||||
|
||||
- name: check PyPI versions
|
||||
- name: Check PyPI versions
|
||||
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
|
||||
run: |
|
||||
pip install --upgrade requests
|
||||
@ -36,6 +62,6 @@ jobs:
|
||||
EXISTS=scripts.pypi_helper.local_on_pypi(); \
|
||||
print(f'PACKAGE_EXISTS={EXISTS}')" >> $GITHUB_ENV
|
||||
|
||||
- name: upload package
|
||||
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != ''
|
||||
- name: Publish build on PyPi
|
||||
if: env.PACKAGE_EXISTS == 'False' && env.TWINE_PASSWORD != '' && github.event.inputs.publish_package == 'true'
|
||||
run: twine upload dist/*
|
||||
|
@ -270,7 +270,7 @@ upgrade script.** See the next section for a Windows recipe.
|
||||
3. Select option [1] to upgrade to the latest release.
|
||||
|
||||
4. Once the upgrade is finished you will be returned to the launcher
|
||||
menu. Select option [7] "Re-run the configure script to fix a broken
|
||||
menu. Select option [6] "Re-run the configure script to fix a broken
|
||||
install or to complete a major upgrade".
|
||||
|
||||
This will run the configure script against the v2.3 directory and
|
||||
|
@ -59,14 +59,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
# #### Build the Web UI ------------------------------------
|
||||
|
||||
FROM node:18 AS web-builder
|
||||
FROM node:18-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
COPY invokeai/frontend/web/ ./
|
||||
RUN --mount=type=cache,target=/usr/lib/node_modules \
|
||||
npm install --include dev
|
||||
RUN --mount=type=cache,target=/usr/lib/node_modules \
|
||||
yarn vite build
|
||||
|
||||
RUN --mount=type=cache,target=/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
RUN pnpm run build
|
||||
|
||||
#### Runtime stage ---------------------------------------
|
||||
|
||||
|
@ -23,7 +23,7 @@ This is done via Docker Desktop preferences
|
||||
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
|
||||
a. the desired location of the InvokeAI runtime directory, or
|
||||
b. an existing, v3.0.0 compatible runtime directory.
|
||||
1. `docker compose up`
|
||||
1. Execute `run.sh`
|
||||
|
||||
The image will be built automatically if needed.
|
||||
|
||||
@ -39,7 +39,7 @@ The Docker daemon on the system must be already set up to use the GPU. In case o
|
||||
|
||||
## Customize
|
||||
|
||||
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `docker compose up`, your custom values will be used.
|
||||
Check the `.env.sample` file. It contains some environment variables for running in Docker. Copy it, name it `.env`, and fill it in with your own values. Next time you run `run.sh`, your custom values will be used.
|
||||
|
||||
You can also set these values in `docker-compose.yml` directly, but `.env` will help avoid conflicts when code is updated.
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
build_args=""
|
||||
|
||||
[[ -f ".env" ]] && build_args=$(awk '$1 ~ /\=[^$]/ {print "--build-arg " $0 " "}' .env)
|
||||
|
||||
echo "docker compose build args:"
|
||||
echo $build_args
|
||||
|
||||
docker compose build $build_args
|
@ -2,23 +2,8 @@
|
||||
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
invokeai:
|
||||
x-invokeai: &invokeai
|
||||
image: "local/invokeai:latest"
|
||||
# edit below to run on a container runtime other than nvidia-container-runtime.
|
||||
# not yet tested with rocm/AMD GPUs
|
||||
# Comment out the "deploy" section to run on CPU only
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
# For AMD support, comment out the deploy section above and uncomment the devices section below:
|
||||
#devices:
|
||||
# - /dev/kfd:/dev/kfd
|
||||
# - /dev/dri:/dev/dri
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile
|
||||
@ -50,3 +35,27 @@ services:
|
||||
# - |
|
||||
# invokeai-model-install --yes --default-only --config_file ${INVOKEAI_ROOT}/config_custom.yaml
|
||||
# invokeai-nodes-web --host 0.0.0.0
|
||||
|
||||
services:
|
||||
invokeai-nvidia:
|
||||
<<: *invokeai
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
invokeai-cpu:
|
||||
<<: *invokeai
|
||||
profiles:
|
||||
- cpu
|
||||
|
||||
invokeai-rocm:
|
||||
<<: *invokeai
|
||||
devices:
|
||||
- /dev/kfd:/dev/kfd
|
||||
- /dev/dri:/dev/dri
|
||||
profiles:
|
||||
- rocm
|
||||
|
@ -1,11 +1,28 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
# This script is provided for backwards compatibility with the old docker setup.
|
||||
# it doesn't do much aside from wrapping the usual docker compose CLI.
|
||||
run() {
|
||||
local scriptdir=$(dirname "${BASH_SOURCE[0]}")
|
||||
cd "$scriptdir" || exit 1
|
||||
|
||||
SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}")
|
||||
cd "$SCRIPTDIR" || exit 1
|
||||
local build_args=""
|
||||
local profile=""
|
||||
|
||||
docker compose up -d
|
||||
docker compose logs -f
|
||||
[[ -f ".env" ]] &&
|
||||
build_args=$(awk '$1 ~ /=[^$]/ && $0 !~ /^#/ {print "--build-arg " $0 " "}' .env) &&
|
||||
profile="$(awk -F '=' '/GPU_DRIVER/ {print $2}' .env)"
|
||||
|
||||
local service_name="invokeai-$profile"
|
||||
|
||||
printf "%s\n" "docker compose build args:"
|
||||
printf "%s\n" "$build_args"
|
||||
|
||||
docker compose build $build_args
|
||||
unset build_args
|
||||
|
||||
printf "%s\n" "starting service $service_name"
|
||||
docker compose --profile "$profile" up -d "$service_name"
|
||||
docker compose logs -f
|
||||
}
|
||||
|
||||
run
|
||||
|
277
docs/contributing/DOWNLOAD_QUEUE.md
Normal file
277
docs/contributing/DOWNLOAD_QUEUE.md
Normal file
@ -0,0 +1,277 @@
|
||||
# The InvokeAI Download Queue
|
||||
|
||||
The DownloadQueueService provides a multithreaded parallel download
|
||||
queue for arbitrary URLs, with queue prioritization, event handling,
|
||||
and restart capabilities.
|
||||
|
||||
## Simple Example
|
||||
|
||||
```
|
||||
from invokeai.app.services.download import DownloadQueueService, TqdmProgress
|
||||
|
||||
download_queue = DownloadQueueService()
|
||||
for url in ['https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true',
|
||||
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true',
|
||||
'https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png',
|
||||
'https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor',
|
||||
]:
|
||||
|
||||
# urls start downloading as soon as download() is called
|
||||
download_queue.download(source=url,
|
||||
dest='/tmp/downloads',
|
||||
on_progress=TqdmProgress().update
|
||||
)
|
||||
|
||||
download_queue.join() # wait for all downloads to finish
|
||||
for job in download_queue.list_jobs():
|
||||
print(job.model_dump_json(exclude_none=True, indent=4),"\n")
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
{
|
||||
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/a-painting-of-a-fire.png?raw=true",
|
||||
"dest": "/tmp/downloads",
|
||||
"id": 0,
|
||||
"priority": 10,
|
||||
"status": "completed",
|
||||
"download_path": "/tmp/downloads/a-painting-of-a-fire.png",
|
||||
"job_started": "2023-12-04T05:34:41.742174",
|
||||
"job_ended": "2023-12-04T05:34:42.592035",
|
||||
"bytes": 666734,
|
||||
"total_bytes": 666734
|
||||
}
|
||||
|
||||
{
|
||||
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/birdhouse.png?raw=true",
|
||||
"dest": "/tmp/downloads",
|
||||
"id": 1,
|
||||
"priority": 10,
|
||||
"status": "completed",
|
||||
"download_path": "/tmp/downloads/birdhouse.png",
|
||||
"job_started": "2023-12-04T05:34:41.741975",
|
||||
"job_ended": "2023-12-04T05:34:42.652841",
|
||||
"bytes": 774949,
|
||||
"total_bytes": 774949
|
||||
}
|
||||
|
||||
{
|
||||
"source": "https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/assets/missing.png",
|
||||
"dest": "/tmp/downloads",
|
||||
"id": 2,
|
||||
"priority": 10,
|
||||
"status": "error",
|
||||
"job_started": "2023-12-04T05:34:41.742079",
|
||||
"job_ended": "2023-12-04T05:34:42.147625",
|
||||
"bytes": 0,
|
||||
"total_bytes": 0,
|
||||
"error_type": "HTTPError(Not Found)",
|
||||
"error": "Traceback (most recent call last):\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 182, in _download_next_item\n self._do_download(job)\n File \"/home/lstein/Projects/InvokeAI/invokeai/app/services/download/download_default.py\", line 206, in _do_download\n raise HTTPError(resp.reason)\nrequests.exceptions.HTTPError: Not Found\n"
|
||||
}
|
||||
|
||||
{
|
||||
"source": "https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor",
|
||||
"dest": "/tmp/downloads",
|
||||
"id": 3,
|
||||
"priority": 10,
|
||||
"status": "completed",
|
||||
"download_path": "/tmp/downloads/xl_more_art-full_v1.safetensors",
|
||||
"job_started": "2023-12-04T05:34:42.147645",
|
||||
"job_ended": "2023-12-04T05:34:43.735990",
|
||||
"bytes": 719020768,
|
||||
"total_bytes": 719020768
|
||||
}
|
||||
```
|
||||
|
||||
## The API
|
||||
|
||||
The default download queue is `DownloadQueueService`, an
|
||||
implementation of ABC `DownloadQueueServiceBase`. It juggles multiple
|
||||
background download requests and provides facilities for interrogating
|
||||
and cancelling the requests. Access to a current or past download task
|
||||
is mediated via `DownloadJob` objects which report the current status
|
||||
of a job request
|
||||
|
||||
### The Queue Object
|
||||
|
||||
A default download queue is located in
|
||||
`ApiDependencies.invoker.services.download_queue`. However, you can
|
||||
create additional instances if you need to isolate your queue from the
|
||||
main one.
|
||||
|
||||
```
|
||||
queue = DownloadQueueService(event_bus=events)
|
||||
```
|
||||
|
||||
`DownloadQueueService()` takes three optional arguments:
|
||||
|
||||
| **Argument** | **Type** | **Default** | **Description** |
|
||||
|----------------|-----------------|---------------|-----------------|
|
||||
| `max_parallel_dl` | int | 5 | Maximum number of simultaneous downloads allowed |
|
||||
| `event_bus` | EventServiceBase | None | System-wide FastAPI event bus for reporting download events |
|
||||
| `requests_session` | requests.sessions.Session | None | An alternative requests Session object to use for the download |
|
||||
|
||||
`max_parallel_dl` specifies how many download jobs are allowed to run
|
||||
simultaneously. Each will run in a different thread of execution.
|
||||
|
||||
`event_bus` is an EventServiceBase, typically the one created at
|
||||
InvokeAI startup. If present, download events are periodically emitted
|
||||
on this bus to allow clients to follow download progress.
|
||||
|
||||
`requests_session` is a url library requests Session object. It is
|
||||
used for testing.
|
||||
|
||||
### The Job object
|
||||
|
||||
The queue operates on a series of download job objects. These objects
|
||||
specify the source and destination of the download, and keep track of
|
||||
the progress of the download.
|
||||
|
||||
The only job type currently implemented is `DownloadJob`, a pydantic object with the
|
||||
following fields:
|
||||
|
||||
| **Field** | **Type** | **Default** | **Description** |
|
||||
|----------------|-----------------|---------------|-----------------|
|
||||
| _Fields passed in at job creation time_ |
|
||||
| `source` | AnyHttpUrl | | Where to download from |
|
||||
| `dest` | Path | | Where to download to |
|
||||
| `access_token` | str | | [optional] string containing authentication token for access |
|
||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||
| `priority` | int | 10 | Job priority. Lower priorities run before higher priorities |
|
||||
| |
|
||||
| _Fields updated over the course of the download task_
|
||||
| `status` | DownloadJobStatus| | Status code |
|
||||
| `download_path` | Path | | Path to the location of the downloaded file |
|
||||
| `job_started` | float | | Timestamp for when the job started running |
|
||||
| `job_ended` | float | | Timestamp for when the job completed or errored out |
|
||||
| `job_sequence` | int | | A counter that is incremented each time a model is dequeued |
|
||||
| `bytes` | int | 0 | Bytes downloaded so far |
|
||||
| `total_bytes` | int | 0 | Total size of the file at the remote site |
|
||||
| `error_type` | str | | String version of the exception that caused an error during download |
|
||||
| `error` | str | | String version of the traceback associated with an error |
|
||||
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
|
||||
|
||||
When you create a job, you can assign it a `priority`. If multiple
|
||||
jobs are queued, the job with the lowest priority runs first.
|
||||
|
||||
Every job has a `source` and a `dest`. `source` is a pydantic.networks AnyHttpUrl object.
|
||||
The `dest` is a path on the local filesystem that specifies the
|
||||
destination for the downloaded object. Its semantics are
|
||||
described below.
|
||||
|
||||
When the job is submitted, it is assigned a numeric `id`. The id can
|
||||
then be used to fetch the job object from the queue.
|
||||
|
||||
The `status` field is updated by the queue to indicate where the job
|
||||
is in its lifecycle. Values are defined in the string enum
|
||||
`DownloadJobStatus`, a symbol available from
|
||||
`invokeai.app.services.download_manager`. Possible values are:
|
||||
|
||||
| **Value** | **String Value** | ** Description ** |
|
||||
|--------------|---------------------|-------------------|
|
||||
| `WAITING` | waiting | Job is on the queue but not yet running|
|
||||
| `RUNNING` | running | The download is started |
|
||||
| `COMPLETED` | completed | Job has finished its work without an error |
|
||||
| `ERROR` | error | Job encountered an error and will not run again|
|
||||
|
||||
`job_started` and `job_ended` indicate when the job
|
||||
was started (using a python timestamp) and when it completed.
|
||||
|
||||
In case of an error, the job's status will be set to `DownloadJobStatus.ERROR`, the text of the
|
||||
Exception that caused the error will be placed in the `error_type`
|
||||
field and the traceback that led to the error will be in `error`.
|
||||
|
||||
A cancelled job will have status `DownloadJobStatus.ERROR` and an
|
||||
`error_type` field of "DownloadJobCancelledException". In addition,
|
||||
the job's `cancelled` property will be set to True.
|
||||
|
||||
### Callbacks
|
||||
|
||||
Download jobs can be associated with a series of callbacks, each with
|
||||
the signature `Callable[["DownloadJob"], None]`. The callbacks are assigned
|
||||
using optional arguments `on_start`, `on_progress`, `on_complete` and
|
||||
`on_error`. When the corresponding event occurs, the callback wil be
|
||||
invoked and passed the job. The callback will be run in a `try:`
|
||||
context in the same thread as the download job. Any exceptions that
|
||||
occur during execution of the callback will be caught and converted
|
||||
into a log error message, thereby allowing the download to continue.
|
||||
|
||||
#### `TqdmProgress`
|
||||
|
||||
The `invokeai.app.services.download.download_default` module defines a
|
||||
class named `TqdmProgress` which can be used as an `on_progress`
|
||||
handler to display a completion bar in the console. Use as follows:
|
||||
|
||||
```
|
||||
from invokeai.app.services.download import TqdmProgress
|
||||
|
||||
download_queue.download(source='http://some.server.somewhere/some_file',
|
||||
dest='/tmp/downloads',
|
||||
on_progress=TqdmProgress().update
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Events
|
||||
|
||||
If the queue was initialized with the InvokeAI event bus (the case
|
||||
when using `ApiDependencies.invoker.services.download_queue`), then
|
||||
download events will also be issued on the bus. The events are:
|
||||
|
||||
* `download_started` -- This is issued when a job is taken off the
|
||||
queue and a request is made to the remote server for the URL headers, but before any data
|
||||
has been downloaded. The event payload will contain the keys `source`
|
||||
and `download_path`. The latter contains the path that the URL will be
|
||||
downloaded to.
|
||||
|
||||
* `download_progress -- This is issued periodically as the download
|
||||
runs. The payload contains the keys `source`, `download_path`,
|
||||
`current_bytes` and `total_bytes`. The latter two fields can be
|
||||
used to display the percent complete.
|
||||
|
||||
* `download_complete` -- This is issued when the download completes
|
||||
successfully. The payload contains the keys `source`, `download_path`
|
||||
and `total_bytes`.
|
||||
|
||||
* `download_error` -- This is issued when the download stops because
|
||||
of an error condition. The payload contains the fields `error_type`
|
||||
and `error`. The former is the text representation of the exception,
|
||||
and the latter is a traceback showing where the error occurred.
|
||||
|
||||
### Job control
|
||||
|
||||
To create a job call the queue's `download()` method. You can list all
|
||||
jobs using `list_jobs()`, fetch a single job by its with
|
||||
`id_to_job()`, cancel a running job with `cancel_job()`, cancel all
|
||||
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
||||
with `join()`.
|
||||
|
||||
#### job = queue.download(source, dest, priority, access_token)
|
||||
|
||||
Create a new download job and put it on the queue, returning the
|
||||
DownloadJob object.
|
||||
|
||||
#### jobs = queue.list_jobs()
|
||||
|
||||
Return a list of all active and inactive `DownloadJob`s.
|
||||
|
||||
#### job = queue.id_to_job(id)
|
||||
|
||||
Return the job corresponding to given ID.
|
||||
|
||||
Return a list of all active and inactive `DownloadJob`s.
|
||||
|
||||
#### queue.prune_jobs()
|
||||
|
||||
Remove inactive (complete or errored) jobs from the listing returned
|
||||
by `list_jobs()`.
|
||||
|
||||
#### queue.join()
|
||||
|
||||
Block until all pending jobs have run to completion or errored out.
|
||||
|
@ -15,7 +15,12 @@ model. These are the:
|
||||
their metadata, and `ModelRecordServiceBase` to store that
|
||||
information. It is also responsible for managing the InvokeAI
|
||||
`models` directory and its contents.
|
||||
|
||||
|
||||
* _ModelMetadataStore_ and _ModelMetaDataFetch_ Backend modules that
|
||||
are able to retrieve metadata from online model repositories,
|
||||
transform them into Pydantic models, and cache them to the InvokeAI
|
||||
SQL database.
|
||||
|
||||
* _DownloadQueueServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
A multithreaded downloader responsible
|
||||
for downloading models from a remote source to disk. The download
|
||||
@ -1184,3 +1189,248 @@ other resources that it might have been using.
|
||||
This will start/pause/cancel all jobs that have been submitted to the
|
||||
queue and have not yet reached a terminal state.
|
||||
|
||||
***
|
||||
|
||||
## This Meta be Good: Model Metadata Storage
|
||||
|
||||
The modules found under `invokeai.backend.model_manager.metadata`
|
||||
provide a straightforward API for fetching model metadatda from online
|
||||
repositories. Currently two repositories are supported: HuggingFace
|
||||
and Civitai. However, the modules are easily extended for additional
|
||||
repos, provided that they have defined APIs for metadata access.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in `invokeai.backend.model_manager.config`.
|
||||
|
||||
### Example Usage:
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CivitaiMetadata
|
||||
ModelMetadataStore,
|
||||
)
|
||||
# to access the initialized sql database
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
civitai = CivitaiMetadataFetch()
|
||||
|
||||
# fetch the metadata
|
||||
model_metadata = civitai.from_url("https://civitai.com/models/215796")
|
||||
|
||||
# get some common metadata fields
|
||||
author = model_metadata.author
|
||||
tags = model_metadata.tags
|
||||
|
||||
# get some Civitai-specific fields
|
||||
assert isinstance(model_metadata, CivitaiMetadata)
|
||||
|
||||
trained_words = model_metadata.trained_words
|
||||
base_model = model_metadata.base_model_trained_on
|
||||
thumbnail = model_metadata.thumbnail_url
|
||||
|
||||
# cache the metadata to the database using the key corresponding to
|
||||
# an existing model config record in the `model_config` table
|
||||
sql_cache = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
sql_cache.add_metadata('fb237ace520b6716adc98bcb16e8462c', model_metadata)
|
||||
|
||||
# now we can search the database by tag, author or model name
|
||||
# matches will contain a list of model keys that match the search
|
||||
matches = sql_cache.search_by_tag({"tool", "turbo"})
|
||||
```
|
||||
|
||||
### Structure of the Metadata objects
|
||||
|
||||
There is a short class hierarchy of Metadata objects, all of which
|
||||
descend from the Pydantic `BaseModel`.
|
||||
|
||||
#### `ModelMetadataBase`
|
||||
|
||||
This is the common base class for metadata:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `name` | str | Repository's name for the model |
|
||||
| `author` | str | Model's author |
|
||||
| `tags` | Set[str] | Model tags |
|
||||
|
||||
|
||||
Note that the model config record also has a `name` field. It is
|
||||
intended that the config record version be locally customizable, while
|
||||
the metadata version is read-only. However, enforcing this is expected
|
||||
to be part of the business logic.
|
||||
|
||||
Descendents of the base add additional fields.
|
||||
|
||||
#### `HuggingFaceMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["huggingface"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | str | HuggingFace repo_id |
|
||||
| `tag_dict` | Dict[str, Any] | A dictionary of tag/value pairs provided in addition to `tags` |
|
||||
| `last_modified`| datetime | Date of last commit of this model to the repo |
|
||||
| `files` | List[Path] | List of the files in the model repo |
|
||||
|
||||
|
||||
#### `CivitaiMetadata`
|
||||
|
||||
This descends from `ModelMetadataBase` and adds the following fields:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `type` | Literal["civitai"] | Used for the discriminated union of metadata classes|
|
||||
| `id` | int | Civitai model id |
|
||||
| `version_name` | str | Name of this version of the model (distinct from model name) |
|
||||
| `version_id` | int | Civitai model version id (distinct from model id) |
|
||||
| `created` | datetime | Date this version of the model was created |
|
||||
| `updated` | datetime | Date this version of the model was last updated |
|
||||
| `published` | datetime | Date this version of the model was published to Civitai |
|
||||
| `description` | str | Model description. Quite verbose and contains HTML tags |
|
||||
| `version_description` | str | Model version description, usually describes changes to the model |
|
||||
| `nsfw` | bool | Whether the model tends to generate NSFW content |
|
||||
| `restrictions` | LicenseRestrictions | An object that describes what is and isn't allowed with this model |
|
||||
| `trained_words`| Set[str] | Trigger words for this model, if any |
|
||||
| `download_url` | AnyHttpUrl | URL for downloading this version of the model |
|
||||
| `base_model_trained_on` | str | Name of the model that this version was trained on |
|
||||
| `thumbnail_url` | AnyHttpUrl | URL to access a representative thumbnail image of the model's output |
|
||||
| `weight_min` | int | For LoRA sliders, the minimum suggested weight to apply |
|
||||
| `weight_max` | int | For LoRA sliders, the maximum suggested weight to apply |
|
||||
|
||||
Note that `weight_min` and `weight_max` are not currently populated
|
||||
and take the default values of (-1.0, +2.0). The issue is that these
|
||||
values aren't part of the structured data but appear in the text
|
||||
description. Some regular expression or LLM coding may be able to
|
||||
extract these values.
|
||||
|
||||
Also be aware that `base_model_trained_on` is free text and doesn't
|
||||
correspond to our `ModelType` enum.
|
||||
|
||||
`CivitaiMetadata` also defines some convenience properties relating to
|
||||
licensing restrictions: `credit_required`, `allow_commercial_use`,
|
||||
`allow_derivatives` and `allow_different_license`.
|
||||
|
||||
#### `AnyModelRepoMetadata`
|
||||
|
||||
This is a discriminated Union of `CivitaiMetadata` and
|
||||
`HuggingFaceMetadata`.
|
||||
|
||||
### Fetching Metadata from Online Repos
|
||||
|
||||
The `HuggingFaceMetadataFetch` and `CivitaiMetadataFetch` classes will
|
||||
retrieve metadata from their corresponding repositories and return
|
||||
`AnyModelRepoMetadata` objects. Their base class
|
||||
`ModelMetadataFetchBase` is an abstract class that defines two
|
||||
methods: `from_url()` and `from_id()`. The former accepts the type of
|
||||
model URLs that the user will try to cut and paste into the model
|
||||
import form. The latter accepts a string ID in the format recognized
|
||||
by the repository of choice. Both methods return an
|
||||
`AnyModelRepoMetadata`.
|
||||
|
||||
The base class also has a class method `from_json()` which will take
|
||||
the JSON representation of a `ModelMetadata` object, validate it, and
|
||||
return the corresponding `AnyModelRepoMetadata` object.
|
||||
|
||||
When initializing one of the metadata fetching classes, you may
|
||||
provide a `requests.Session` argument. This allows you to customize
|
||||
the low-level HTTP fetch requests and is used, for instance, in the
|
||||
testing suite to avoid hitting the internet.
|
||||
|
||||
The HuggingFace and Civitai fetcher subclasses add additional
|
||||
repo-specific fetching methods:
|
||||
|
||||
|
||||
#### HuggingFaceMetadataFetch
|
||||
|
||||
This overrides its base class `from_json()` method to return a
|
||||
`HuggingFaceMetadata` object directly.
|
||||
|
||||
#### CivitaiMetadataFetch
|
||||
|
||||
This adds the following methods:
|
||||
|
||||
`from_civitai_modelid()` This takes the ID of a model, finds the
|
||||
default version of the model, and then retrieves the metadata for
|
||||
that version, returning a `CivitaiMetadata` object directly.
|
||||
|
||||
`from_civitai_versionid()` This takes the ID of a model version and
|
||||
retrieves its metadata. Functionally equivalent to `from_id()`, the
|
||||
only difference is that it returna a `CivitaiMetadata` object rather
|
||||
than an `AnyModelRepoMetadata`.
|
||||
|
||||
|
||||
### Metadata Storage
|
||||
|
||||
The `ModelMetadataStore` provides a simple facility to store model
|
||||
metadata in the `invokeai.db` database. The data is stored as a JSON
|
||||
blob, with a few common fields (`name`, `author`, `tags`) broken out
|
||||
to be searchable.
|
||||
|
||||
When a metadata object is saved to the database, it is identified
|
||||
using the model key, _and this key must correspond to an existing
|
||||
model key in the model_config table_. There is a foreign key integrity
|
||||
constraint between the `model_config.id` field and the
|
||||
`model_metadata.id` field such that if you attempt to save metadata
|
||||
under an unknown key, the attempt will result in an
|
||||
`UnknownModelException`. Likewise, when a model is deleted from
|
||||
`model_config`, the deletion of the corresponding metadata record will
|
||||
be triggered.
|
||||
|
||||
Tags are stored in a normalized fashion in the tables `model_tags` and
|
||||
`tags`. Triggers keep the tag table in sync with the `model_metadata`
|
||||
table.
|
||||
|
||||
To create the storage object, initialize it with the InvokeAI
|
||||
`SqliteDatabase` object. This is often done this way:
|
||||
|
||||
```
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
metadata_store = ModelMetadataStore(ApiDependencies.invoker.services.db)
|
||||
```
|
||||
|
||||
You can then access the storage with the following methods:
|
||||
|
||||
#### `add_metadata(key, metadata)`
|
||||
|
||||
Add the metadata using a previously-defined model key.
|
||||
|
||||
There is currently no `delete_metadata()` method. The metadata will
|
||||
persist until the matching config is deleted from the `model_config`
|
||||
table.
|
||||
|
||||
#### `get_metadata(key) -> AnyModelRepoMetadata`
|
||||
|
||||
Retrieve the metadata corresponding to the model key.
|
||||
|
||||
#### `update_metadata(key, new_metadata)`
|
||||
|
||||
Update an existing metadata record with new metadata.
|
||||
|
||||
#### `search_by_tag(tags: Set[str]) -> Set[str]`
|
||||
|
||||
Given a set of tags, find models that are tagged with them. If
|
||||
multiple tags are provided then a matching model must be tagged with
|
||||
*all* the tags in the set. This method returns a set of model keys and
|
||||
is intended to be used in conjunction with the `ModelRecordService`:
|
||||
|
||||
```
|
||||
model_config_store = ApiDependencies.invoker.services.model_records
|
||||
matches = metadata_store.search_by_tag({'license:other'})
|
||||
models = [model_config_store.get(x) for x in matches]
|
||||
```
|
||||
|
||||
#### `search_by_name(name: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given name and return a
|
||||
set of keys to the corresponding model config objects.
|
||||
|
||||
#### `search_by_author(author: str) -> Set[str]
|
||||
|
||||
Find all model metadata records that have the given author and return
|
||||
a set of keys to the corresponding model config objects.
|
||||
|
||||
|
@ -46,17 +46,18 @@ We encourage you to ping @psychedelicious and @blessedcoolant on [Discord](http
|
||||
```bash
|
||||
node --version
|
||||
```
|
||||
2. Install [yarn classic](https://classic.yarnpkg.com/lang/en/) and confirm it is installed by running this:
|
||||
|
||||
2. Install [pnpm](https://pnpm.io/) and confirm it is installed by running this:
|
||||
```bash
|
||||
npm install --global yarn
|
||||
yarn --version
|
||||
npm install --global pnpm
|
||||
pnpm --version
|
||||
```
|
||||
|
||||
From `invokeai/frontend/web/` run `yarn install` to get everything set up.
|
||||
From `invokeai/frontend/web/` run `pnpm install` to get everything set up.
|
||||
|
||||
Start everything in dev mode:
|
||||
1. Ensure your virtual environment is running
|
||||
2. Start the dev server: `yarn dev`
|
||||
2. Start the dev server: `pnpm dev`
|
||||
3. Start the InvokeAI Nodes backend: `python scripts/invokeai-web.py # run from the repo root`
|
||||
4. Point your browser to the dev server address e.g. [http://localhost:5173/](http://localhost:5173/)
|
||||
|
||||
@ -72,4 +73,4 @@ For a number of technical and logistical reasons, we need to commit UI build art
|
||||
|
||||
If you submit a PR, there is a good chance we will ask you to include a separate commit with a build of the app.
|
||||
|
||||
To build for production, run `yarn build`.
|
||||
To build for production, run `pnpm build`.
|
||||
|
@ -13,6 +13,7 @@ If you'd prefer, you can also just download the whole node folder from the linke
|
||||
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
|
||||
|
||||
- Community Nodes
|
||||
+ [Adapters-Linked](#adapters-linked-nodes)
|
||||
+ [Average Images](#average-images)
|
||||
+ [Clean Image Artifacts After Cut](#clean-image-artifacts-after-cut)
|
||||
+ [Close Color Mask](#close-color-mask)
|
||||
@ -32,8 +33,9 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Image Resize Plus](#image-resize-plus)
|
||||
+ [Load Video Frame](#load-video-frame)
|
||||
+ [Make 3D](#make-3d)
|
||||
+ [Mask Operations](#mask-operations)
|
||||
+ [Mask Operations](#mask-operations)
|
||||
+ [Match Histogram](#match-histogram)
|
||||
+ [Metadata-Linked](#metadata-linked-nodes)
|
||||
+ [Negative Image](#negative-image)
|
||||
+ [Oobabooga](#oobabooga)
|
||||
+ [Prompt Tools](#prompt-tools)
|
||||
@ -51,6 +53,19 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
- [Help](#help)
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Adapters Linked Nodes
|
||||
|
||||
**Description:** A set of nodes for linked adapters (ControlNet, IP-Adaptor & T2I-Adapter). This allows multiple adapters to be chained together without using a `collect` node which means it can be used inside an `iterate` node without any collecting on every iteration issues.
|
||||
|
||||
- `ControlNet-Linked` - Collects ControlNet info to pass to other nodes.
|
||||
- `IP-Adapter-Linked` - Collects IP-Adapter info to pass to other nodes.
|
||||
- `T2I-Adapter-Linked` - Collects T2I-Adapter info to pass to other nodes.
|
||||
|
||||
Note: These are inherited from the core nodes so any update to the core nodes should be reflected in these.
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/adapters-linked-nodes
|
||||
|
||||
--------------------------------
|
||||
### Average Images
|
||||
|
||||
@ -307,6 +322,20 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
|
||||
|
||||
--------------------------------
|
||||
### Metadata Linked Nodes
|
||||
|
||||
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
|
||||
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
|
||||
- `Metadata From Image` - Provides Metadata from an image.
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata.
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata.
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
|
||||
|
||||
--------------------------------
|
||||
### Negative Image
|
||||
|
||||
|
@ -91,9 +91,11 @@ rm -rf InvokeAI-Installer
|
||||
|
||||
# copy content
|
||||
mkdir InvokeAI-Installer
|
||||
for f in templates lib *.txt *.reg; do
|
||||
for f in templates *.txt *.reg; do
|
||||
cp -r ${f} InvokeAI-Installer/
|
||||
done
|
||||
mkdir InvokeAI-Installer/lib
|
||||
cp lib/*.py InvokeAI-Installer/lib
|
||||
|
||||
# Move the wheel
|
||||
mv dist/*.whl InvokeAI-Installer/lib/
|
||||
@ -111,6 +113,6 @@ cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
zip -r InvokeAI-installer-$VERSION.zip InvokeAI-Installer
|
||||
|
||||
# clean up
|
||||
rm -rf InvokeAI-Installer tmp dist
|
||||
rm -rf InvokeAI-Installer tmp dist ../invokeai/frontend/web/dist/
|
||||
|
||||
exit 0
|
||||
|
@ -11,6 +11,7 @@ from ..services.board_images.board_images_default import BoardImagesService
|
||||
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from ..services.boards.boards_default import BoardService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
@ -29,8 +30,7 @@ from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.graph import GraphExecutionState
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
@ -80,13 +80,13 @@ class ApiDependencies:
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_install_service = ModelInstallService(
|
||||
app_config=config, record_store=model_record_service, event_bus=events
|
||||
)
|
||||
@ -107,7 +107,6 @@ class ApiDependencies:
|
||||
configuration=configuration,
|
||||
events=events,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
graph_library=graph_library,
|
||||
image_files=image_files,
|
||||
image_records=image_records,
|
||||
images=images,
|
||||
@ -116,6 +115,7 @@ class ApiDependencies:
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
model_install=model_install_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
@ -127,8 +127,6 @@ class ApiDependencies:
|
||||
workflow_records=workflow_records,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
db.clean()
|
||||
|
||||
|
111
invokeai/app/api/routers/download_queue.py
Normal file
111
invokeai/app/api/routers/download_queue.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for the download queue."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Body, Path, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.app.services.download import (
|
||||
DownloadJob,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
download_queue_router = APIRouter(prefix="/v1/download_queue", tags=["download_queue"])
|
||||
|
||||
|
||||
@download_queue_router.get(
|
||||
"/",
|
||||
operation_id="list_downloads",
|
||||
)
|
||||
async def list_downloads() -> List[DownloadJob]:
|
||||
"""Get a list of active and inactive jobs."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
return queue.list_jobs()
|
||||
|
||||
|
||||
@download_queue_router.patch(
|
||||
"/",
|
||||
operation_id="prune_downloads",
|
||||
responses={
|
||||
204: {"description": "All completed jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_downloads():
|
||||
"""Prune completed and errored jobs."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
queue.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@download_queue_router.post(
|
||||
"/i/",
|
||||
operation_id="download",
|
||||
)
|
||||
async def download(
|
||||
source: AnyHttpUrl = Body(description="download source"),
|
||||
dest: str = Body(description="download destination"),
|
||||
priority: int = Body(default=10, description="queue priority"),
|
||||
access_token: Optional[str] = Body(default=None, description="token for authorization to download"),
|
||||
) -> DownloadJob:
|
||||
"""Download the source URL to the file or directory indicted in dest."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
return queue.download(source, dest, priority, access_token)
|
||||
|
||||
|
||||
@download_queue_router.get(
|
||||
"/i/{id}",
|
||||
operation_id="get_download_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "The requested download JobID could not be found"},
|
||||
},
|
||||
)
|
||||
async def get_download_job(
|
||||
id: int = Path(description="ID of the download job to fetch."),
|
||||
) -> DownloadJob:
|
||||
"""Get a download job using its ID."""
|
||||
try:
|
||||
job = ApiDependencies.invoker.services.download_queue.id_to_job(id)
|
||||
return job
|
||||
except UnknownJobIDException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@download_queue_router.delete(
|
||||
"/i/{id}",
|
||||
operation_id="cancel_download_job",
|
||||
responses={
|
||||
204: {"description": "Job has been cancelled"},
|
||||
404: {"description": "The requested download JobID could not be found"},
|
||||
},
|
||||
)
|
||||
async def cancel_download_job(
|
||||
id: int = Path(description="ID of the download job to cancel."),
|
||||
):
|
||||
"""Cancel a download job using its ID."""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
job = queue.id_to_job(id)
|
||||
queue.cancel_job(job)
|
||||
return Response(status_code=204)
|
||||
except UnknownJobIDException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@download_queue_router.delete(
|
||||
"/i",
|
||||
operation_id="cancel_all_download_jobs",
|
||||
responses={
|
||||
204: {"description": "Download jobs have been cancelled"},
|
||||
},
|
||||
)
|
||||
async def cancel_all_download_jobs():
|
||||
"""Cancel all download jobs."""
|
||||
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
@ -26,7 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2"])
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
|
@ -45,6 +45,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
app_info,
|
||||
board_images,
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_records,
|
||||
models,
|
||||
@ -116,6 +117,7 @@ app.include_router(sessions.session_router, prefix="/api")
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
app.include_router(model_records.model_records_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
@ -1,4 +1,3 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@ -17,6 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -87,7 +87,7 @@ class CompelInvocation(BaseInvocation):
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
@ -210,7 +210,7 @@ class SDXLPromptInvocationBase:
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
import inspect
|
||||
import re
|
||||
|
||||
# from contextlib import ExitStack
|
||||
from typing import List, Literal, Union
|
||||
@ -21,6 +20,7 @@ from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -78,7 +78,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
|
@ -77,7 +77,7 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
||||
title="Calculate Image Tiles Even Split",
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
@ -97,11 +97,11 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
ge=1,
|
||||
description="Number of tiles to divide image into on the y axis",
|
||||
)
|
||||
overlap_fraction: float = InputField(
|
||||
default=0.25,
|
||||
overlap: int = InputField(
|
||||
default=128,
|
||||
ge=0,
|
||||
lt=1,
|
||||
description="Overlap between adjacent tiles as a fraction of the tile's dimensions (0-1)",
|
||||
multiple_of=8,
|
||||
description="The overlap, in pixels, between adjacent tiles.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||
@ -110,7 +110,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
image_width=self.image_width,
|
||||
num_tiles_x=self.num_tiles_x,
|
||||
num_tiles_y=self.num_tiles_y,
|
||||
overlap_fraction=self.overlap_fraction,
|
||||
overlap=self.overlap,
|
||||
)
|
||||
return CalculateImageTilesOutput(tiles=tiles)
|
||||
|
||||
|
@ -356,7 +356,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
else:
|
||||
root = self.find_root().expanduser().absolute()
|
||||
self.root = root # insulate ourselves from relative paths that may change
|
||||
return root
|
||||
return root.resolve()
|
||||
|
||||
@property
|
||||
def root_dir(self) -> Path:
|
||||
|
12
invokeai/app/services/download/__init__.py
Normal file
12
invokeai/app/services/download/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Init file for download queue."""
|
||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
||||
from .download_default import DownloadQueueService, TqdmProgress
|
||||
|
||||
__all__ = [
|
||||
"DownloadJob",
|
||||
"DownloadQueueServiceBase",
|
||||
"DownloadQueueService",
|
||||
"TqdmProgress",
|
||||
"DownloadJobStatus",
|
||||
"UnknownJobIDException",
|
||||
]
|
217
invokeai/app/services/download/download_base.py
Normal file
217
invokeai/app/services/download/download_base.py
Normal file
@ -0,0 +1,217 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Model download service."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
|
||||
WAITING = "waiting" # not enqueued, will not run
|
||||
RUNNING = "running" # actively downloading
|
||||
COMPLETED = "completed" # finished running
|
||||
CANCELLED = "cancelled" # user cancelled
|
||||
ERROR = "error" # terminated with an error message
|
||||
|
||||
|
||||
class DownloadJobCancelledException(Exception):
|
||||
"""This exception is raised when a download job is cancelled."""
|
||||
|
||||
|
||||
class UnknownJobIDException(Exception):
|
||||
"""This exception is raised when an invalid job id is referened."""
|
||||
|
||||
|
||||
class ServiceInactiveException(Exception):
|
||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJob"], None]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJob(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
# required variables to be passed in on creation
|
||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
|
||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||
# automatically assigned on creation
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
|
||||
# set internally during download process
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
|
||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||
job_ended: Optional[str] = Field(
|
||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||
)
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
||||
|
||||
# set when an error occurs
|
||||
error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error")
|
||||
error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error")
|
||||
|
||||
# internal flag
|
||||
_cancelled: bool = PrivateAttr(default=False)
|
||||
|
||||
# optional event handlers passed in on creation
|
||||
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_error: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
|
||||
def __le__(self, other: "DownloadJob") -> bool:
|
||||
"""Return True if this job's priority is less than another's."""
|
||||
return self.priority <= other.priority
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self._cancelled = True
|
||||
|
||||
# cancelled and the callbacks are private attributes in order to prevent
|
||||
# them from being serialized and/or used in the Json Schema
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Call to cancel the job."""
|
||||
return self._cancelled
|
||||
|
||||
@property
|
||||
def on_start(self) -> Optional[DownloadEventHandler]:
|
||||
"""Return the on_start event handler."""
|
||||
return self._on_start
|
||||
|
||||
@property
|
||||
def on_progress(self) -> Optional[DownloadEventHandler]:
|
||||
"""Return the on_progress event handler."""
|
||||
return self._on_progress
|
||||
|
||||
@property
|
||||
def on_complete(self) -> Optional[DownloadEventHandler]:
|
||||
"""Return the on_complete event handler."""
|
||||
return self._on_complete
|
||||
|
||||
@property
|
||||
def on_error(self) -> Optional[DownloadEventHandler]:
|
||||
"""Return the on_error event handler."""
|
||||
return self._on_error
|
||||
|
||||
@property
|
||||
def on_cancelled(self) -> Optional[DownloadEventHandler]:
|
||||
"""Return the on_cancelled event handler."""
|
||||
return self._on_cancelled
|
||||
|
||||
def set_callbacks(
|
||||
self,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadEventHandler] = None,
|
||||
) -> None:
|
||||
"""Set the callbacks for download events."""
|
||||
self._on_start = on_start
|
||||
self._on_progress = on_progress
|
||||
self._on_complete = on_complete
|
||||
self._on_error = on_error
|
||||
self._on_cancelled = on_cancelled
|
||||
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
"""Multithreaded queue for downloading models via URL."""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Start the download worker threads."""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Stop the download worker threads."""
|
||||
|
||||
@abstractmethod
|
||||
def download(
|
||||
self,
|
||||
source: AnyHttpUrl,
|
||||
dest: Path,
|
||||
priority: int = 10,
|
||||
access_token: Optional[str] = None,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadEventHandler] = None,
|
||||
) -> DownloadJob:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param source: Source of the download as a URL.
|
||||
:param dest: Path to download to. See below.
|
||||
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||
events.
|
||||
:returns: A DownloadJob object for monitoring the state of the download.
|
||||
|
||||
The `dest` argument is a Path object. Its behavior is:
|
||||
|
||||
1. If the path exists and is a directory, then the URL contents will be downloaded
|
||||
into that directory using the filename indicated in the response's `Content-Disposition` field.
|
||||
If no content-disposition is present, then the last component of the URL will be used (similar to
|
||||
wget's behavior).
|
||||
2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded
|
||||
content.
|
||||
3. If the path exists and is an existing file, then the downloader will try to resume the download from
|
||||
the end of the existing file.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""
|
||||
List active download jobs.
|
||||
|
||||
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> DownloadJob:
|
||||
"""
|
||||
Return the DownloadJob corresponding to the integer ID.
|
||||
|
||||
:param id: ID of the DownloadJob.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJob):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def join(self):
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
418
invokeai/app/services/download/download_default.py
Normal file
418
invokeai/app/services/download/download_default.py
Normal file
@ -0,0 +1,418 @@
|
||||
# Copyright (c) 2023, Lincoln D. Stein
|
||||
"""Implementation of multithreaded download queue for invokeai."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import traceback
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .download_base import (
|
||||
DownloadEventHandler,
|
||||
DownloadJob,
|
||||
DownloadJobCancelledException,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueServiceBase,
|
||||
ServiceInactiveException,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
|
||||
class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Class for queued download of models."""
|
||||
|
||||
_jobs: Dict[int, DownloadJob]
|
||||
_max_parallel_dl: int = 5
|
||||
_worker_pool: Set[threading.Thread]
|
||||
_queue: PriorityQueue[DownloadJob]
|
||||
_stop_event: threading.Event
|
||||
_lock: threading.Lock
|
||||
_logger: Logger
|
||||
_events: Optional[EventServiceBase] = None
|
||||
_next_job_id: int = 0
|
||||
_accept_download_requests: bool = False
|
||||
_requests: requests.sessions.Session
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._jobs = {}
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._stop_event = threading.Event()
|
||||
self._worker_pool = set()
|
||||
self._lock = threading.Lock()
|
||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||
self._event_bus = event_bus
|
||||
self._requests = requests_session or requests.Session()
|
||||
self._accept_download_requests = False
|
||||
self._max_parallel_dl = max_parallel_dl
|
||||
|
||||
def start(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Start the download worker threads."""
|
||||
with self._lock:
|
||||
if self._worker_pool:
|
||||
raise Exception("Attempt to start the download service twice")
|
||||
self._stop_event.clear()
|
||||
self._start_workers(self._max_parallel_dl)
|
||||
self._accept_download_requests = True
|
||||
|
||||
def stop(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Stop the download worker threads."""
|
||||
with self._lock:
|
||||
if not self._worker_pool:
|
||||
raise Exception("Attempt to stop the download service before it was started")
|
||||
self._accept_download_requests = False # reject attempts to add new jobs to queue
|
||||
queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING]
|
||||
active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING]
|
||||
if queued_jobs:
|
||||
self._logger.warning(f"Cancelling {len(queued_jobs)} queued downloads")
|
||||
if active_jobs:
|
||||
self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete")
|
||||
with self._queue.mutex:
|
||||
self._queue.queue.clear()
|
||||
self.join() # wait for all active jobs to finish
|
||||
self._stop_event.set()
|
||||
self._worker_pool.clear()
|
||||
|
||||
def download(
|
||||
self,
|
||||
source: AnyHttpUrl,
|
||||
dest: Path,
|
||||
priority: int = 10,
|
||||
access_token: Optional[str] = None,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadEventHandler] = None,
|
||||
) -> DownloadJob:
|
||||
"""Create a download job and return its ID."""
|
||||
if not self._accept_download_requests:
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
with self._lock:
|
||||
id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
job = DownloadJob(
|
||||
id=id,
|
||||
source=source,
|
||||
dest=dest,
|
||||
priority=priority,
|
||||
access_token=access_token,
|
||||
)
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[id] = job
|
||||
self._queue.put(job)
|
||||
return job
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for all jobs to complete."""
|
||||
self._queue.join()
|
||||
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""List all the jobs."""
|
||||
return list(self._jobs.values())
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
with self._lock:
|
||||
to_delete = set()
|
||||
for job_id, job in self._jobs.items():
|
||||
if self._in_terminal_state(job):
|
||||
to_delete.add(job_id)
|
||||
for job_id in to_delete:
|
||||
del self._jobs[job_id]
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJob:
|
||||
"""Translate a job ID into a DownloadJob object."""
|
||||
try:
|
||||
return self._jobs[id]
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
If it is running it will be stopped.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
with self._lock:
|
||||
job.cancel()
|
||||
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False) -> None:
|
||||
"""Cancel all jobs (those not in enqueued, running or paused state)."""
|
||||
for job in self._jobs.values():
|
||||
if not self._in_terminal_state(job):
|
||||
self.cancel_job(job)
|
||||
|
||||
def _in_terminal_state(self, job: DownloadJob) -> bool:
|
||||
return job.status in [
|
||||
DownloadJobStatus.COMPLETED,
|
||||
DownloadJobStatus.CANCELLED,
|
||||
DownloadJobStatus.ERROR,
|
||||
]
|
||||
|
||||
def _start_workers(self, max_workers: int) -> None:
|
||||
"""Start the requested number of worker threads."""
|
||||
self._stop_event.clear()
|
||||
for i in range(0, max_workers): # noqa B007
|
||||
worker = threading.Thread(target=self._download_next_item, daemon=True)
|
||||
self._logger.debug(f"Download queue worker thread {worker.name} starting.")
|
||||
worker.start()
|
||||
self._worker_pool.add(worker)
|
||||
|
||||
def _download_next_item(self) -> None:
|
||||
"""Worker thread gets next job on priority queue."""
|
||||
done = False
|
||||
while not done:
|
||||
if self._stop_event.is_set():
|
||||
done = True
|
||||
continue
|
||||
try:
|
||||
job = self._queue.get(timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
try:
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
self._signal_job_complete(job)
|
||||
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
self._signal_job_error(job)
|
||||
except DownloadJobCancelledException:
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._queue.task_done()
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
|
||||
def _do_download(self, job: DownloadJob) -> None:
|
||||
"""Do the actual download."""
|
||||
url = job.source
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
|
||||
# Make a streaming request. This will retrieve headers including
|
||||
# content-length and content-disposition, but not fetch any content itself
|
||||
resp = self._requests.get(str(url), headers=header, stream=True)
|
||||
if not resp.ok:
|
||||
raise HTTPError(resp.reason)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
if job.dest.is_dir():
|
||||
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
|
||||
|
||||
if match := re.search('filename="(.+)"', resp.headers.get("Content-Disposition", "")):
|
||||
remote_name = match.group(1)
|
||||
if self._validate_filename(job.dest.as_posix(), remote_name):
|
||||
file_name = remote_name
|
||||
|
||||
job.download_path = job.dest / file_name
|
||||
|
||||
else:
|
||||
job.dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
job.download_path = job.dest
|
||||
|
||||
assert job.download_path
|
||||
|
||||
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
|
||||
# for code that instead resumes an interrupted download.
|
||||
if job.download_path.exists():
|
||||
raise OSError(f"[Errno 17] File {job.download_path} exists")
|
||||
|
||||
# append ".downloading" to the path
|
||||
in_progress_path = self._in_progress_path(job.download_path)
|
||||
|
||||
# signal caller that the download is starting. At this point, key fields such as
|
||||
# download_path and total_bytes will be populated. We call it here because the might
|
||||
# discover that the local file is already complete and generate a COMPLETED status.
|
||||
self._signal_job_started(job)
|
||||
|
||||
# "range not satisfiable" - local file is at least as large as the remote file
|
||||
if resp.status_code == 416 or (content_length > 0 and job.bytes >= content_length):
|
||||
self._logger.warning(f"{job.download_path}: complete file found. Skipping.")
|
||||
return
|
||||
|
||||
# "partial content" - local file is smaller than remote file
|
||||
elif resp.status_code == 206 or job.bytes > 0:
|
||||
self._logger.warning(f"{job.download_path}: partial file found. Resuming")
|
||||
|
||||
# some other error
|
||||
elif resp.status_code != 200:
|
||||
raise HTTPError(resp.reason)
|
||||
|
||||
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
|
||||
# DOWNLOAD LOOP
|
||||
with open(in_progress_path, open_mode) as file:
|
||||
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
||||
if job.cancelled:
|
||||
raise DownloadJobCancelledException("Job was cancelled at caller's request")
|
||||
|
||||
job.bytes += file.write(data)
|
||||
if (job.bytes - last_report_bytes >= report_delta) or (job.bytes >= job.total_bytes):
|
||||
last_report_bytes = job.bytes
|
||||
self._signal_job_progress(job)
|
||||
|
||||
# if we get here we are done and can rename the file to the original dest
|
||||
in_progress_path.rename(job.download_path)
|
||||
|
||||
def _validate_filename(self, directory: str, filename: str) -> bool:
|
||||
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows
|
||||
pc_path_max = (
|
||||
os.pathconf(directory, "PC_PATH_MAX") if hasattr(os, "pathconf") else 32767
|
||||
) # hardcoded for windows with long names enabled
|
||||
if "/" in filename:
|
||||
return False
|
||||
if filename.startswith(".."):
|
||||
return False
|
||||
if len(filename) > pc_name_max:
|
||||
return False
|
||||
if len(os.path.join(directory, filename)) > pc_path_max:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _in_progress_path(self, path: Path) -> Path:
|
||||
return path.with_name(path.name + ".downloading")
|
||||
|
||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.RUNNING
|
||||
if job.on_start:
|
||||
try:
|
||||
job.on_start(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
try:
|
||||
job.on_progress(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_progress(
|
||||
str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
if job.on_complete:
|
||||
try:
|
||||
job.on_complete(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
if self._event_bus:
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_complete(
|
||||
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
|
||||
)
|
||||
|
||||
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.CANCELLED
|
||||
if job.on_cancelled:
|
||||
try:
|
||||
job.on_cancelled(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(str(job.source))
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
if job.on_error:
|
||||
try:
|
||||
job.on_error(job)
|
||||
except Exception as e:
|
||||
self._logger.error(e)
|
||||
if self._event_bus:
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
||||
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
try:
|
||||
if job.download_path:
|
||||
partial_file = self._in_progress_path(job.download_path)
|
||||
partial_file.unlink()
|
||||
except OSError as excp:
|
||||
self._logger.warning(excp)
|
||||
|
||||
|
||||
# Example on_progress event handler to display a TQDM status bar
|
||||
# Activate with:
|
||||
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
|
||||
class TqdmProgress(object):
|
||||
"""TQDM-based progress bar object to use in on_progress handlers."""
|
||||
|
||||
_bars: Dict[int, tqdm] # the tqdm object
|
||||
_last: Dict[int, int] # last bytes downloaded
|
||||
|
||||
def __init__(self) -> None: # noqa D107
|
||||
self._bars = {}
|
||||
self._last = {}
|
||||
|
||||
def update(self, job: DownloadJob) -> None: # noqa D102
|
||||
job_id = job.id
|
||||
# new job
|
||||
if job_id not in self._bars:
|
||||
assert job.download_path
|
||||
dest = Path(job.download_path).name
|
||||
self._bars[job_id] = tqdm(
|
||||
desc=dest,
|
||||
initial=0,
|
||||
total=job.total_bytes,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
)
|
||||
self._last[job_id] = 0
|
||||
self._bars[job_id].update(job.bytes - self._last[job_id])
|
||||
self._last[job_id] = job.bytes
|
@ -17,6 +17,7 @@ from invokeai.backend.model_management.models.base import BaseModelType, ModelTy
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
download_event: str = "download_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
@ -32,6 +33,13 @@ class EventServiceBase:
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_download_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
@ -323,6 +331,79 @@ class EventServiceBase:
|
||||
payload={"queue_id": queue_id},
|
||||
)
|
||||
|
||||
def emit_download_started(self, source: str, download_path: str) -> None:
|
||||
"""
|
||||
Emit when a download job is started.
|
||||
|
||||
:param url: The downloaded url
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_started",
|
||||
payload={"source": source, "download_path": download_path},
|
||||
)
|
||||
|
||||
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit "download_progress" events at regular intervals during a download job.
|
||||
|
||||
:param source: The downloaded source
|
||||
:param download_path: The local downloaded file
|
||||
:param current_bytes: Number of bytes downloaded so far
|
||||
:param total_bytes: The size of the file being downloaded (if known)
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_progress",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"current_bytes": current_bytes,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit a "download_complete" event at the end of a successful download.
|
||||
|
||||
:param source: Source URL
|
||||
:param download_path: Path to the locally downloaded file
|
||||
:param total_bytes: The size of the downloaded file
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_complete",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_download_cancelled(self, source: str) -> None:
|
||||
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
|
||||
self.__emit_download_event(
|
||||
event_name="download_cancelled",
|
||||
payload={
|
||||
"source": source,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
|
||||
"""
|
||||
Emit a "download_error" event when an download job encounters an exception.
|
||||
|
||||
:param source: Source URL
|
||||
:param error_type: The name of the exception that raised the error
|
||||
:param error: The traceback from this error
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_error",
|
||||
payload={
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_started(self, source: str) -> None:
|
||||
"""
|
||||
Emitted when an install job is started.
|
||||
|
@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||
from .board_records.board_records_base import BoardRecordStorageBase
|
||||
from .boards.boards_base import BoardServiceABC
|
||||
from .config import InvokeAIAppConfig
|
||||
from .download import DownloadQueueServiceBase
|
||||
from .events.events_base import EventServiceBase
|
||||
from .image_files.image_files_base import ImageFileStorageBase
|
||||
from .image_records.image_records_base import ImageRecordStorageBase
|
||||
@ -27,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||
from .shared.graph import GraphExecutionState
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
|
||||
@ -43,7 +44,6 @@ class InvocationServices:
|
||||
configuration: "InvokeAIAppConfig"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||
images: "ImageServiceABC"
|
||||
image_records: "ImageRecordStorageBase"
|
||||
image_files: "ImageFileStorageBase"
|
||||
@ -51,6 +51,7 @@ class InvocationServices:
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
model_install: "ModelInstallServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
@ -71,7 +72,6 @@ class InvocationServices:
|
||||
configuration: "InvokeAIAppConfig",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||
images: "ImageServiceABC",
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
@ -79,6 +79,7 @@ class InvocationServices:
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
@ -97,7 +98,6 @@ class InvocationServices:
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.graph_library = graph_library
|
||||
self.images = images
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
@ -105,6 +105,7 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.download_queue = download_queue
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
|
@ -11,7 +11,6 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
@ -157,12 +156,12 @@ class ModelInstallServiceBase(ABC):
|
||||
:param event_bus: InvokeAI event bus for reporting events to.
|
||||
"""
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Call at InvokeAI startup time."""
|
||||
self.sync_to_config()
|
||||
@abstractmethod
|
||||
def start(self, *args: Any, **kwarg: Any) -> None:
|
||||
"""Start the installer service."""
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
def stop(self, *args: Any, **kwarg: Any) -> None:
|
||||
"""Stop the model install service. After this the objection can be safely deleted."""
|
||||
|
||||
@property
|
||||
|
@ -71,7 +71,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_queue = Queue()
|
||||
self._cached_model_paths = set()
|
||||
self._models_installed = set()
|
||||
self._start_installer_thread()
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@ -85,8 +84,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
||||
def start(self, *args: Any, **kwarg: Any) -> None:
|
||||
"""Start the installer thread."""
|
||||
self._start_installer_thread()
|
||||
self.sync_to_config()
|
||||
|
||||
def stop(self, *args: Any, **kwarg: Any) -> None:
|
||||
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
|
||||
self._install_queue.put(STOP_JOB)
|
||||
|
||||
def _start_installer_thread(self) -> None:
|
||||
|
@ -5,6 +5,8 @@ from invokeai.app.services.image_files.image_files_base import ImageFileStorageB
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -27,6 +29,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrator.register_migration(build_migration_1())
|
||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -23,7 +23,6 @@ class Migration2Callback:
|
||||
self._drop_old_workflow_tables(cursor)
|
||||
self._add_workflow_library(cursor)
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_embedded_workflows(cursor)
|
||||
|
||||
def _add_images_has_workflow(self, cursor: sqlite3.Cursor) -> None:
|
||||
@ -97,40 +96,6 @@ class Migration2Callback:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Drops the `model_config` table, recreating it.
|
||||
|
||||
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
||||
|
||||
Because this table is not used in production, we are able to simply drop it and recreate it.
|
||||
"""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_embedded_workflows(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
In the v3.5.0 release, InvokeAI changed how it handles embedded workflows. The `images` table in
|
||||
|
@ -0,0 +1,79 @@
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
from .util.migrate_yaml_config_1 import MigrateModelYamlToDb1
|
||||
|
||||
|
||||
class Migration3Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_model_manager_metadata(cursor)
|
||||
self._recreate_model_config(cursor)
|
||||
self._migrate_model_config_records(cursor)
|
||||
|
||||
def _drop_model_manager_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `model_manager_metadata` table."""
|
||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||
|
||||
def _recreate_model_config(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Drops the `model_config` table, recreating it.
|
||||
|
||||
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
||||
|
||||
Because this table is not used in production, we are able to simply drop it and recreate it.
|
||||
"""
|
||||
|
||||
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_config (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
original_hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _migrate_model_config_records(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""After updating the model config table, we repopulate it."""
|
||||
self._logger.info("Migrating model config records from models.yaml to database")
|
||||
model_record_migrator = MigrateModelYamlToDb1(self._app_config, self._logger, cursor)
|
||||
model_record_migrator.migrate()
|
||||
|
||||
|
||||
def build_migration_3(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 2 to 3.
|
||||
|
||||
This migration does the following:
|
||||
- Drops the `model_config` table, recreating it
|
||||
- Migrates data from `models.yaml` into the `model_config` table
|
||||
"""
|
||||
migration_3 = Migration(
|
||||
from_version=2,
|
||||
to_version=3,
|
||||
callback=Migration3Callback(app_config=app_config, logger=logger),
|
||||
)
|
||||
|
||||
return migration_3
|
@ -0,0 +1,94 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration4Callback:
|
||||
"""Callback to do step 4 of migration."""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: # noqa D102
|
||||
self._create_model_metadata(cursor)
|
||||
self._create_model_tags(cursor)
|
||||
self._create_tags(cursor)
|
||||
self._create_triggers(cursor)
|
||||
|
||||
def _create_model_metadata(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Create the table used to store model metadata downloaded from remote sources."""
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_metadata (
|
||||
id TEXT NOT NULL,
|
||||
name TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.name')) VIRTUAL NOT NULL,
|
||||
author TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.author')) VIRTUAL NOT NULL,
|
||||
-- Serialized JSON representation of the whole metadata object,
|
||||
-- which will contain additional fields from subclasses
|
||||
metadata TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
FOREIGN KEY(id) REFERENCES model_config(id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_model_tags(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS model_tags (
|
||||
id TEXT NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
FOREIGN KEY(id) REFERENCES model_config(id),
|
||||
FOREIGN KEY(tag_id) REFERENCES tags(tag_id),
|
||||
UNIQUE(id,tag_id)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_tags(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
tag_id INTEGER NOT NULL PRIMARY KEY,
|
||||
tag_text TEXT NOT NULL UNIQUE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_triggers(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_metadata_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_metadata FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_metadata SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_deleted
|
||||
AFTER DELETE
|
||||
ON model_config
|
||||
BEGIN
|
||||
DELETE from model_metadata WHERE id=old.id;
|
||||
DELETE from model_tags WHERE id=old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_4() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 3 to 4.
|
||||
|
||||
Adds the tables needed to store model metadata and tags.
|
||||
"""
|
||||
migration_4 = Migration(
|
||||
from_version=3,
|
||||
to_version=4,
|
||||
callback=Migration4Callback(),
|
||||
)
|
||||
|
||||
return migration_4
|
@ -1,8 +1,12 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from hashlib import sha1
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import TypeAdapter
|
||||
@ -10,24 +14,22 @@ from pydantic import TypeAdapter
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class MigrateModelYamlToDb:
|
||||
class MigrateModelYamlToDb1:
|
||||
"""
|
||||
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
|
||||
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.5.0).
|
||||
|
||||
The class has one externally useful method, migrate(), which scans the
|
||||
currently models.yaml file and imports all its entries into invokeai.db.
|
||||
@ -41,17 +43,12 @@ class MigrateModelYamlToDb:
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
logger: Logger
|
||||
cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.config.parse_args()
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
|
||||
def get_db(self) -> ModelRecordServiceSQL:
|
||||
"""Fetch the sqlite3 database for this installation."""
|
||||
db_path = None if self.config.use_memory_db else self.config.db_path
|
||||
db = SqliteDatabase(db_path=db_path, logger=self.logger, verbose=self.config.log_sql)
|
||||
return ModelRecordServiceSQL(db)
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger, cursor: sqlite3.Cursor = None) -> None:
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.cursor = cursor
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
@ -62,8 +59,10 @@ class MigrateModelYamlToDb:
|
||||
|
||||
def migrate(self) -> None:
|
||||
"""Do the migration from models.yaml to invokeai.db."""
|
||||
db = self.get_db()
|
||||
yaml = self.get_yaml()
|
||||
try:
|
||||
yaml = self.get_yaml()
|
||||
except OSError:
|
||||
return
|
||||
|
||||
for model_key, stanza in yaml.items():
|
||||
if model_key == "__metadata__":
|
||||
@ -86,22 +85,62 @@ class MigrateModelYamlToDb:
|
||||
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||
|
||||
try:
|
||||
if original_record := db.search_by_path(stanza.path):
|
||||
key = original_record[0].key
|
||||
if original_record := self._search_by_path(stanza.path):
|
||||
key = original_record.key
|
||||
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||
db.update_model(key, new_config)
|
||||
self._update_model(key, new_config)
|
||||
else:
|
||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||
db.add_model(new_key, new_config)
|
||||
self._add_model(new_key, new_config)
|
||||
except DuplicateModelException:
|
||||
self.logger.warning(f"Model {model_name} is already in the database")
|
||||
except UnknownModelException:
|
||||
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||
|
||||
def _search_by_path(self, path: Path) -> Optional[AnyModelConfig]:
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self.cursor.fetchall()]
|
||||
return results[0] if results else None
|
||||
|
||||
def main():
|
||||
MigrateModelYamlToDb().migrate()
|
||||
def _update_model(self, key: str, config: AnyModelConfig) -> None:
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self.cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
def _add_model(self, key: str, config: AnyModelConfig) -> None:
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = record.model_dump_json() # and turn it into a json string.
|
||||
try:
|
||||
self.cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
id,
|
||||
original_hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
except sqlite3.IntegrityError as exc:
|
||||
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,975 @@
|
||||
{
|
||||
"name": "Prompt from File",
|
||||
"author": "InvokeAI",
|
||||
"description": "Sample workflow using Prompt from File node",
|
||||
"version": "0.1.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text2image, prompt from file, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
|
||||
"fieldName": "file_path"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"category": "default",
|
||||
"version": "2.0.0"
|
||||
},
|
||||
"id": "d1609af5-eb0a-4f73-b573-c9af96a8d6bf",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "c2eaf1ba-5708-4679-9e15-945b8b432692",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c2eaf1ba-5708-4679-9e15-945b8b432692",
|
||||
"type": "compel",
|
||||
"label": "",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "dcdf3f6d-9b96-4bcd-9b8d-f992fefe4f62",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"clip": {
|
||||
"id": "3f1981c9-d8a9-42eb-a739-4f120eb80745",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "46205e6c-c5e2-44cb-9c82-1cd20b95674a",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 925,
|
||||
"y": -200
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
|
||||
"type": "prompt_from_file",
|
||||
"label": "Prompts from File",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.1",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"file_path": {
|
||||
"id": "37e37684-4f30-4ec8-beae-b333e550f904",
|
||||
"name": "file_path",
|
||||
"fieldKind": "input",
|
||||
"label": "Prompts File Path",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"pre_prompt": {
|
||||
"id": "7de02feb-819a-4992-bad3-72a30920ddea",
|
||||
"name": "pre_prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"post_prompt": {
|
||||
"id": "95f191d8-a282-428e-bd65-de8cb9b7513a",
|
||||
"name": "post_prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"start_line": {
|
||||
"id": "efee9a48-05ab-4829-8429-becfa64a0782",
|
||||
"name": "start_line",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 1
|
||||
},
|
||||
"max_prompts": {
|
||||
"id": "abebb428-3d3d-49fd-a482-4e96a16fff08",
|
||||
"name": "max_prompts",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 1
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"collection": {
|
||||
"id": "77d5d7f1-9877-4ab1-9a8c-33e9ffa9abf3",
|
||||
"name": "collection",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": true,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 580,
|
||||
"position": {
|
||||
"x": 475,
|
||||
"y": -400
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "1b89067c-3f6b-42c8-991f-e3055789b251",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "1b89067c-3f6b-42c8-991f-e3055789b251",
|
||||
"type": "iterate",
|
||||
"label": "",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.1.0",
|
||||
"inputs": {
|
||||
"collection": {
|
||||
"id": "4c564bf8-5ed6-441e-ad2c-dda265d5785f",
|
||||
"name": "collection",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": true,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "CollectionField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"item": {
|
||||
"id": "36340f9a-e7a5-4afa-b4b5-313f4e292380",
|
||||
"name": "item",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "CollectionItemField"
|
||||
}
|
||||
},
|
||||
"index": {
|
||||
"id": "1beca95a-2159-460f-97ff-c8bab7d89336",
|
||||
"name": "index",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"total": {
|
||||
"id": "ead597b8-108e-4eda-88a8-5c29fa2f8df9",
|
||||
"name": "total",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 925,
|
||||
"y": -400
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
|
||||
"type": "main_model_loader",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"id": "3f264259-3418-47d5-b90d-b6600e36ae46",
|
||||
"name": "model",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MainModelField"
|
||||
},
|
||||
"value": {
|
||||
"model_name": "stable-diffusion-v1-5",
|
||||
"base_model": "sd-1",
|
||||
"model_type": "main"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"unet": {
|
||||
"id": "8e182ea2-9d0a-4c02-9407-27819288d4b5",
|
||||
"name": "unet",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"id": "d67d9d30-058c-46d5-bded-3d09d6d1aa39",
|
||||
"name": "clip",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"id": "89641601-0429-4448-98d5-190822d920d8",
|
||||
"name": "vae",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 227,
|
||||
"position": {
|
||||
"x": 0,
|
||||
"y": -375
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
|
||||
"type": "compel",
|
||||
"label": "",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "dcdf3f6d-9b96-4bcd-9b8d-f992fefe4f62",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"clip": {
|
||||
"id": "3f1981c9-d8a9-42eb-a739-4f120eb80745",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "46205e6c-c5e2-44cb-9c82-1cd20b95674a",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 925,
|
||||
"y": -275
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
|
||||
"type": "noise",
|
||||
"label": "",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.1",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"seed": {
|
||||
"id": "b722d84a-eeee-484f-bef2-0250c027cb67",
|
||||
"name": "seed",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"width": {
|
||||
"id": "d5f8ce11-0502-4bfc-9a30-5757dddf1f94",
|
||||
"name": "width",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"height": {
|
||||
"id": "f187d5ff-38a5-4c3f-b780-fc5801ef34af",
|
||||
"name": "height",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"use_cpu": {
|
||||
"id": "12f112b8-8b76-4816-b79e-662edc9f9aa5",
|
||||
"name": "use_cpu",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": true
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"noise": {
|
||||
"id": "08576ad1-96d9-42d2-96ef-6f5c1961933f",
|
||||
"name": "noise",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "f3e1f94a-258d-41ff-9789-bd999bd9f40d",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "6cefc357-4339-415e-a951-49b9c2be32f4",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 925,
|
||||
"y": 25
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
|
||||
"type": "rand_int",
|
||||
"label": "",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"id": "b9fc6cf1-469c-4037-9bf0-04836965826f",
|
||||
"name": "low",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"id": "06eac725-0f60-4ba2-b8cd-7ad9f757488c",
|
||||
"name": "high",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 2147483647
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"value": {
|
||||
"id": "df08c84e-7346-4e92-9042-9e5cb773aaff",
|
||||
"name": "value",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 925,
|
||||
"y": -50
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
|
||||
"type": "l2i",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.2.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"metadata": {
|
||||
"id": "022e4b33-562b-438d-b7df-41c3fd931f40",
|
||||
"name": "metadata",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MetadataField"
|
||||
}
|
||||
},
|
||||
"latents": {
|
||||
"id": "67cb6c77-a394-4a66-a6a9-a0a7dcca69ec",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"id": "7b3fd9ad-a4ef-4e04-89fa-3832a9902dbd",
|
||||
"name": "vae",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
},
|
||||
"tiled": {
|
||||
"id": "5ac5680d-3add-4115-8ec0-9ef5bb87493b",
|
||||
"name": "tiled",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": false
|
||||
},
|
||||
"fp32": {
|
||||
"id": "db8297f5-55f8-452f-98cf-6572c2582152",
|
||||
"name": "fp32",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": false
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"image": {
|
||||
"id": "d8778d0c-592a-4960-9280-4e77e00a7f33",
|
||||
"name": "image",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ImageField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "c8b0a75a-f5de-4ff2-9227-f25bb2b97bec",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "83c05fbf-76b9-49ab-93c4-fa4b10e793e4",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 267,
|
||||
"position": {
|
||||
"x": 2037.861329274915,
|
||||
"y": -329.8393457509562
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
|
||||
"type": "denoise_latents",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.5.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"positive_conditioning": {
|
||||
"id": "751fb35b-3f23-45ce-af1c-053e74251337",
|
||||
"name": "positive_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"id": "b9dc06b6-7481-4db1-a8c2-39d22a5eacff",
|
||||
"name": "negative_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"noise": {
|
||||
"id": "6e15e439-3390-48a4-8031-01e0e19f0e1d",
|
||||
"name": "noise",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"steps": {
|
||||
"id": "bfdfb3df-760b-4d51-b17b-0abb38b976c2",
|
||||
"name": "steps",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 10
|
||||
},
|
||||
"cfg_scale": {
|
||||
"id": "47770858-322e-41af-8494-d8b63ed735f3",
|
||||
"name": "cfg_scale",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 7.5
|
||||
},
|
||||
"denoising_start": {
|
||||
"id": "2ba78720-ee02-4130-a348-7bc3531f790b",
|
||||
"name": "denoising_start",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"id": "a874dffb-d433-4d1a-9f59-af4367bb05e4",
|
||||
"name": "denoising_end",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 1
|
||||
},
|
||||
"scheduler": {
|
||||
"id": "36e021ad-b762-4fe4-ad4d-17f0291c40b2",
|
||||
"name": "scheduler",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "SchedulerField"
|
||||
},
|
||||
"value": "euler"
|
||||
},
|
||||
"unet": {
|
||||
"id": "98d3282d-f9f6-4b5e-b9e8-58658f1cac78",
|
||||
"name": "unet",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"control": {
|
||||
"id": "f2ea3216-43d5-42b4-887f-36e8f7166d53",
|
||||
"name": "control",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "ControlField"
|
||||
}
|
||||
},
|
||||
"ip_adapter": {
|
||||
"id": "d0780610-a298-47c8-a54e-70e769e0dfe2",
|
||||
"name": "ip_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "IPAdapterField"
|
||||
}
|
||||
},
|
||||
"t2i_adapter": {
|
||||
"id": "fdb40970-185e-4ea8-8bb5-88f06f91f46a",
|
||||
"name": "t2i_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "T2IAdapterField"
|
||||
}
|
||||
},
|
||||
"cfg_rescale_multiplier": {
|
||||
"id": "3af2d8c5-de83-425c-a100-49cb0f1f4385",
|
||||
"name": "cfg_rescale_multiplier",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"latents": {
|
||||
"id": "e05b538a-1b5a-4aa5-84b1-fd2361289a81",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"denoise_mask": {
|
||||
"id": "463a419e-df30-4382-8ffb-b25b25abe425",
|
||||
"name": "denoise_mask",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "DenoiseMaskField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"latents": {
|
||||
"id": "559ee688-66cf-4139-8b82-3d3aa69995ce",
|
||||
"name": "latents",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "0b4285c2-e8b9-48e5-98f6-0a49d3f98fd2",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "8b0881b9-45e5-47d5-b526-24b6661de0ee",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 705,
|
||||
"position": {
|
||||
"x": 1570.9941088179146,
|
||||
"y": -407.6505491604564
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "1b89067c-3f6b-42c8-991f-e3055789b251-fc9d0e35-a6de-4a19-84e1-c72497c823f6-collapsed",
|
||||
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
|
||||
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
|
||||
"type": "collapsed"
|
||||
},
|
||||
{
|
||||
"id": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77-collapsed",
|
||||
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
|
||||
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
|
||||
"type": "collapsed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-1b7e0df8-8589-4915-a4ea-c0088f15d642collection-1b89067c-3f6b-42c8-991f-e3055789b251collection",
|
||||
"source": "1b7e0df8-8589-4915-a4ea-c0088f15d642",
|
||||
"target": "1b89067c-3f6b-42c8-991f-e3055789b251",
|
||||
"type": "default",
|
||||
"sourceHandle": "collection",
|
||||
"targetHandle": "collection"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-fc9d0e35-a6de-4a19-84e1-c72497c823f6clip",
|
||||
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
|
||||
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-1b89067c-3f6b-42c8-991f-e3055789b251item-fc9d0e35-a6de-4a19-84e1-c72497c823f6prompt",
|
||||
"source": "1b89067c-3f6b-42c8-991f-e3055789b251",
|
||||
"target": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
|
||||
"type": "default",
|
||||
"sourceHandle": "item",
|
||||
"targetHandle": "prompt"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426clip-c2eaf1ba-5708-4679-9e15-945b8b432692clip",
|
||||
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
|
||||
"target": "c2eaf1ba-5708-4679-9e15-945b8b432692",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5value-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77seed",
|
||||
"source": "dfc20e07-7aef-4fc0-a3a1-7bf68ec6a4e5",
|
||||
"target": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
|
||||
"type": "default",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-fc9d0e35-a6de-4a19-84e1-c72497c823f6conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5epositive_conditioning",
|
||||
"source": "fc9d0e35-a6de-4a19-84e1-c72497c823f6",
|
||||
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c2eaf1ba-5708-4679-9e15-945b8b432692conditioning-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enegative_conditioning",
|
||||
"source": "c2eaf1ba-5708-4679-9e15-945b8b432692",
|
||||
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77noise-2fb1577f-0a56-4f12-8711-8afcaaaf1d5enoise",
|
||||
"source": "0eb5f3f5-1b91-49eb-9ef0-41d67c7eae77",
|
||||
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
|
||||
"type": "default",
|
||||
"sourceHandle": "noise",
|
||||
"targetHandle": "noise"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426unet-2fb1577f-0a56-4f12-8711-8afcaaaf1d5eunet",
|
||||
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
|
||||
"target": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
|
||||
"type": "default",
|
||||
"sourceHandle": "unet",
|
||||
"targetHandle": "unet"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-2fb1577f-0a56-4f12-8711-8afcaaaf1d5elatents-491ec988-3c77-4c37-af8a-39a0c4e7a2a1latents",
|
||||
"source": "2fb1577f-0a56-4f12-8711-8afcaaaf1d5e",
|
||||
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
|
||||
"type": "default",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-d6353b7f-b447-4e17-8f2e-80a88c91d426vae-491ec988-3c77-4c37-af8a-39a0c4e7a2a1vae",
|
||||
"source": "d6353b7f-b447-4e17-8f2e-80a88c91d426",
|
||||
"target": "491ec988-3c77-4c37-af8a-39a0c4e7a2a1",
|
||||
"type": "default",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
}
|
||||
]
|
||||
}
|
@ -0,0 +1,903 @@
|
||||
{
|
||||
"name": "Text to Image with LoRA",
|
||||
"author": "InvokeAI",
|
||||
"description": "Simple text to image workflow with a LoRA",
|
||||
"version": "1.0.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text to image, lora, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"fieldName": "lora"
|
||||
},
|
||||
{
|
||||
"nodeId": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"fieldName": "weight"
|
||||
},
|
||||
{
|
||||
"nodeId": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
|
||||
"fieldName": "prompt"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"version": "2.0.0",
|
||||
"category": "default"
|
||||
},
|
||||
"id": "a9d70c39-4cdd-4176-9942-8ff3fe32d3b1",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "85b77bb2-c67a-416a-b3e8-291abe746c44",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "85b77bb2-c67a-416a-b3e8-291abe746c44",
|
||||
"type": "compel",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "39fe92c4-38eb-4cc7-bf5e-cbcd31847b11",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "Negative Prompt",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"clip": {
|
||||
"id": "14313164-e5c4-4e40-a599-41b614fe3690",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "02140b9d-50f3-470b-a0b7-01fc6ed2dcd6",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 256,
|
||||
"position": {
|
||||
"x": 3425,
|
||||
"y": -300
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
|
||||
"type": "main_model_loader",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"id": "e2e1c177-ae39-4244-920e-d621fa156a24",
|
||||
"name": "model",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MainModelField"
|
||||
},
|
||||
"value": {
|
||||
"model_name": "Analog-Diffusion",
|
||||
"base_model": "sd-1",
|
||||
"model_type": "main"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"vae": {
|
||||
"id": "f91410e8-9378-4298-b285-f0f40ffd9825",
|
||||
"name": "vae",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"id": "928d91bf-de0c-44a8-b0c8-4de0e2e5b438",
|
||||
"name": "clip",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
},
|
||||
"unet": {
|
||||
"id": "eacaf530-4e7e-472e-b904-462192189fc1",
|
||||
"name": "unet",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 227,
|
||||
"position": {
|
||||
"x": 2500,
|
||||
"y": -600
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"type": "lora_loader",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"inputs": {
|
||||
"lora": {
|
||||
"id": "36d867e8-92ea-4c3f-9ad5-ba05c64cf326",
|
||||
"name": "lora",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LoRAModelField"
|
||||
},
|
||||
"value": {
|
||||
"model_name": "Ink scenery",
|
||||
"base_model": "sd-1"
|
||||
}
|
||||
},
|
||||
"weight": {
|
||||
"id": "8be86540-ba81-49b3-b394-2b18fa70b867",
|
||||
"name": "weight",
|
||||
"fieldKind": "input",
|
||||
"label": "LoRA Weight",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0.75
|
||||
},
|
||||
"unet": {
|
||||
"id": "9c4d5668-e9e1-411b-8f4b-e71115bc4a01",
|
||||
"name": "unet",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"id": "918ec00e-e76f-4ad0-aee1-3927298cf03b",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"unet": {
|
||||
"id": "c63f7825-1bcf-451d-b7a7-aa79f5c77416",
|
||||
"name": "unet",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"id": "6f79ef2d-00f7-4917-bee3-53e845bf4192",
|
||||
"name": "clip",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 252,
|
||||
"position": {
|
||||
"x": 2975,
|
||||
"y": -600
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
|
||||
"type": "compel",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "39fe92c4-38eb-4cc7-bf5e-cbcd31847b11",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "Positive Prompt",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": "cute tiger cub"
|
||||
},
|
||||
"clip": {
|
||||
"id": "14313164-e5c4-4e40-a599-41b614fe3690",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "02140b9d-50f3-470b-a0b7-01fc6ed2dcd6",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 256,
|
||||
"position": {
|
||||
"x": 3425,
|
||||
"y": -575
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
|
||||
"type": "denoise_latents",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.5.0",
|
||||
"inputs": {
|
||||
"positive_conditioning": {
|
||||
"id": "025ff44b-c4c6-4339-91b4-5f461e2cadc5",
|
||||
"name": "positive_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"id": "2d92b45a-a7fb-4541-9a47-7c7495f50f54",
|
||||
"name": "negative_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"noise": {
|
||||
"id": "4d0deeff-24ed-4562-a1ca-7833c0649377",
|
||||
"name": "noise",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"steps": {
|
||||
"id": "c9907328-aece-4af9-8a95-211b4f99a325",
|
||||
"name": "steps",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 10
|
||||
},
|
||||
"cfg_scale": {
|
||||
"id": "7cf0f031-2078-49f4-9273-bb3a64ad7130",
|
||||
"name": "cfg_scale",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 7.5
|
||||
},
|
||||
"denoising_start": {
|
||||
"id": "44cec3ba-b404-4b51-ba98-add9d783279e",
|
||||
"name": "denoising_start",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"id": "3e7975f3-e438-4a13-8a14-395eba1fb7cd",
|
||||
"name": "denoising_end",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 1
|
||||
},
|
||||
"scheduler": {
|
||||
"id": "a6f6509b-7bb4-477d-b5fb-74baefa38111",
|
||||
"name": "scheduler",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "SchedulerField"
|
||||
},
|
||||
"value": "euler"
|
||||
},
|
||||
"unet": {
|
||||
"id": "5a87617a-b09f-417b-9b75-0cea4c255227",
|
||||
"name": "unet",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"control": {
|
||||
"id": "db87aace-ace8-4f2a-8f2b-1f752389fa9b",
|
||||
"name": "control",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "ControlField"
|
||||
}
|
||||
},
|
||||
"ip_adapter": {
|
||||
"id": "f0c133ed-4d6d-4567-bb9a-b1779810993c",
|
||||
"name": "ip_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "IPAdapterField"
|
||||
}
|
||||
},
|
||||
"t2i_adapter": {
|
||||
"id": "59ee1233-887f-45e7-aa14-cbad5f6cb77f",
|
||||
"name": "t2i_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "T2IAdapterField"
|
||||
}
|
||||
},
|
||||
"cfg_rescale_multiplier": {
|
||||
"id": "1a12e781-4b30-4707-b432-18c31866b5c3",
|
||||
"name": "cfg_rescale_multiplier",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"latents": {
|
||||
"id": "d0e593ae-305c-424b-9acd-3af830085832",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"denoise_mask": {
|
||||
"id": "b81b5a79-fc2b-4011-aae6-64c92bae59a7",
|
||||
"name": "denoise_mask",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "DenoiseMaskField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"latents": {
|
||||
"id": "9ae4022a-548e-407e-90cf-cc5ca5ff8a21",
|
||||
"name": "latents",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "730ba4bd-2c52-46bb-8c87-9b3aec155576",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "52b98f0b-b5ff-41b5-acc7-d0b1d1011a6f",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 705,
|
||||
"position": {
|
||||
"x": 3975,
|
||||
"y": -575
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
|
||||
"type": "noise",
|
||||
"label": "",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.1",
|
||||
"inputs": {
|
||||
"seed": {
|
||||
"id": "446ac80c-ba0a-4fea-a2d7-21128f52e5bf",
|
||||
"name": "seed",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"width": {
|
||||
"id": "779831b3-20b4-4f5f-9de7-d17de57288d8",
|
||||
"name": "width",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"height": {
|
||||
"id": "08959766-6d67-4276-b122-e54b911f2316",
|
||||
"name": "height",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"use_cpu": {
|
||||
"id": "53b36a98-00c4-4dc5-97a4-ef3432c0a805",
|
||||
"name": "use_cpu",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": true
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"noise": {
|
||||
"id": "eed95824-580b-442f-aa35-c073733cecce",
|
||||
"name": "noise",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "7985a261-dfee-47a8-908a-c5a8754f5dc4",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "3d00f6c1-84b0-4262-83d9-3bf755babeea",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 3425,
|
||||
"y": 75
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
|
||||
"type": "rand_int",
|
||||
"label": "",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"version": "1.0.0",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"id": "d25305f3-bfd6-446c-8e2c-0b025ec9e9ad",
|
||||
"name": "low",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"id": "10376a3d-b8fe-4a51-b81a-ea46d8c12c78",
|
||||
"name": "high",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 2147483647
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"value": {
|
||||
"id": "c64878fa-53b1-4202-b88a-cfb854216a57",
|
||||
"name": "value",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 3425,
|
||||
"y": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
|
||||
"type": "l2i",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"version": "1.2.0",
|
||||
"inputs": {
|
||||
"metadata": {
|
||||
"id": "b1982e8a-14ad-4029-a697-beb30af8340f",
|
||||
"name": "metadata",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MetadataField"
|
||||
}
|
||||
},
|
||||
"latents": {
|
||||
"id": "f7669388-9f91-46cc-94fc-301fa7041c3e",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"id": "c6f2d4db-4d0a-4e3d-acb4-b5c5a228a3e2",
|
||||
"name": "vae",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
},
|
||||
"tiled": {
|
||||
"id": "19ef7d31-d96f-4e94-b7e5-95914e9076fc",
|
||||
"name": "tiled",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": false
|
||||
},
|
||||
"fp32": {
|
||||
"id": "a9454533-8ab7-4225-b411-646dc5e76d00",
|
||||
"name": "fp32",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": false
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"image": {
|
||||
"id": "4f81274e-e216-47f3-9fb6-f97493a40e6f",
|
||||
"name": "image",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ImageField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "61a9acfb-1547-4f1e-8214-e89bd3855ee5",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "b15cc793-4172-4b07-bcf4-5627bbc7d0d7",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 267,
|
||||
"position": {
|
||||
"x": 4450,
|
||||
"y": -550
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953-ea18915f-2c5b-4569-b725-8e9e9122e8d3-collapsed",
|
||||
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
|
||||
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
|
||||
"type": "collapsed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818clip-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip",
|
||||
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
|
||||
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-c3fa6872-2599-4a82-a596-b3446a66cf8bclip",
|
||||
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"target": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818unet-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet",
|
||||
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
|
||||
"target": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"type": "default",
|
||||
"sourceHandle": "unet",
|
||||
"targetHandle": "unet"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966unet-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63unet",
|
||||
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
|
||||
"type": "default",
|
||||
"sourceHandle": "unet",
|
||||
"targetHandle": "unet"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-85b77bb2-c67a-416a-b3e8-291abe746c44conditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63negative_conditioning",
|
||||
"source": "85b77bb2-c67a-416a-b3e8-291abe746c44",
|
||||
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c3fa6872-2599-4a82-a596-b3446a66cf8bconditioning-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63positive_conditioning",
|
||||
"source": "c3fa6872-2599-4a82-a596-b3446a66cf8b",
|
||||
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-ea18915f-2c5b-4569-b725-8e9e9122e8d3noise-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63noise",
|
||||
"source": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
|
||||
"target": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
|
||||
"type": "default",
|
||||
"sourceHandle": "noise",
|
||||
"targetHandle": "noise"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-6fd74a17-6065-47a5-b48b-f4e2b8fa7953value-ea18915f-2c5b-4569-b725-8e9e9122e8d3seed",
|
||||
"source": "6fd74a17-6065-47a5-b48b-f4e2b8fa7953",
|
||||
"target": "ea18915f-2c5b-4569-b725-8e9e9122e8d3",
|
||||
"type": "default",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63latents-a9683c0a-6b1f-4a5e-8187-c57e764b3400latents",
|
||||
"source": "ad487d0c-dcbb-49c5-bb8e-b28d4cbc5a63",
|
||||
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
|
||||
"type": "default",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-24e9d7ed-4836-4ec4-8f9e-e747721f9818vae-a9683c0a-6b1f-4a5e-8187-c57e764b3400vae",
|
||||
"source": "24e9d7ed-4836-4ec4-8f9e-e747721f9818",
|
||||
"target": "a9683c0a-6b1f-4a5e-8187-c57e764b3400",
|
||||
"type": "default",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c41e705b-f2e3-4d1a-83c4-e34bb9344966clip-85b77bb2-c67a-416a-b3e8-291abe746c44clip",
|
||||
"source": "c41e705b-f2e3-4d1a-83c4-e34bb9344966",
|
||||
"target": "85b77bb2-c67a-416a-b3e8-291abe746c44",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
}
|
||||
]
|
||||
}
|
8
invokeai/app/util/ti_utils.py
Normal file
8
invokeai/app/util/ti_utils.py
Normal file
@ -0,0 +1,8 @@
|
||||
import re
|
||||
|
||||
|
||||
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
|
||||
ti_triggers = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
ti_triggers.append(trigger)
|
||||
return ti_triggers
|
@ -28,7 +28,7 @@ def check_invokeai_root(config: InvokeAIAppConfig):
|
||||
print("== STARTUP ABORTED ==")
|
||||
print("** One or more necessary files is missing from your InvokeAI root directory **")
|
||||
print("** Please rerun the configuration script to fix this problem. **")
|
||||
print("** From the launcher, selection option [7]. **")
|
||||
print("** From the launcher, selection option [6]. **")
|
||||
print(
|
||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
||||
)
|
||||
|
@ -389,7 +389,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
elif "clip_g" in checkpoint:
|
||||
token_dim = checkpoint["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
token_dim = list(checkpoint.values())[0].shape[-1]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
|
@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_unet_") and (
|
||||
"time_emb_proj.lora_down" in key
|
||||
): # recognizes format at https://civitai.com/models/224641
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
|
43
invokeai/backend/model_manager/metadata/__init__.py
Normal file
43
invokeai/backend/model_manager/metadata/__init__.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata import(
|
||||
AnyModelRepoMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
HuggingFaceMetadata,
|
||||
CivitaiMetadata,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .fetch import CivitaiMetadataFetch, HuggingFaceMetadataFetch
|
||||
from .metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
LicenseRestrictions,
|
||||
)
|
||||
from .metadata_store import ModelMetadataStore
|
||||
|
||||
__all__ = [
|
||||
"AnyModelRepoMetadata",
|
||||
"AnyModelRepoMetadataValidator",
|
||||
"CommercialUsage",
|
||||
"LicenseRestrictions",
|
||||
"HuggingFaceMetadata",
|
||||
"CivitaiMetadata",
|
||||
"ModelMetadataStore",
|
||||
"CivitaiMetadataFetch",
|
||||
"HuggingFaceMetadataFetch",
|
||||
]
|
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
21
invokeai/backend/model_manager/metadata/fetch/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.metadata.fetch
|
||||
|
||||
Usage:
|
||||
from invokeai.backend.model_manager.metadata.fetch import (
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import CivitaiMetadata
|
||||
|
||||
data = CivitaiMetadataFetch().from_url("https://civitai.com/models/206883/split")
|
||||
assert isinstance(data, CivitaiMetadata)
|
||||
if data.allow_commercial_use:
|
||||
print("Commercial use of this model is allowed")
|
||||
"""
|
||||
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
from .fetch_civitai import CivitaiMetadataFetch
|
||||
from .fetch_huggingface import HuggingFaceMetadataFetch
|
||||
|
||||
__all__ = ["ModelMetadataFetchBase", "CivitaiMetadataFetch", "HuggingFaceMetadataFetch"]
|
63
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
63
invokeai/backend/model_manager/metadata/fetch/fetch_base.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module is the base class for subclasses that fetch metadata from model repositories
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitAIMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
||||
|
||||
|
||||
class ModelMetadataFetchBase(ABC):
|
||||
"""Fetch metadata from remote generative model repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a model repository, return a ModelMetadata object.
|
||||
|
||||
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
|
||||
in the event that the requested model metadata is not found at the provided location.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given an ID for a model, return a ModelMetadata object.
|
||||
|
||||
This method will raise an `invokeai.app.services.model_records.UnknownModelException`
|
||||
in the event that the requested model's metadata is not found at the provided id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_json(self, json: str) -> AnyModelRepoMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
return (
|
||||
metadata # mypy complains that metadata is a <typing special form> and issues a type checking error. Why?
|
||||
)
|
157
invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py
Normal file
157
invokeai/backend/model_manager/metadata/fetch/fetch_civitai.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the Civitai model repository.
|
||||
In addition to the `from_url()` and `from_id()` methods inherited from the
|
||||
`ModelMetadataFetchBase` base class.
|
||||
|
||||
Civitai has two separate ID spaces: a model ID and a version ID. The
|
||||
version ID corresponds to a specific model, and is the ID accepted by
|
||||
`from_id()`. The model ID corresponds to a family of related models,
|
||||
such as different training checkpoints or 16 vs 32-bit versions. The
|
||||
`from_civitai_modelid()` method will accept a model ID and return the
|
||||
metadata from the default version within this model set. The default
|
||||
version is the same as what the user sees when they click on a model's
|
||||
thumbnail.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import CivitaiMetadataFetch
|
||||
|
||||
fetcher = CivitaiMetadataFetch()
|
||||
metadata = fetcher.from_url("https://civitai.com/models/206883/split")
|
||||
print(metadata.trained_words)
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
|
||||
from ..metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
CivitaiMetadata,
|
||||
CommercialUsage,
|
||||
LicenseRestrictions,
|
||||
)
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
CIVITAI_MODEL_PAGE_RE = r"https?://civitai.com/models/(\d+)"
|
||||
CIVITAI_VERSION_PAGE_RE = r"https?://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_DOWNLOAD_RE = r"https?://civitai.com/api/download/models/(\d+)"
|
||||
|
||||
CIVITAI_VERSION_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
CIVITAI_MODEL_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
|
||||
|
||||
class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from Civitai."""
|
||||
|
||||
_requests: Session
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a URL to a CivitAI model or version page, return a ModelMetadata object.
|
||||
|
||||
In the event that the URL points to a model page without the particular version
|
||||
indicated, the default model version is returned. Otherwise, the requested version
|
||||
is returned.
|
||||
"""
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_RE, str(url)):
|
||||
model_id = match.group(1)
|
||||
return self.from_civitai_modelid(int(model_id))
|
||||
elif match := re.match(CIVITAI_VERSION_PAGE_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
elif match := re.match(CIVITAI_DOWNLOAD_RE, str(url)):
|
||||
version_id = match.group(1)
|
||||
return self.from_civitai_versionid(int(version_id))
|
||||
raise UnknownModelException("The url '{url}' does not match any known Civitai URL patterns")
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Given a Civitai model version ID, return a ModelRepoMetadata object.
|
||||
|
||||
May raise an `UnknownModelException`.
|
||||
"""
|
||||
return self.from_civitai_versionid(int(id))
|
||||
|
||||
def from_civitai_modelid(self, model_id: int) -> CivitaiMetadata:
|
||||
"""
|
||||
Return metadata from the default version of the indicated model.
|
||||
|
||||
May raise an `UnknownModelException`.
|
||||
"""
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(model_id)
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json)
|
||||
|
||||
def _from_model_json(self, model_json: Dict[str, Any], version_id: Optional[int] = None) -> CivitaiMetadata:
|
||||
version_id = version_id or model_json["modelVersions"][0]["id"]
|
||||
|
||||
# loop till we find the section containing the version requested
|
||||
version_sections = [x for x in model_json["modelVersions"] if x["id"] == version_id]
|
||||
if not version_sections:
|
||||
raise UnknownModelException(f"Version {version_id} not found in model metadata")
|
||||
|
||||
version_json = version_sections[0]
|
||||
safe_thumbnails = [x["url"] for x in version_json["images"] if x["nsfw"] == "None"]
|
||||
return CivitaiMetadata(
|
||||
id=model_json["id"],
|
||||
name=model_json["name"],
|
||||
version_id=version_json["id"],
|
||||
version_name=version_json["name"],
|
||||
created=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["createdAt"])),
|
||||
updated=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["updatedAt"])),
|
||||
published=datetime.fromisoformat(re.sub(r"Z$", "+00:00", version_json["publishedAt"])),
|
||||
base_model_trained_on=version_json["baseModel"], # note - need a dictionary to turn into a BaseModelType
|
||||
download_url=version_json["downloadUrl"],
|
||||
thumbnail_url=safe_thumbnails[0] if safe_thumbnails else None,
|
||||
author=model_json["creator"]["username"],
|
||||
description=model_json["description"],
|
||||
version_description=version_json["description"] or "",
|
||||
tags=model_json["tags"],
|
||||
trained_words=version_json["trainedWords"],
|
||||
nsfw=model_json["nsfw"],
|
||||
restrictions=LicenseRestrictions(
|
||||
AllowNoCredit=model_json["allowNoCredit"],
|
||||
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
|
||||
AllowDerivatives=model_json["allowDerivatives"],
|
||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||
),
|
||||
)
|
||||
|
||||
def from_civitai_versionid(self, version_id: int) -> CivitaiMetadata:
|
||||
"""
|
||||
Return a CivitaiMetadata object given a model version id.
|
||||
|
||||
May raise an `UnknownModelException`.
|
||||
"""
|
||||
version_url = CIVITAI_VERSION_ENDPOINT + str(version_id)
|
||||
version = self._requests.get(version_url).json()
|
||||
|
||||
model_url = CIVITAI_MODEL_ENDPOINT + str(version["modelId"])
|
||||
model_json = self._requests.get(model_url).json()
|
||||
return self._from_model_json(model_json, version_id)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> CivitaiMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
return metadata
|
@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
This module fetches model metadata objects from the HuggingFace model repository,
|
||||
using either a `repo_id` or the model page URL.
|
||||
|
||||
Usage:
|
||||
|
||||
from invokeai.backend.model_manager.metadata.fetch import HuggingFaceMetadataFetch
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch()
|
||||
metadata = fetcher.from_url("https://huggingface.co/stabilityai/sdxl-turbo")
|
||||
print(metadata.tags)
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, configure_http_backend
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, HuggingFaceMetadata
|
||||
from .fetch_base import ModelMetadataFetchBase
|
||||
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
|
||||
|
||||
class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
"""Fetch model metadata from HuggingFace."""
|
||||
|
||||
_requests: Session
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""
|
||||
Initialize the fetcher with an optional requests.sessions.Session object.
|
||||
|
||||
By providing a configurable Session object, we can support unit tests on
|
||||
this module without an internet connection.
|
||||
"""
|
||||
self._requests = session or requests.Session()
|
||||
configure_http_backend(backend_factory=lambda: self._requests)
|
||||
|
||||
def from_id(self, id: str) -> AnyModelRepoMetadata:
|
||||
"""Return a HuggingFaceMetadata object given the model's repo_id."""
|
||||
try:
|
||||
model_info = HfApi().model_info(repo_id=id, files_metadata=True)
|
||||
except RepositoryNotFoundError as excp:
|
||||
raise UnknownModelException(f"'{id}' not found. See trace for details.") from excp
|
||||
|
||||
_, name = id.split("/")
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.modelId,
|
||||
author=model_info.author,
|
||||
name=name,
|
||||
last_modified=model_info.lastModified,
|
||||
tag_dict=model_info.card_data.to_dict(),
|
||||
tags=model_info.tags,
|
||||
files=[Path(x.rfilename) for x in model_info.siblings],
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Return a HuggingFaceMetadata object given the model's web page URL.
|
||||
|
||||
In the case of an invalid or missing URL, raises a ModelNotFound exception.
|
||||
"""
|
||||
if match := re.match(HF_MODEL_RE, str(url)):
|
||||
repo_id = match.group(1)
|
||||
return self.from_id(repo_id)
|
||||
else:
|
||||
raise UnknownModelException(f"'{url}' does not look like a HuggingFace model page")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> HuggingFaceMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
return metadata
|
119
invokeai/backend/model_manager/metadata/metadata_base.py
Normal file
119
invokeai/backend/model_manager/metadata/metadata_base.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""This module defines core text-to-image model metadata fields.
|
||||
|
||||
Metadata comprises any descriptive information that is not essential
|
||||
for getting the model to run. For example "author" is metadata, while
|
||||
"type", "base" and "format" are not. The latter fields are part of the
|
||||
model's config, as defined in invokeai.backend.model_manager.config.
|
||||
|
||||
Note that the "name" and "description" are also present in `config`
|
||||
records. This is intentional. The config record fields are intended to
|
||||
be editable by the user as a form of customization. The metadata
|
||||
versions of these fields are intended to be kept in sync with the
|
||||
remote repo.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class CommercialUsage(str, Enum):
|
||||
"""Type of commercial usage allowed."""
|
||||
|
||||
No = "None"
|
||||
Image = "Image"
|
||||
Rent = "Rent"
|
||||
RentCivit = "RentCivit"
|
||||
Sell = "Sell"
|
||||
|
||||
|
||||
class LicenseRestrictions(BaseModel):
|
||||
"""Broad categories of licensing restrictions."""
|
||||
|
||||
AllowNoCredit: bool = Field(
|
||||
description="if true, model can be redistributed without crediting author", default=False
|
||||
)
|
||||
AllowDerivatives: bool = Field(description="if true, derivatives of this model can be redistributed", default=False)
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: CommercialUsage = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
|
||||
)
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
||||
name: str = Field(description="model's name")
|
||||
author: str = Field(description="model's author")
|
||||
tags: Set[str] = Field(description="tags provided by model source")
|
||||
|
||||
|
||||
class HuggingFaceMetadata(ModelMetadataBase):
|
||||
"""Extended metadata fields provided by HuggingFace."""
|
||||
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="huggingface model id")
|
||||
tag_dict: Dict[str, Any]
|
||||
last_modified: datetime = Field(description="date of last commit to repo")
|
||||
files: List[Path] = Field(description="sibling files that belong to this model", default_factory=list)
|
||||
|
||||
|
||||
class CivitaiMetadata(ModelMetadataBase):
|
||||
"""Extended metadata fields provided by Civitai."""
|
||||
|
||||
type: Literal["civitai"] = "civitai"
|
||||
id: int = Field(description="Civitai model identifier")
|
||||
version_name: str = Field(description="Version identifier, such as 'V2-alpha'")
|
||||
version_id: int = Field(description="Civitai model version identifier")
|
||||
created: datetime = Field(description="date the model was created")
|
||||
updated: datetime = Field(description="date the model was last modified")
|
||||
published: datetime = Field(description="date the model was published to Civitai")
|
||||
description: str = Field(description="text description of model; may contain HTML")
|
||||
version_description: str = Field(
|
||||
description="text description of the model's reversion; usually change history; may contain HTML"
|
||||
)
|
||||
nsfw: bool = Field(description="whether the model tends to generate NSFW content", default=False)
|
||||
restrictions: LicenseRestrictions = Field(description="license terms", default_factory=LicenseRestrictions)
|
||||
trained_words: Set[str] = Field(description="words to trigger the model", default_factory=set)
|
||||
download_url: AnyHttpUrl = Field(description="download URL for this model")
|
||||
base_model_trained_on: str = Field(description="base model on which this model was trained (currently not an enum)")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="a thumbnail image for this model", default=None)
|
||||
weight_min: float = Field(
|
||||
description="minimum suggested value for a LoRA or other secondary model", default=-1.0
|
||||
) # note: For future use; not currently easily
|
||||
weight_max: float = Field(
|
||||
description="maximum suggested value for a LoRA or other secondary model", default=+2.0
|
||||
) # recoverable from metadata
|
||||
|
||||
@property
|
||||
def credit_required(self) -> bool:
|
||||
"""Return True if you must give credit for derivatives of this model and images generated from it."""
|
||||
return not self.restrictions.AllowNoCredit
|
||||
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
"""Return True if derivatives of this model can be redistributed."""
|
||||
return self.restrictions.AllowDerivatives
|
||||
|
||||
@property
|
||||
def allow_different_license(self) -> bool:
|
||||
"""Return true if derivatives of this model can use a different license."""
|
||||
return self.restrictions.AllowDifferentLicense
|
||||
|
||||
|
||||
AnyModelRepoMetadata = Annotated[Union[HuggingFaceMetadata, CivitaiMetadata], Field(discriminator="type")]
|
||||
AnyModelRepoMetadataValidator = TypeAdapter(AnyModelRepoMetadata)
|
208
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
208
invokeai/backend/model_manager/metadata/metadata_store.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import Optional, Set
|
||||
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .fetch import ModelMetadataFetchBase
|
||||
from .metadata_base import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStore:
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
_db: SqliteDatabase
|
||||
_cursor: sqlite3.Cursor
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
self._enable_foreign_key_constraints()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownModelException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
|
||||
def _enable_foreign_key_constraints(self) -> None:
|
||||
self._cursor.execute("PRAGMA foreign_keys = ON;")
|
@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
|
@ -102,7 +102,7 @@ def calc_tiles_with_overlap(
|
||||
|
||||
|
||||
def calc_tiles_even_split(
|
||||
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap_fraction: float = 0
|
||||
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap: int = 0
|
||||
) -> list[Tile]:
|
||||
"""Calculate the tile coordinates for a given image shape with the number of tiles requested.
|
||||
|
||||
@ -111,31 +111,35 @@ def calc_tiles_even_split(
|
||||
image_width (int): The image width in px.
|
||||
num_x_tiles (int): The number of tile to split the image into on the X-axis.
|
||||
num_y_tiles (int): The number of tile to split the image into on the Y-axis.
|
||||
overlap_fraction (float, optional): The target overlap as fraction of the tiles size. Defaults to 0.
|
||||
overlap (int, optional): The overlap between adjacent tiles in pixels. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
|
||||
"""
|
||||
|
||||
# Ensure tile size is divisible by 8
|
||||
# Ensure the image is divisible by LATENT_SCALE_FACTOR
|
||||
if image_width % LATENT_SCALE_FACTOR != 0 or image_height % LATENT_SCALE_FACTOR != 0:
|
||||
raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by {LATENT_SCALE_FACTOR}")
|
||||
|
||||
# Calculate the overlap size based on the percentage and adjust it to be divisible by 8 (rounding up)
|
||||
overlap_x = LATENT_SCALE_FACTOR * math.ceil(
|
||||
int((image_width / num_tiles_x) * overlap_fraction) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
overlap_y = LATENT_SCALE_FACTOR * math.ceil(
|
||||
int((image_height / num_tiles_y) * overlap_fraction) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
|
||||
# Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down)
|
||||
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
|
||||
((image_width + overlap_x * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
|
||||
((image_height + overlap_y * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
if num_tiles_x > 1:
|
||||
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
|
||||
assert overlap <= image_width - (LATENT_SCALE_FACTOR * (num_tiles_x - 1))
|
||||
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
|
||||
((image_width + overlap * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
assert overlap < tile_size_x
|
||||
else:
|
||||
tile_size_x = image_width
|
||||
|
||||
if num_tiles_y > 1:
|
||||
# ensure the overlap is not more than the maximum overlap if we only have 1 tile then we dont care about overlap
|
||||
assert overlap <= image_height - (LATENT_SCALE_FACTOR * (num_tiles_y - 1))
|
||||
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
|
||||
((image_height + overlap * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
|
||||
)
|
||||
assert overlap < tile_size_y
|
||||
else:
|
||||
tile_size_y = image_height
|
||||
|
||||
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||
tiles: list[Tile] = []
|
||||
@ -143,7 +147,7 @@ def calc_tiles_even_split(
|
||||
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||
for tile_idx_y in range(num_tiles_y):
|
||||
# Calculate the top and bottom of the row
|
||||
top = tile_idx_y * (tile_size_y - overlap_y)
|
||||
top = tile_idx_y * (tile_size_y - overlap)
|
||||
bottom = min(top + tile_size_y, image_height)
|
||||
# For the last row adjust bottom to be the height of the image
|
||||
if tile_idx_y == num_tiles_y - 1:
|
||||
@ -151,7 +155,7 @@ def calc_tiles_even_split(
|
||||
|
||||
for tile_idx_x in range(num_tiles_x):
|
||||
# Calculate the left & right coordinate of each tile
|
||||
left = tile_idx_x * (tile_size_x - overlap_x)
|
||||
left = tile_idx_x * (tile_size_x - overlap)
|
||||
right = min(left + tile_size_x, image_width)
|
||||
# For the last tile in the row adjust right to be the width of the image
|
||||
if tile_idx_x == num_tiles_x - 1:
|
||||
|
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
from contextlib import nullcontext
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import autocast
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -37,7 +35,7 @@ def choose_precision(device: torch.device) -> str:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
return "float16"
|
||||
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"):
|
||||
elif device.type == "mps":
|
||||
return "float16"
|
||||
return "float32"
|
||||
|
||||
|
@ -4,6 +4,7 @@ pip install <path_to_git_source>.
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import pkg_resources
|
||||
import psutil
|
||||
@ -31,10 +32,6 @@ else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
|
||||
def get_versions() -> dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
|
||||
def invokeai_is_running() -> bool:
|
||||
for p in psutil.process_iter():
|
||||
try:
|
||||
@ -50,6 +47,20 @@ def invokeai_is_running() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_pypi_versions():
|
||||
url = "https://pypi.org/pypi/invokeai/json"
|
||||
try:
|
||||
data = requests.get(url).json()
|
||||
except Exception:
|
||||
raise Exception("Unable to fetch version information from PyPi")
|
||||
|
||||
versions = list(data["releases"].keys())
|
||||
versions.sort(key=LooseVersion, reverse=True)
|
||||
latest_version = [v for v in versions if "rc" not in v][0]
|
||||
latest_release_candidate = [v for v in versions if "rc" in v][0]
|
||||
return latest_version, latest_release_candidate, versions
|
||||
|
||||
|
||||
def welcome(latest_release: str, latest_prerelease: str):
|
||||
@group()
|
||||
def text():
|
||||
@ -63,8 +74,7 @@ def welcome(latest_release: str, latest_prerelease: str):
|
||||
yield "[bold yellow]Options:"
|
||||
yield f"""[1] Update to the latest [bold]official release[/bold] ([italic]{latest_release}[/italic])
|
||||
[2] Update to the latest [bold]pre-release[/bold] (may be buggy; caveat emptor!) ([italic]{latest_prerelease}[/italic])
|
||||
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
|
||||
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to"""
|
||||
[3] Manually enter the [bold]version[/bold] you wish to update to"""
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@ -92,44 +102,35 @@ def get_extras():
|
||||
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
released_versions = [x for x in versions if not (x["draft"] or x["prerelease"])]
|
||||
prerelease_versions = [x for x in versions if not x["draft"] and x["prerelease"]]
|
||||
latest_release = released_versions[0]["tag_name"] if len(released_versions) else None
|
||||
latest_prerelease = prerelease_versions[0]["tag_name"] if len(prerelease_versions) else None
|
||||
|
||||
if invokeai_is_running():
|
||||
print(":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]")
|
||||
input("Press any key to continue...")
|
||||
return
|
||||
|
||||
latest_release, latest_prerelease, versions = get_pypi_versions()
|
||||
|
||||
welcome(latest_release, latest_prerelease)
|
||||
|
||||
tag = None
|
||||
branch = None
|
||||
release = None
|
||||
choice = Prompt.ask("Choice:", choices=["1", "2", "3", "4"], default="1")
|
||||
release = latest_release
|
||||
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
|
||||
|
||||
if choice == "1":
|
||||
release = latest_release
|
||||
elif choice == "2":
|
||||
release = latest_prerelease
|
||||
elif choice == "3":
|
||||
while not tag:
|
||||
tag = Prompt.ask("Enter an InvokeAI tag name")
|
||||
elif choice == "4":
|
||||
while not branch:
|
||||
branch = Prompt.ask("Enter an InvokeAI branch name")
|
||||
while True:
|
||||
release = Prompt.ask("Enter an InvokeAI version")
|
||||
release.strip()
|
||||
if release in versions:
|
||||
break
|
||||
print(f":exclamation: [bold red]'{release}' is not a recognized InvokeAI release.[/red bold]")
|
||||
|
||||
extras = get_extras()
|
||||
|
||||
print(f":crossed_fingers: Upgrading to [yellow]{tag or release or branch}[/yellow]")
|
||||
if release:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade'
|
||||
elif tag:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade'
|
||||
else:
|
||||
cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade'
|
||||
print(f":crossed_fingers: Upgrading to [yellow]{release}[/yellow]")
|
||||
cmd = f'pip install "invokeai{extras}=={release}" --use-pep517 --upgrade'
|
||||
|
||||
print("")
|
||||
print("")
|
||||
if os.system(cmd) == 0:
|
||||
|
@ -727,9 +727,6 @@
|
||||
"showMinimapnodes": "Mostrar el minimapa",
|
||||
"reloadNodeTemplates": "Recargar las plantillas de nodos",
|
||||
"loadWorkflow": "Cargar el flujo de trabajo",
|
||||
"resetWorkflow": "Reiniciar e flujo de trabajo",
|
||||
"resetWorkflowDesc": "¿Está seguro de que deseas restablecer este flujo de trabajo?",
|
||||
"resetWorkflowDesc2": "Al reiniciar el flujo de trabajo se borrarán todos los nodos, aristas y detalles del flujo de trabajo.",
|
||||
"downloadWorkflow": "Descargar el flujo de trabajo en un archivo JSON"
|
||||
}
|
||||
}
|
||||
|
@ -898,11 +898,8 @@
|
||||
"zoomInNodes": "Ingrandire",
|
||||
"fitViewportNodes": "Adatta vista",
|
||||
"showGraphNodes": "Mostra sovrapposizione grafico",
|
||||
"resetWorkflowDesc2": "Il ripristino dell'editor del flusso di lavoro cancellerà tutti i nodi, le connessioni e i dettagli del flusso di lavoro. I flussi di lavoro salvati non saranno interessati.",
|
||||
"reloadNodeTemplates": "Ricarica i modelli di nodo",
|
||||
"loadWorkflow": "Importa flusso di lavoro JSON",
|
||||
"resetWorkflow": "Reimposta l'editor del flusso di lavoro",
|
||||
"resetWorkflowDesc": "Sei sicuro di voler reimpostare l'editor del flusso di lavoro?",
|
||||
"downloadWorkflow": "Esporta flusso di lavoro JSON",
|
||||
"scheduler": "Campionatore",
|
||||
"addNode": "Aggiungi nodo",
|
||||
@ -1112,7 +1109,10 @@
|
||||
"deletedInvalidEdge": "Eliminata connessione non valida {{source}} -> {{target}}",
|
||||
"unknownInput": "Input sconosciuto: {{name}}",
|
||||
"prototypeDesc": "Questa invocazione è un prototipo. Potrebbe subire modifiche sostanziali durante gli aggiornamenti dell'app e potrebbe essere rimossa in qualsiasi momento.",
|
||||
"betaDesc": "Questa invocazione è in versione beta. Fino a quando non sarà stabile, potrebbe subire modifiche importanti durante gli aggiornamenti dell'app. Abbiamo intenzione di supportare questa invocazione a lungo termine."
|
||||
"betaDesc": "Questa invocazione è in versione beta. Fino a quando non sarà stabile, potrebbe subire modifiche importanti durante gli aggiornamenti dell'app. Abbiamo intenzione di supportare questa invocazione a lungo termine.",
|
||||
"newWorkflow": "Nuovo flusso di lavoro",
|
||||
"newWorkflowDesc": "Creare un nuovo flusso di lavoro?",
|
||||
"newWorkflowDesc2": "Il flusso di lavoro attuale presenta modifiche non salvate."
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
@ -1619,7 +1619,6 @@
|
||||
"saveWorkflow": "Salva flusso di lavoro",
|
||||
"openWorkflow": "Apri flusso di lavoro",
|
||||
"clearWorkflowSearchFilter": "Cancella il filtro di ricerca del flusso di lavoro",
|
||||
"workflowEditorReset": "Reimpostazione dell'editor del flusso di lavoro",
|
||||
"workflowLibrary": "Libreria",
|
||||
"noRecentWorkflows": "Nessun flusso di lavoro recente",
|
||||
"workflowSaved": "Flusso di lavoro salvato",
|
||||
@ -1633,7 +1632,10 @@
|
||||
"deleteWorkflow": "Elimina flusso di lavoro",
|
||||
"workflows": "Flussi di lavoro",
|
||||
"noDescription": "Nessuna descrizione",
|
||||
"userWorkflows": "I miei flussi di lavoro"
|
||||
"userWorkflows": "I miei flussi di lavoro",
|
||||
"newWorkflowCreated": "Nuovo flusso di lavoro creato",
|
||||
"downloadWorkflow": "Salva su file",
|
||||
"uploadWorkflow": "Carica da file"
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "Il negozio non è inizializzato"
|
||||
|
@ -844,9 +844,6 @@
|
||||
"hideLegendNodes": "Typelegende veld verbergen",
|
||||
"reloadNodeTemplates": "Herlaad knooppuntsjablonen",
|
||||
"loadWorkflow": "Laad werkstroom",
|
||||
"resetWorkflow": "Herstel werkstroom",
|
||||
"resetWorkflowDesc": "Weet je zeker dat je deze werkstroom wilt herstellen?",
|
||||
"resetWorkflowDesc2": "Herstel van een werkstroom haalt alle knooppunten, randen en werkstroomdetails weg.",
|
||||
"downloadWorkflow": "Download JSON van werkstroom",
|
||||
"booleanPolymorphicDescription": "Een verzameling Booleanse waarden.",
|
||||
"scheduler": "Planner",
|
||||
|
@ -909,9 +909,6 @@
|
||||
"hideLegendNodes": "Скрыть тип поля",
|
||||
"showMinimapnodes": "Показать миникарту",
|
||||
"loadWorkflow": "Загрузить рабочий процесс",
|
||||
"resetWorkflowDesc2": "Сброс рабочего процесса очистит все узлы, ребра и детали рабочего процесса.",
|
||||
"resetWorkflow": "Сбросить рабочий процесс",
|
||||
"resetWorkflowDesc": "Вы уверены, что хотите сбросить этот рабочий процесс?",
|
||||
"reloadNodeTemplates": "Перезагрузить шаблоны узлов",
|
||||
"downloadWorkflow": "Скачать JSON рабочего процесса",
|
||||
"booleanPolymorphicDescription": "Коллекция логических значений.",
|
||||
@ -1599,7 +1596,6 @@
|
||||
"saveWorkflow": "Сохранить рабочий процесс",
|
||||
"openWorkflow": "Открытый рабочий процесс",
|
||||
"clearWorkflowSearchFilter": "Очистить фильтр поиска рабочих процессов",
|
||||
"workflowEditorReset": "Сброс редактора рабочих процессов",
|
||||
"workflowLibrary": "Библиотека",
|
||||
"downloadWorkflow": "Скачать рабочий процесс",
|
||||
"noRecentWorkflows": "Нет недавних рабочих процессов",
|
||||
|
@ -892,11 +892,8 @@
|
||||
},
|
||||
"nodes": {
|
||||
"zoomInNodes": "放大",
|
||||
"resetWorkflowDesc": "是否确定要重置工作流编辑器?",
|
||||
"resetWorkflow": "重置工作流编辑器",
|
||||
"loadWorkflow": "加载工作流",
|
||||
"zoomOutNodes": "缩小",
|
||||
"resetWorkflowDesc2": "重置工作流编辑器将清除所有节点、边际和节点图详情。不影响已保存的工作流。",
|
||||
"reloadNodeTemplates": "重载节点模板",
|
||||
"hideGraphNodes": "隐藏节点图信息",
|
||||
"fitViewportNodes": "自适应视图",
|
||||
@ -1122,7 +1119,10 @@
|
||||
"deletedInvalidEdge": "已删除无效的边缘 {{source}} -> {{target}}",
|
||||
"unknownInput": "未知输入:{{name}}",
|
||||
"prototypeDesc": "此调用是一个原型 (prototype)。它可能会在本项目更新期间发生破坏性更改,并且随时可能被删除。",
|
||||
"betaDesc": "此调用尚处于测试阶段。在稳定之前,它可能会在项目更新期间发生破坏性更改。本项目计划长期支持这种调用。"
|
||||
"betaDesc": "此调用尚处于测试阶段。在稳定之前,它可能会在项目更新期间发生破坏性更改。本项目计划长期支持这种调用。",
|
||||
"newWorkflow": "新建工作流",
|
||||
"newWorkflowDesc": "是否创建一个新的工作流?",
|
||||
"newWorkflowDesc2": "当前工作流有未保存的更改。"
|
||||
},
|
||||
"controlnet": {
|
||||
"resize": "直接缩放",
|
||||
@ -1637,9 +1637,8 @@
|
||||
"saveWorkflow": "保存工作流",
|
||||
"openWorkflow": "打开工作流",
|
||||
"clearWorkflowSearchFilter": "清除工作流检索过滤器",
|
||||
"workflowEditorReset": "工作流编辑器重置",
|
||||
"workflowLibrary": "工作流库",
|
||||
"downloadWorkflow": "下载工作流",
|
||||
"downloadWorkflow": "保存到文件",
|
||||
"noRecentWorkflows": "无最近工作流",
|
||||
"workflowSaved": "已保存工作流",
|
||||
"workflowIsOpen": "工作流已打开",
|
||||
@ -1652,8 +1651,9 @@
|
||||
"deleteWorkflow": "删除工作流",
|
||||
"workflows": "工作流",
|
||||
"noDescription": "无描述",
|
||||
"uploadWorkflow": "上传工作流",
|
||||
"userWorkflows": "我的工作流"
|
||||
"uploadWorkflow": "从文件中加载",
|
||||
"userWorkflows": "我的工作流",
|
||||
"newWorkflowCreated": "已创建新的工作流"
|
||||
},
|
||||
"app": {
|
||||
"storeNotInitialized": "商店尚未初始化"
|
||||
|
@ -34,6 +34,7 @@ import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
import { authToastMiddleware } from 'services/api/authToastMiddleware';
|
||||
|
||||
const allReducers = {
|
||||
canvas: canvasReducer,
|
||||
@ -96,6 +97,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
})
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.concat(authToastMiddleware)
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
enhancers: (getDefaultEnhancers) => {
|
||||
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());
|
||||
|
@ -2,11 +2,14 @@ import { Text } from '@chakra-ui/layout';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
const name = useAppSelector((state) => state.workflow.name);
|
||||
const isTouched = useAppSelector((state) => state.workflow.isTouched);
|
||||
const isWorkflowLibraryEnabled =
|
||||
useFeatureStatus('workflowLibrary').isFeatureEnabled;
|
||||
|
||||
return (
|
||||
<Text
|
||||
@ -19,7 +22,7 @@ const TopCenterPanel = () => {
|
||||
opacity={0.8}
|
||||
>
|
||||
{name || t('workflows.unnamedWorkflow')}
|
||||
{isTouched ? ` (${t('common.unsaved')})` : ''}
|
||||
{isTouched && isWorkflowLibraryEnabled ? ` (${t('common.unsaved')})` : ''}
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
|
@ -144,6 +144,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
type: 'l2i',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
@ -255,6 +256,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
is_intermediate,
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -295,6 +297,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
fp32,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
||||
|
@ -191,6 +191,7 @@ export const buildCanvasInpaintGraph = (
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
reference: canvasInitImage,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -199,6 +199,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
type: 'color_correct',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -266,6 +266,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
is_intermediate,
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -306,6 +307,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
fp32,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
||||
|
@ -196,6 +196,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
reference: canvasInitImage,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -204,6 +204,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
type: 'color_correct',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -258,6 +258,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
is_intermediate,
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -288,6 +289,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
fp32,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
|
@ -246,6 +246,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
is_intermediate,
|
||||
width: width,
|
||||
height: height,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
@ -276,6 +277,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate,
|
||||
fp32,
|
||||
use_cache: false,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
|
@ -143,6 +143,7 @@ export const buildLinearImageToImageGraph = (
|
||||
// },
|
||||
fp32,
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -154,6 +154,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
||||
// },
|
||||
fp32,
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -127,6 +127,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -146,6 +146,7 @@ export const buildLinearTextToImageGraph = (
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32,
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -5,12 +5,10 @@ import { t } from 'i18next';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zRejectedForbiddenAction = z.object({
|
||||
action: z.object({
|
||||
payload: z.object({
|
||||
status: z.literal(403),
|
||||
data: z.object({
|
||||
detail: z.string(),
|
||||
}),
|
||||
payload: z.object({
|
||||
status: z.literal(403),
|
||||
data: z.object({
|
||||
detail: z.string(),
|
||||
}),
|
||||
}),
|
||||
});
|
||||
@ -22,8 +20,8 @@ export const authToastMiddleware: Middleware =
|
||||
const parsed = zRejectedForbiddenAction.parse(action);
|
||||
const { dispatch } = api;
|
||||
const customMessage =
|
||||
parsed.action.payload.data.detail !== 'Forbidden'
|
||||
? parsed.action.payload.data.detail
|
||||
parsed.payload.data.detail !== 'Forbidden'
|
||||
? parsed.payload.data.detail
|
||||
: undefined;
|
||||
dispatch(
|
||||
addToast({
|
||||
@ -32,7 +30,7 @@ export const authToastMiddleware: Middleware =
|
||||
description: customMessage,
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
} catch (error) {
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
@ -172,6 +172,8 @@ nav:
|
||||
- Adding Tests: 'contributing/TESTS.md'
|
||||
- Documentation: 'contributing/contribution_guides/documentation.md'
|
||||
- Nodes: 'contributing/INVOCATIONS.md'
|
||||
- Model Manager: 'contributing/MODEL_MANAGER.md'
|
||||
- Download Queue: 'contributing/DOWNLOAD_QUEUE.md'
|
||||
- Translation: 'contributing/contribution_guides/translation.md'
|
||||
- Tutorials: 'contributing/contribution_guides/tutorials.md'
|
||||
- Changelog: 'CHANGELOG.md'
|
||||
|
@ -105,6 +105,7 @@ dependencies = [
|
||||
"pytest>6.0.0",
|
||||
"pytest-cov",
|
||||
"pytest-datadir",
|
||||
"requests_testadapter",
|
||||
]
|
||||
"xformers" = [
|
||||
"xformers==0.0.23; sys_platform!='darwin'",
|
||||
@ -138,7 +139,6 @@ dependencies = [
|
||||
"invokeai-node-web" = "invokeai.app.api_app:invoke_api"
|
||||
"invokeai-import-images" = "invokeai.frontend.install.import_images:main"
|
||||
"invokeai-db-maintenance" = "invokeai.backend.util.db_maintenance:main"
|
||||
"invokeai-migrate-models-to-db" = "invokeai.backend.model_manager.migrate_to_db:main"
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||
|
@ -26,7 +26,6 @@ from invokeai.app.services.shared.graph import (
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
IterateInvocation,
|
||||
LibraryGraph,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
@ -61,7 +60,6 @@ def mock_services() -> InvocationServices:
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
@ -70,6 +68,7 @@ def mock_services() -> InvocationServices:
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_install=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
|
@ -24,7 +24,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_default import Invo
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -66,7 +66,6 @@ def mock_services() -> InvocationServices:
|
||||
configuration=configuration,
|
||||
events=TestEventService(),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
|
||||
image_files=None, # type: ignore
|
||||
image_records=None, # type: ignore
|
||||
images=None, # type: ignore
|
||||
@ -75,6 +74,7 @@ def mock_services() -> InvocationServices:
|
||||
logger=logging, # type: ignore
|
||||
model_manager=None, # type: ignore
|
||||
model_records=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
model_install=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
|
223
tests/app/services/download/test_download_queue.py
Normal file
223
tests/app/services/download/test_download_queue.py
Normal file
@ -0,0 +1,223 @@
|
||||
"""Test the queued download facility"""
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
TestAdapter.__test__ = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> requests.sessions.Session:
|
||||
sess = requests.Session()
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
# A dummy event service for testing event issuing
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
on_start=event_handler,
|
||||
on_progress=event_handler,
|
||||
on_complete=event_handler,
|
||||
on_error=event_handler,
|
||||
)
|
||||
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
|
||||
assert isinstance(job.id, int), "expected the job id to be numeric"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
|
||||
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
|
||||
|
||||
queue.join()
|
||||
jobs = queue.list_jobs()
|
||||
print(jobs)
|
||||
assert len(jobs) == 2
|
||||
jobs_dict = {str(x.source): x for x in jobs}
|
||||
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
|
||||
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
|
||||
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
|
||||
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
event_bus = DummyEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue.start()
|
||||
queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
)
|
||||
queue.join()
|
||||
events = event_bus.events
|
||||
assert len(events) == 3
|
||||
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
||||
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
||||
assert events[0].event_name == "download_started"
|
||||
assert events[1].event_name == "download_progress"
|
||||
assert events[1].payload["total_bytes"] > 0
|
||||
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
||||
assert events[2].event_name == "download_complete"
|
||||
assert events[2].payload["total_bytes"] == 32029
|
||||
|
||||
# test a failure
|
||||
event_bus.events = [] # reset our accumulator
|
||||
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
|
||||
queue.join()
|
||||
events = event_bus.events
|
||||
print("\n".join([x.model_dump_json() for x in events]))
|
||||
assert len(events) == 1
|
||||
assert events[0].event_name == "download_error"
|
||||
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
||||
assert events[0].payload["error"] is not None
|
||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, capsys) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
callback_ran = False
|
||||
|
||||
def broken_callback(job: DownloadJob) -> None:
|
||||
nonlocal callback_ran
|
||||
callback_ran = True
|
||||
print(1 / 0) # deliberate error here
|
||||
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
on_progress=broken_callback,
|
||||
)
|
||||
|
||||
queue.join()
|
||||
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists()
|
||||
assert callback_ran
|
||||
# LS: The pytest capsys fixture does not seem to be working. I can see the
|
||||
# correct stderr message in the pytest log, but it is not appearing in
|
||||
# capsys.readouterr().
|
||||
# captured = capsys.readouterr()
|
||||
# assert re.search("division by zero", captured.err)
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_cancel(tmp_path: Path, session: requests.sessions.Session) -> None:
|
||||
event_bus = DummyEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
|
||||
def slow_callback(job: DownloadJob) -> None:
|
||||
time.sleep(2)
|
||||
|
||||
def cancelled_callback(job: DownloadJob) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
on_start=slow_callback,
|
||||
on_cancelled=cancelled_callback,
|
||||
)
|
||||
queue.cancel_job(job)
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus.CANCELLED
|
||||
assert cancelled
|
||||
events = event_bus.events
|
||||
assert events[-1].event_name == "download_cancelled"
|
||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
@ -48,11 +48,13 @@ def store(
|
||||
|
||||
@pytest.fixture
|
||||
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||
return ModelInstallService(
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=store,
|
||||
event_bus=DummyEventService(),
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,244 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
return InvokeAIAppConfig(
|
||||
root=datadir / "root",
|
||||
models_dir=datadir / "root/models",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
db = create_mock_sqlite_database(app_config, logger)
|
||||
store = ModelRecordServiceSQL(db)
|
||||
# add three simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test2",
|
||||
"base": BaseModelType("sd-2"),
|
||||
"type": ModelType("vae"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "stabilityai/sdxl-vae",
|
||||
}
|
||||
raw2 = {
|
||||
"path": "/tmp/foo2.ckpt",
|
||||
"name": "model1",
|
||||
"format": ModelFormat("checkpoint"),
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": "main",
|
||||
"config": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
"original_hash": "111222333444",
|
||||
"source": "https://civitai.com/models/206883/split",
|
||||
}
|
||||
raw3 = {
|
||||
"path": "/tmp/foo3",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test3",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("main"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author3/model3",
|
||||
}
|
||||
store.add_model("test_config_1", raw1)
|
||||
store.add_model("test_config_2", raw2)
|
||||
store.add_model("test_config_3", raw3)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Session:
|
||||
sess = requests.Session()
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/model-versions/242807",
|
||||
TestAdapter(
|
||||
RepoCivitaiVersionMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiVersionMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/models/215485",
|
||||
TestAdapter(
|
||||
RepoCivitaiModelMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiModelMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_store(record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||
db = record_store._db # to ensure we are sharing the same database
|
||||
return ModelMetadataStore(db)
|
||||
|
||||
|
||||
def test_metadata_store_put_get(metadata_store: ModelMetadataStore) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
output_metadata = metadata_store.get_metadata("test_config_1")
|
||||
assert input_metadata == output_metadata
|
||||
with pytest.raises(UnknownModelException):
|
||||
metadata_store.add_metadata("unknown_key", input_metadata)
|
||||
|
||||
|
||||
def test_metadata_store_update(metadata_store: ModelMetadataStore) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
input_metadata.name = "new-name"
|
||||
metadata_store.update_metadata("test_config_1", input_metadata)
|
||||
output_metadata = metadata_store.get_metadata("test_config_1")
|
||||
assert output_metadata.name == "new-name"
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(metadata_store: ModelMetadataStore) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata2 = HuggingFaceMetadata(
|
||||
name="model2",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers", "community-contributed"},
|
||||
id="author2/model2",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata3 = HuggingFaceMetadata(
|
||||
name="model3",
|
||||
author="author3",
|
||||
tags={"text-to-image", "checkpoint", "community-contributed"},
|
||||
id="author3/model3",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata_store.add_metadata("test_config_1", metadata1)
|
||||
metadata_store.add_metadata("test_config_2", metadata2)
|
||||
metadata_store.add_metadata("test_config_3", metadata3)
|
||||
|
||||
matches = metadata_store.search_by_author("stabilityai")
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
matches = metadata_store.search_by_author("Sherlock Holmes")
|
||||
assert not matches
|
||||
|
||||
matches = metadata_store.search_by_name("model3")
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
matches = metadata_store.search_by_tag({"text-to-image"})
|
||||
assert len(matches) == 3
|
||||
|
||||
matches = metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
|
||||
matches = metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
# does the tag table update correctly?
|
||||
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert not matches
|
||||
metadata3.tags.add("licensed-for-commercial-use")
|
||||
metadata_store.update_metadata("test_config_3", metadata3)
|
||||
matches = metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_metadata_civitai_fetch(session: Session) -> None:
|
||||
fetcher = CivitaiMetadataFetch(session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
||||
def test_metadata_hf_fetch(session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
assert metadata.author == "test_author" # this is not the same as the original
|
||||
assert metadata.files
|
||||
assert metadata.tags == {
|
||||
"diffusers",
|
||||
"onnx",
|
||||
"safetensors",
|
||||
"text-to-image",
|
||||
"license:other",
|
||||
"has_space",
|
||||
"diffusers:StableDiffusionXLPipeline",
|
||||
"region:us",
|
||||
}
|
@ -305,9 +305,7 @@ def test_calc_tiles_min_overlap_input_validation(
|
||||
|
||||
def test_calc_tiles_even_split_single_tile():
|
||||
"""Test calc_tiles_even_split() behavior when a single tile covers the image."""
|
||||
tiles = calc_tiles_even_split(
|
||||
image_height=512, image_width=1024, num_tiles_x=1, num_tiles_y=1, overlap_fraction=0.25
|
||||
)
|
||||
tiles = calc_tiles_even_split(image_height=512, image_width=1024, num_tiles_x=1, num_tiles_y=1, overlap=64)
|
||||
|
||||
expected_tiles = [
|
||||
Tile(
|
||||
@ -322,36 +320,34 @@ def test_calc_tiles_even_split_single_tile():
|
||||
def test_calc_tiles_even_split_evenly_divisible():
|
||||
"""Test calc_tiles_even_split() behavior when the image is evenly covered by multiple tiles."""
|
||||
# Parameters mimic roughly the same output as the original tile generations of the same test name
|
||||
tiles = calc_tiles_even_split(
|
||||
image_height=576, image_width=1600, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
|
||||
)
|
||||
tiles = calc_tiles_even_split(image_height=576, image_width=1600, num_tiles_x=3, num_tiles_y=2, overlap=64)
|
||||
|
||||
expected_tiles = [
|
||||
# Row 0
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=320, left=0, right=624),
|
||||
overlap=TBLR(top=0, bottom=72, left=0, right=136),
|
||||
coords=TBLR(top=0, bottom=320, left=0, right=576),
|
||||
overlap=TBLR(top=0, bottom=64, left=0, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=320, left=488, right=1112),
|
||||
overlap=TBLR(top=0, bottom=72, left=136, right=136),
|
||||
coords=TBLR(top=0, bottom=320, left=512, right=1088),
|
||||
overlap=TBLR(top=0, bottom=64, left=64, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=320, left=976, right=1600),
|
||||
overlap=TBLR(top=0, bottom=72, left=136, right=0),
|
||||
coords=TBLR(top=0, bottom=320, left=1024, right=1600),
|
||||
overlap=TBLR(top=0, bottom=64, left=64, right=0),
|
||||
),
|
||||
# Row 1
|
||||
Tile(
|
||||
coords=TBLR(top=248, bottom=576, left=0, right=624),
|
||||
overlap=TBLR(top=72, bottom=0, left=0, right=136),
|
||||
coords=TBLR(top=256, bottom=576, left=0, right=576),
|
||||
overlap=TBLR(top=64, bottom=0, left=0, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=248, bottom=576, left=488, right=1112),
|
||||
overlap=TBLR(top=72, bottom=0, left=136, right=136),
|
||||
coords=TBLR(top=256, bottom=576, left=512, right=1088),
|
||||
overlap=TBLR(top=64, bottom=0, left=64, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=248, bottom=576, left=976, right=1600),
|
||||
overlap=TBLR(top=72, bottom=0, left=136, right=0),
|
||||
coords=TBLR(top=256, bottom=576, left=1024, right=1600),
|
||||
overlap=TBLR(top=64, bottom=0, left=64, right=0),
|
||||
),
|
||||
]
|
||||
assert tiles == expected_tiles
|
||||
@ -360,36 +356,34 @@ def test_calc_tiles_even_split_evenly_divisible():
|
||||
def test_calc_tiles_even_split_not_evenly_divisible():
|
||||
"""Test calc_tiles_even_split() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
|
||||
# Parameters mimic roughly the same output as the original tile generations of the same test name
|
||||
tiles = calc_tiles_even_split(
|
||||
image_height=400, image_width=1200, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
|
||||
)
|
||||
tiles = calc_tiles_even_split(image_height=400, image_width=1200, num_tiles_x=3, num_tiles_y=2, overlap=64)
|
||||
|
||||
expected_tiles = [
|
||||
# Row 0
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=224, left=0, right=464),
|
||||
overlap=TBLR(top=0, bottom=56, left=0, right=104),
|
||||
coords=TBLR(top=0, bottom=232, left=0, right=440),
|
||||
overlap=TBLR(top=0, bottom=64, left=0, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=224, left=360, right=824),
|
||||
overlap=TBLR(top=0, bottom=56, left=104, right=104),
|
||||
coords=TBLR(top=0, bottom=232, left=376, right=816),
|
||||
overlap=TBLR(top=0, bottom=64, left=64, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=224, left=720, right=1200),
|
||||
overlap=TBLR(top=0, bottom=56, left=104, right=0),
|
||||
coords=TBLR(top=0, bottom=232, left=752, right=1200),
|
||||
overlap=TBLR(top=0, bottom=64, left=64, right=0),
|
||||
),
|
||||
# Row 1
|
||||
Tile(
|
||||
coords=TBLR(top=168, bottom=400, left=0, right=464),
|
||||
overlap=TBLR(top=56, bottom=0, left=0, right=104),
|
||||
coords=TBLR(top=168, bottom=400, left=0, right=440),
|
||||
overlap=TBLR(top=64, bottom=0, left=0, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=168, bottom=400, left=360, right=824),
|
||||
overlap=TBLR(top=56, bottom=0, left=104, right=104),
|
||||
coords=TBLR(top=168, bottom=400, left=376, right=816),
|
||||
overlap=TBLR(top=64, bottom=0, left=64, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=168, bottom=400, left=720, right=1200),
|
||||
overlap=TBLR(top=56, bottom=0, left=104, right=0),
|
||||
coords=TBLR(top=168, bottom=400, left=752, right=1200),
|
||||
overlap=TBLR(top=64, bottom=0, left=64, right=0),
|
||||
),
|
||||
]
|
||||
|
||||
@ -399,28 +393,26 @@ def test_calc_tiles_even_split_not_evenly_divisible():
|
||||
def test_calc_tiles_even_split_difficult_size():
|
||||
"""Test calc_tiles_even_split() behavior when the image is a difficult size to spilt evenly and keep div8."""
|
||||
# Parameters are a difficult size for other tile gen routines to calculate
|
||||
tiles = calc_tiles_even_split(
|
||||
image_height=1000, image_width=1000, num_tiles_x=2, num_tiles_y=2, overlap_fraction=0.25
|
||||
)
|
||||
tiles = calc_tiles_even_split(image_height=1000, image_width=1000, num_tiles_x=2, num_tiles_y=2, overlap=64)
|
||||
|
||||
expected_tiles = [
|
||||
# Row 0
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=560, left=0, right=560),
|
||||
overlap=TBLR(top=0, bottom=128, left=0, right=128),
|
||||
coords=TBLR(top=0, bottom=528, left=0, right=528),
|
||||
overlap=TBLR(top=0, bottom=64, left=0, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=0, bottom=560, left=432, right=1000),
|
||||
overlap=TBLR(top=0, bottom=128, left=128, right=0),
|
||||
coords=TBLR(top=0, bottom=528, left=464, right=1000),
|
||||
overlap=TBLR(top=0, bottom=64, left=64, right=0),
|
||||
),
|
||||
# Row 1
|
||||
Tile(
|
||||
coords=TBLR(top=432, bottom=1000, left=0, right=560),
|
||||
overlap=TBLR(top=128, bottom=0, left=0, right=128),
|
||||
coords=TBLR(top=464, bottom=1000, left=0, right=528),
|
||||
overlap=TBLR(top=64, bottom=0, left=0, right=64),
|
||||
),
|
||||
Tile(
|
||||
coords=TBLR(top=432, bottom=1000, left=432, right=1000),
|
||||
overlap=TBLR(top=128, bottom=0, left=128, right=0),
|
||||
coords=TBLR(top=464, bottom=1000, left=464, right=1000),
|
||||
overlap=TBLR(top=64, bottom=0, left=64, right=0),
|
||||
),
|
||||
]
|
||||
|
||||
@ -428,11 +420,13 @@ def test_calc_tiles_even_split_difficult_size():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["image_height", "image_width", "num_tiles_x", "num_tiles_y", "overlap_fraction", "raises"],
|
||||
["image_height", "image_width", "num_tiles_x", "num_tiles_y", "overlap", "raises"],
|
||||
[
|
||||
(128, 128, 1, 1, 0.25, False), # OK
|
||||
(128, 128, 1, 1, 127, False), # OK
|
||||
(128, 128, 1, 1, 0, False), # OK
|
||||
(128, 128, 2, 1, 0, False), # OK
|
||||
(128, 128, 2, 2, 0, False), # OK
|
||||
(128, 128, 2, 1, 120, True), # overlap equals tile_height.
|
||||
(128, 128, 1, 2, 120, True), # overlap equals tile_width.
|
||||
(127, 127, 1, 1, 0, True), # image size must be dividable by 8
|
||||
],
|
||||
)
|
||||
@ -441,15 +435,15 @@ def test_calc_tiles_even_split_input_validation(
|
||||
image_width: int,
|
||||
num_tiles_x: int,
|
||||
num_tiles_y: int,
|
||||
overlap_fraction: float,
|
||||
overlap: int,
|
||||
raises: bool,
|
||||
):
|
||||
"""Test that calc_tiles_even_split() raises an exception if the inputs are invalid."""
|
||||
if raises:
|
||||
with pytest.raises(ValueError):
|
||||
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
|
||||
with pytest.raises((AssertionError, ValueError)):
|
||||
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap)
|
||||
else:
|
||||
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
|
||||
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap)
|
||||
|
||||
|
||||
#############################################
|
||||
|
Reference in New Issue
Block a user